Skip to content

Add trt_executed_modules & default_torch_execution interface #1122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
}
torch::jit::EliminateDeadCode(g);
if (lower_info.forced_fallback_modules.size() > 0) {
passes::MarkNodesForFallback(g, true);
passes::MarkNodesForFallback(g, true, lower_info.default_torch_execution);
}
passes::UnpackHardSwish(g);
passes::EliminateExceptionOrPassPattern(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct LowerInfo {
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
bool disable_cse = false;
bool default_torch_execution = false;
std::vector<std::string> forced_fallback_modules;
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
};
Expand Down
4 changes: 2 additions & 2 deletions core/lowering/passes/module_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void NotateModuleForFallback(
}
}

void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims) {
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims, bool default_torch_execution) {
auto b = g->block();

std::stack<bool> mark = std::stack<bool>({false});
Expand Down Expand Up @@ -126,7 +126,7 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
LOG_WARNING("Found the end of segmented block targeted for torch while not actively marking a block");
}
} else if (mark.top()) {
} else if ((!mark.top() && default_torch_execution) or (mark.top() && !default_torch_execution)) {
LOG_GRAPH("Marking " << util::node_info(n) << " to run in PyTorch");
n->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
}
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
void ReduceRemainder(std::shared_ptr<torch::jit::Graph>& graph);
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims, bool default_torch_execution);
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
info.partition_info.truncate_long_and_double = truncate_long_and_double;
info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules;
info.lower_info.default_torch_execution = default_torch_execution;
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;

info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ struct CompileSpec : torch::CustomClassHolder {
bool refit = false;
bool debug = false;
bool truncate_long_and_double = false;
bool default_torch_execution = false;
Device device;
TorchFallback torch_fallback;
EngineCapability capability = EngineCapability::kDEFAULT;
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("num_avg_timing_iters", &CompileSpec::num_avg_timing_iters)
.def_readwrite("workspace_size", &CompileSpec::workspace_size)
.def_readwrite("torch_fallback", &CompileSpec::torch_fallback)
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double);
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double)
.def_readwrite("default_torch_execution", &CompileSpec::default_torch_execution);

py::class_<TorchFallback>(ts_sub_mod, "TorchFallback")
.def(py::init<>())
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:

if "torch_fallback" in compile_spec:
info.torch_fallback = _parse_torch_fallback(compile_spec["torch_fallback"])


if "default_torch_execution" in compile_spec:
assert type(compile_spec["default_torch_execution"]) is bool
info.default_torch_execution = compile_spec["default_torch_execution"]
return info


Expand Down
19 changes: 16 additions & 3 deletions py/torch_tensorrt/ts/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def compile(module: torch.jit.ScriptModule,
require_full_compilation=False,
min_block_size=3,
torch_executed_ops=[],
torch_executed_modules=[]) -> torch.jit.ScriptModule:
torch_executed_modules=[],
default_torch_execution=False,
trt_executed_modules=[]) -> torch.jit.ScriptModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT

Takes a existing TorchScript module and a set of settings to configure the compiler
Expand Down Expand Up @@ -74,6 +76,8 @@ def compile(module: torch.jit.ScriptModule,
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
default_torch_execution (bool): If turned on, modules would be executed in torch by default, and those specified by trt_executed_modules would be compiled
trt_executed_modules (List[str]): List of modules that would be compiled to TensorRT. An error will be thrown if this list is not empty but ``default_torch_execution`` is False

Returns:
torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT
Expand All @@ -87,6 +91,14 @@ def compile(module: torch.jit.ScriptModule,
raise ValueError(
"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: "
+ torch_executed_ops + ", torch_executed_modules: " + torch_executed_modules)

if default_torch_execution:
if require_full_compilation:
raise ValueError("require_full_compilation is enabled however default_torch_execution mode is also switched on, which causes confliction")
if len(torch_executed_modules) > 0:
raise ValueError("With default_torch_execution=True, it is unnecessary to specify torch_executed_modules")
if len(trt_executed_modules) == 0:
raise ValueError("With default_torch_execution=True, it is necesary to specify some trt_executed_modules otherwise nothing will be compiled")

spec = {
"inputs": inputs,
Expand All @@ -105,9 +117,10 @@ def compile(module: torch.jit.ScriptModule,
"torch_fallback": {
"enabled": not require_full_compilation,
"forced_fallback_ops": torch_executed_ops,
"forced_fallback_modules": torch_executed_modules,
"forced_fallback_modules": torch_executed_modules if not default_torch_execution else trt_executed_modules,
"min_block_size": min_block_size
}
},
"default_torch_execution": default_torch_execution
}

compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
Expand Down