Skip to content

Commit 28ee445

Browse files
committed
feat(CheckMethodOperatorSupport): A new API which will check the graph
to see if all operators are supported. Addresses #26. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 3da4947 commit 28ee445

File tree

10 files changed

+95
-30
lines changed

10 files changed

+95
-30
lines changed

core/compiler.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,28 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
6464
return;
6565
}
6666

67+
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
68+
std::string method_name) {
69+
auto g = mod.get_method(method_name).graph();
70+
// Go through PyTorch Lowering to simplify graph and extract weight parameters
71+
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
72+
73+
g = graph_and_parameters.first;
74+
75+
// Go through TRTorch Lowering to reformat graph to be conversion friendly
76+
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
77+
lowering::LowerGraph(g);
78+
79+
auto params = graph_and_parameters.second;
80+
auto named_params = conversion::get_named_params(g->inputs(), params);
81+
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
82+
83+
// Is this necessary?
84+
lowering::LowerBlock(g->block());
85+
86+
return conversion::VerifyConverterSupportForBlock(g->block());
87+
}
88+
6789
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
6890
std::string method_name,
6991
conversion::ExtraInfo cfg) {
@@ -87,7 +109,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
87109
return std::move(engine);
88110
}
89111

90-
// TODO: Consider if there is a better way to deal with input size
91112
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
92113
conversion::ExtraInfo cfg) {
93114
// TODO: Should be doing a functional transform but need PR #31978

core/compiler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
namespace trtorch {
88
namespace core {
9+
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
910

1011
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
1112
std::string method_name, conversion::ExtraInfo cfg);
13+
1214
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, conversion::ExtraInfo cfg);
1315

1416
} // namespace core

core/conversion/conversion.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ namespace core {
1111
namespace conversion {
1212

1313
// Defined in core/conversion/conversion_blacklist.cpp
14-
bool isNodeConversionBlacklisted(torch::jit::Node* n);
14+
bool isNodeConversionBlacklisted(const torch::jit::Node* n);
1515

1616
bool OpSupported(const torch::jit::Node* n) {
1717
bool evalable = evaluators::shouldEvalAtConversionTime(n);
1818
bool convertable = converters::node_is_convertable(n);
1919
return evalable || convertable;
2020
}
2121

22-
c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, torch::jit::Node* n, int level=0, int limit=10) {
22+
c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level=0, int limit=10) {
2323
// Check to see if you can just go through and eval all of these AOT (saves the recursion)
2424
// Also probably a better way to deal with the two error cases;
2525
TRTORCH_CHECK(level < limit, "Failed to evaluate node: " << *n \
@@ -55,7 +55,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, torch::jit::N
5555
return eval;
5656
}
5757

58-
bool AddLayer(ConversionCtx* ctx, torch::jit::Node* n) {
58+
bool AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
5959
LOG_INFO(ctx->logger,
6060
"Adding Layer " << util::node_info(n) << " (ctx.AddLayer)");
6161
converters::args node_args;
@@ -114,11 +114,11 @@ bool AddLayer(ConversionCtx* ctx, torch::jit::Node* n) {
114114
}
115115

116116
bool AddInputs(ConversionCtx* ctx,
117-
at::ArrayRef<torch::jit::Value*> inputs,
117+
at::ArrayRef<const torch::jit::Value*> inputs,
118118
std::vector<InputRange>& input_dims) {
119119

120120
auto type_lut = torch::jit::script::string_to_type_lut();
121-
std::vector<torch::jit::Value*> input_tensors;
121+
std::vector<const torch::jit::Value*> input_tensors;
122122
for (auto in : inputs) {
123123
// Disregarding inputs that are not tensors
124124
//
@@ -163,7 +163,7 @@ bool AddInputs(ConversionCtx* ctx,
163163
return true;
164164
}
165165

166-
bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef<torch::jit::Value*> outputs) {
166+
bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outputs) {
167167
for (auto out : outputs) {
168168
ctx->net->markOutput(*(ctx->value_tensor_map[out]));
169169
LOG_INFO(ctx->logger,
@@ -178,7 +178,7 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
178178
}
179179
}
180180

181-
void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
181+
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
182182
LOG_INFO(ctx->logger, "Converting Block");
183183

184184
auto inputs = b->inputs();
@@ -188,7 +188,6 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo bu
188188
auto nodes = b->nodes();
189189

190190
for (const auto n : nodes) {
191-
192191
bool to_eval = evaluators::shouldEvalAtConversionTime(n);
193192
bool blacklisted = isNodeConversionBlacklisted(n);
194193
if (!to_eval && !blacklisted) {
@@ -220,13 +219,41 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo bu
220219
// a serialized TensorRT engine that can be deserialized and run
221220

222221
// Probably should consolidate these two functions
223-
std::string ConvertBlockToEngine(torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
222+
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
224223
ConversionCtx ctx(build_info.engine_settings);
225224
ConvertBlockToNetDef(&ctx, b, build_info, static_params);
226225
std::string engine = ctx.SerializeEngine();
227226
return engine;
228227
}
229228

229+
bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
230+
bool supported = true;
231+
std::set<std::string> unsupported_ops;
232+
for (const auto n : b->nodes()) {
233+
if (!OpSupported(n)) {
234+
auto schema = n->maybeSchema();
235+
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
236+
<< " (conversion.AddLayer)");
237+
std::stringstream ss;
238+
ss << *schema;
239+
unsupported_ops.insert(ss.str());
240+
supported = false;
241+
}
242+
}
243+
244+
if (!supported) {
245+
std::stringstream unsupported_msg;
246+
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:" << std::endl;
247+
for (auto s : unsupported_ops) {
248+
unsupported_msg << " - " << s << std::endl;
249+
}
250+
unsupported_msg << "You can either implement converters for these ops in your application or file a bug" << std::endl;
251+
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
252+
LOG_ERROR(unsupported_msg.str());
253+
}
254+
return supported;
255+
}
256+
230257
} // namespace conversion
231258
} // namespace core
232259
} // namespace trtorch

core/conversion/conversion.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vect
4343

4444
// Converts a already lowered block (blocks with no sub blocks) to
4545
// a serialized TensorRT engine that can be deserialized and run
46-
std::string ConvertBlockToEngine(torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);
46+
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);
4747

4848
bool OpSupported(const torch::jit::Node* n);
4949

50+
bool VerifyConverterSupportForBlock(const torch::jit::Block* b);
51+
5052
} // namespace conversion
5153
} // namespace core
5254
} // namespace trtorch

core/conversion/conversion_blacklist.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
2424
return nonconvertable_nodes;
2525
}
2626

27-
bool isNodeConversionBlacklisted(torch::jit::Node* n) {
27+
bool isNodeConversionBlacklisted(const torch::jit::Node* n) {
2828
auto kind = n->kind();
2929
auto convertableIt = get_non_convertable_nodes().find(kind.toQualString());
3030
if (convertableIt == get_non_convertable_nodes().end()) {

core/conversion/converters/NodeConverterRegistry.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ using ConverterLUT = std::unordered_map<torch::jit::Symbol, OpConverter>;
4646
class NodeConverterRegistry {
4747
public:
4848
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
49-
// NOTE: This is useful for people developing extentions to the conversion registry as is
50-
// If you are working on the core conversion library and the conversion registry
51-
// itself, it might helpful to set -DDEBUG_MSGS when you compile so you can watch the
52-
// registration of core converters during init, otherwise the messages will be masked
5349
LOG_DEBUG("Registering Converter for " << canonical_schema_string(*signature));
5450
auto sym = torch::jit::Symbol::fromQualString(signature->name());
5551
converter_lut_[sym] = std::move(converter);
@@ -70,13 +66,12 @@ class NodeConverterRegistry {
7066
bool Convertable(const torch::jit::Node* n) {
7167
auto schema = n->maybeSchema();
7268
if (schema) {
73-
auto converter = GetConverter(schema);
74-
if (converter) {
75-
return true;
69+
auto sym = torch::jit::Symbol::fromQualString(schema->name());
70+
auto iter = converter_lut_.find(sym);
71+
if (iter == converter_lut_.end()) {
72+
return false;
7673
} else {
77-
LOG_DEBUG("Node has no registered converter: " << util::node_info(n) \
78-
<< " (NodeConverterRegistry.Convertable)\nSchema: " << *schema);
79-
return false;
74+
return true;
8075
}
8176
} else {
8277
LOG_DEBUG("Unable to get schema for Node " << util::node_info(n) \

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, NodeEvaluator>;
2020
class NodeEvaluatorRegistry {
2121
public:
2222
void RegisterEvaluator(torch::jit::NodeKind node_kind, NodeEvaluator& evaluator) {
23-
// NOTE: This is useful for people developing extentions to the conversion registry as is
24-
// If you are working on the core conversion library and the conversion registry
25-
// itself, it might helpful to set -DDEBUG_MSGS when you compile so you can watch the
26-
// registration of core converters during init, otherwise the messages will be masked
2723
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
2824
evaluator_lut_[node_kind] = std::move(evaluator);
2925
}

cpp/api/include/trtorch/trtorch.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,19 @@ TRTORCH_API std::string get_build_info();
215215
*/
216216
TRTORCH_API void dump_build_info();
217217

218+
/**
219+
* @brief Check to see if a module is fully supported by the compiler
220+
*
221+
* @param module: torch::jit::script::Module - Existing TorchScript module
222+
* @param method_name: std::string - Name of method to compile
223+
*
224+
* Takes a module and a method name and checks if the method graph contains purely
225+
* convertable operators
226+
*
227+
* Will print out a list of unsupported operators if the graph is unsupported
228+
*/
229+
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name);
230+
218231
/**
219232
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
220233
*
@@ -239,5 +252,5 @@ TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Mo
239252
* and will convert selected method to a serialized TensorRT engine which can be run with
240253
* TensorRT
241254
*/
242-
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, ExtraInfo info);
255+
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
243256
} // namespace trtorch

cpp/api/src/trtorch.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@ namespace trtorch {
1010
// Defined in extra_info.cpp
1111
core::conversion::ExtraInfo to_internal_extra_info(ExtraInfo external);
1212

13-
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
13+
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module,
14+
std::string method_name) {
15+
return core::CheckMethodOperatorSupport(module, method_name);
16+
}
17+
18+
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module,
1419
std::string method_name, ExtraInfo info) {
1520
LOG_DEBUG(get_build_info());
1621
// Want to export a much simpler (non TRT header dependent) API so doing the
1722
// type conversion here
18-
return std::move(core::ConvertGraphToTRTEngine(mod, method_name, to_internal_extra_info(info)));
23+
return std::move(core::ConvertGraphToTRTEngine(module, method_name, to_internal_extra_info(info)));
1924
}
2025

2126
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info) {

cpp/trtorchexec/main.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ int main(int argc, const char* argv[]) {
5555
dims.push_back(v);
5656
}
5757

58+
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
59+
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
60+
return -1;
61+
}
62+
5863
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims);
5964
std::ofstream out("/tmp/engine_converted_from_jit.trt");
6065
out << engine;
@@ -69,7 +74,6 @@ int main(int argc, const char* argv[]) {
6974
torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
7075
std::vector<at::Tensor> jit_results;
7176
jit_results.push_back(jit_results_ivalues.toTensor());
72-
7377

7478
auto trt_mod = trtorch::CompileGraph(mod, dims);
7579
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);

0 commit comments

Comments
 (0)