Skip to content
Snippets Groups Projects
Commit 54d7f907 authored by Jonathan Mace's avatar Jonathan Mace
Browse files

Update CUDA loading code to handle both old and new TVM binaries

parent d7661392
No related branches found
No related tags found
No related merge requests found
...@@ -22,23 +22,40 @@ UnloadedCUDAModule::UnloadedCUDAModule(const char* &cuda_blob) { ...@@ -22,23 +22,40 @@ UnloadedCUDAModule::UnloadedCUDAModule(const char* &cuda_blob) {
dmlc::Stream* stream = &fs; dmlc::Stream* stream = &fs;
uint64_t size; uint64_t size;
CHECK(stream->Read(&size)); CHECK(stream->Read(&size));
CHECK(size == 1) << "Only expected one dev_mblob, found " << size;
std::string tkey; CHECK(size == 1 || size == 3) << "Found " << size << " dev_mblob; expected 1 (legacy) or 3 (tvm v0.6)";
CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey; bool found_cuda = false;
CHECK(tkey == "cuda") << "Expected dev_mblob of type cuda, found " << tkey; for (uint64_t i = 0; i < size; i++) {
std::string tkey;
stream->Read(&this->fmt); CHECK(stream->Read(&tkey));
if (tkey == "cuda") {
std::unordered_map<std::string, tvm::runtime::FunctionInfo> fmap; stream->Read(&this->fmt);
stream->Read(&fmap);
std::unordered_map<std::string, tvm::runtime::FunctionInfo> fmap;
this->functions.reserve(fmap.size()); stream->Read(&fmap);
for (auto & e : fmap) {
this->functions[e.first] = new UnloadedCUDAFunc(e.first, e.second); this->functions.reserve(fmap.size());
for (auto & e : fmap) {
this->functions[e.first] = new UnloadedCUDAFunc(e.first, e.second);
}
stream->Read(&this->data);
found_cuda = true;
} else if (tkey == "_lib") {
// Skip
} else if (tkey == "_import_tree") {
std::vector<uint64_t> import_tree_row_ptr;
std::vector<uint64_t> import_tree_child_indices;
CHECK(stream->Read(&import_tree_row_ptr));
CHECK(stream->Read(&import_tree_child_indices));
CHECK(import_tree_row_ptr.size() == 3 && import_tree_child_indices.size() == 1) <<
"Possible invalid TVM dev_mblob; import_tree has stuff in it";
} else {
CHECK(false) << "Found unexpected key " << tkey << " in dev_mblob";
}
} }
stream->Read(&this->data);
CHECK(found_cuda) << "Expected dev_mblob of type cuda but did not find one";
} }
UnloadedCUDAModule::~UnloadedCUDAModule() { UnloadedCUDAModule::~UnloadedCUDAModule() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment