Skip to content

Commit 7113cbe

Browse files
authored
Merge pull request #521 from lamhoangtung/raise_error_with_better_pytorch_traceback
feat: Show PyTorch code of unsupported operators
2 parents cf06222 + 9a48752 commit 7113cbe

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

core/conversion/conversion.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
436436
return engine;
437437
}
438438

439-
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
440-
std::set<std::string> unsupported_ops;
439+
std::unordered_map<c10::OperatorName, std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
440+
std::unordered_map<c10::OperatorName, std::string> unsupported_ops;
441441
for (const auto n : b->nodes()) {
442442
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
443443
auto schema = n->maybeSchema();
@@ -446,7 +446,7 @@ std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
446446
"Unable to get schema for Node " << util::node_info(n) << " (conversion.VerifyCoverterSupportForBlock)");
447447
std::stringstream ss;
448448
ss << *schema;
449-
unsupported_ops.insert(ss.str());
449+
unsupported_ops[schema->operator_name()] = ss.str();
450450
}
451451
for (const auto sub_b : n->blocks()) {
452452
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
@@ -488,12 +488,27 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
488488
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:"
489489
<< std::endl;
490490
for (auto s : unsupported_ops) {
491-
unsupported_msg << " - " << s << std::endl;
491+
unsupported_msg << " - " << s.second << std::endl;
492492
}
493493
unsupported_msg << "You can either implement converters for these ops in your application or request implementation"
494494
<< std::endl;
495495
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
496+
unsupported_msg << std::endl << "In Module:" << std::endl;
497+
496498
LOG_ERROR(unsupported_msg.str());
499+
500+
for (const auto n : b->nodes()) {
501+
auto schema = n->maybeSchema();
502+
if (schema) {
503+
for (const auto& x : unsupported_ops) {
504+
if (x.first == schema->operator_name()) {
505+
LOG_ERROR(
506+
"Unsupported operator: " << *schema << std::endl
507+
<< trtorch::core::util::GetPyTorchSourceCode(n) << std::endl);
508+
}
509+
}
510+
}
511+
}
497512
return false;
498513
}
499514

core/util/jit_util.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ inline c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::sha
4747
return c10::FunctionSchema(method_name, method_name, args, returns);
4848
}
4949

50+
inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
51+
std::string source_code = n->sourceRange().str();
52+
return source_code;
53+
}
54+
5055
} // namespace util
5156
} // namespace core
5257
} // namespace trtorch

0 commit comments

Comments
 (0)