Skip to content

Add a "no_conversion" flow to torch-tensorrt #1360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,37 @@ void AddSegmentedBlockToGraph(
return;
}

void AddSegmentedBlockToGraphAsFunction(
torch::jit::script::Module& new_mod,
std::shared_ptr<torch::jit::Graph>& g,
partitioning::SegmentedBlock& seg,
const std::string& function_name) {
auto method_self = seg.g()->insertInput(0, "self_1");
method_self->setType(new_mod.type());
auto engine_method = new_mod._ivalue()->compilation_unit()->create_function(function_name, seg.g());
auto schema = util::GenerateGraphSchema(engine_method->name(), seg.g());
std::vector<torch::jit::Value*> method_inputs;

auto self = g->insertInput(0, "self_1");
self->setType(new_mod.type());
method_inputs.push_back(self);

new_mod.type()->addMethod(engine_method);
engine_method->setSchema(schema);
for (size_t idx = 0UL; idx < seg.raw_inputs().size(); ++idx) {
auto in_val = g->addInput("input_" + std::to_string(idx));
in_val->setType(seg.raw_inputs()[idx]->type());
method_inputs.push_back(in_val);
}
auto function_call = g->create(at::prim::CallMethod, method_inputs, seg.raw_outputs().size());
function_call->s_(at::attr::name, function_name);
g->appendNode(function_call);

for (auto out : function_call->outputs()) {
g->registerOutput(out);
}
}

typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
GraphAndMapping;

Expand Down Expand Up @@ -245,25 +276,32 @@ GraphAndMapping ConstructFallbackGraph(
trt_engine_id << reinterpret_cast<const int*>(&seg_block);

if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
auto shapes = seg_block.in_shapes();
auto types = seg_block.in_types();
std::vector<ir::Input> inputs;
for (size_t i = 0; i < shapes.size(); i++) {
auto in = ir::Input(shapes[i]);
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
inputs.push_back(in);
}
// update the input ranges for each segments
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
if (partition_info.no_conversion) {
// Embed a method call for each segment which would be converted to a TRT engine in the standard flow
auto temp_g = std::make_shared<torch::jit::Graph>();
AddSegmentedBlockToGraphAsFunction(new_mod, temp_g, seg_block, "trt_engine_" + trt_engine_id.str());
seg_block.update_graph(temp_g);
} else {
auto shapes = seg_block.in_shapes();
auto types = seg_block.in_types();
std::vector<ir::Input> inputs;
for (size_t i = 0; i < shapes.size(); i++) {
auto in = ir::Input(shapes[i]);
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
inputs.push_back(in);
}
// update the input ranges for each segments
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);

// TODO mapping Inputs Ivalue to flatten one here
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
// TODO mapping Inputs Ivalue to flatten one here
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);

seg_block.update_graph(temp_g);
seg_block.update_graph(temp_g);
}
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
} else {
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
Expand Down Expand Up @@ -434,7 +472,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
if (cfg.partition_info.enabled &&
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection)) {
outputIsCollection || cfg.partition_info.no_conversion)) {
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
auto collection_input_ivalues_map =
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
Expand Down
1 change: 1 addition & 0 deletions core/partitioning/PartitionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace partitioning {

struct PartitionInfo {
bool enabled = false;
bool no_conversion = false;
uint64_t min_block_size = 1;
std::vector<std::string> forced_fallback_operators;
bool truncate_long_and_double;
Expand Down
6 changes: 6 additions & 0 deletions cpp/include/torch_tensorrt/torch_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,12 @@ struct CompileSpec {
*/
bool require_full_compilation = false;

/**
* Do not convert TensorRT convertible partitions to engines. Instead embed them in the PyTorch graph as function
* calls
*/
bool no_conversion = false;

/**
* Minimum number of contiguous supported operators to compile a subgraph to TensorRT
*/
Expand Down
1 change: 1 addition & 0 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {

internal.partition_info.enabled = !external.require_full_compilation;
internal.partition_info.min_block_size = external.min_block_size;
internal.partition_info.no_conversion = external.no_conversion;
internal.partition_info.forced_fallback_operators = std::move(external.torch_executed_ops);
internal.partition_info.truncate_long_and_double = external.truncate_long_and_double;
internal.lower_info.forced_fallback_modules = std::move(external.torch_executed_modules);
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void RegisterTRTCompileSpec() {

ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, enabled);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, min_block_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, no_conversion);
ADD_FIELD_GET_SET_REGISTRATION(
TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, forced_fallback_operators);
ADD_FIELD_GET_SET_REGISTRATION(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
info.partition_info.enabled = torch_fallback.enabled;
info.partition_info.min_block_size = torch_fallback.min_block_size;
info.partition_info.no_conversion = torch_fallback.no_conversion;
info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
info.partition_info.truncate_long_and_double = truncate_long_and_double;
info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules;
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ nvinfer1::DeviceType toTRTDeviceType(DeviceType value);

struct TorchFallback : torch::CustomClassHolder {
bool enabled;
bool no_conversion;
int64_t min_block_size;
std::vector<std::string> forced_fallback_operators;
std::vector<std::string> forced_fallback_modules;
TorchFallback() : enabled(false), min_block_size(1) {}

ADD_FIELD_GET_SET(enabled, bool);
ADD_FIELD_GET_SET(no_conversion, bool);
ADD_FIELD_GET_SET(min_block_size, int64_t);
ADD_FIELD_GET_SET(forced_fallback_operators, std::vector<std::string>);
ADD_FIELD_GET_SET(forced_fallback_modules, std::vector<std::string>);
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ PYBIND11_MODULE(_C, m) {
.def(py::init<>())
.def("__str__", &torch_tensorrt::pyapi::TorchFallback::to_str)
.def_readwrite("enabled", &TorchFallback::enabled)
.def_readwrite("no_conversion", &TorchFallback::no_conversion)
.def_readwrite("min_block_size", &TorchFallback::min_block_size)
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators)
.def_readwrite("forced_fallback_modules", &TorchFallback::forced_fallback_modules);
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback:
else:
assert isinstance(fallback_info["enabled"], bool)
info.enabled = fallback_info["enabled"]
if "no_conversion" in fallback_info:
assert isinstance(fallback_info["no_conversion"], bool)
info.no_conversion = fallback_info["no_conversion"]
if "min_block_size" in fallback_info:
assert isinstance(fallback_info["min_block_size"], int)
info.min_block_size = fallback_info["min_block_size"]
Expand Down Expand Up @@ -460,6 +463,7 @@ def TensorRTCompileSpec(

torch_fallback = torch.classes.tensorrt._TorchFallback()
torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled)
torch_fallback._set_no_conversion(parsed_spec.no_conversion)
torch_fallback._set_min_block_size(parsed_spec.torch_fallback.min_block_size)
torch_fallback._set_forced_fallback_operators(
parsed_spec.torch_fallback.forced_fallback_operators
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/ts/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def compile(
calibrator=None,
truncate_long_and_double=False,
require_full_compilation=False,
no_conversion=False,
min_block_size=3,
torch_executed_ops=[],
torch_executed_modules=[],
Expand Down Expand Up @@ -91,6 +92,7 @@ def compile(
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
no_conversion (bool): Do not convert TensorRT convertible segments to TensorRT engines. Embed the convertible segments in the PyTorch graph as function calls
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
Expand Down Expand Up @@ -130,6 +132,7 @@ def compile(
"forced_fallback_ops": torch_executed_ops,
"forced_fallback_modules": torch_executed_modules,
"min_block_size": min_block_size,
"no_conversion": no_conversion,
},
}

Expand Down
39 changes: 39 additions & 0 deletions tests/core/partitioning/test_conditionals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ size_t count_trt_engines_in_conditionals(std::shared_ptr<torch::jit::Graph> g) {
return count;
}

size_t count_trt_engine_functions_in_conditionals(std::shared_ptr<torch::jit::Graph> g) {
size_t count = 0;
for (auto n : g->nodes()) {
if (n->kind() == torch::jit::prim::If) {
std::vector<torch::jit::Block*> blocks{n->blocks()[0], n->blocks()[1]};
for (auto cur_block : blocks) {
for (auto n : cur_block->nodes()) {
if (n->kind().toQualString() == std::string("prim::CallMethod")) {
++count;
}
}
}
}
}
return count;
}

TEST(Partitioning, FallbackOnConditionalsCorrectly) {
torch::jit::script::Module mod;
try {
Expand All @@ -43,6 +60,28 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
ASSERT_TRUE(conditional_engines_count == 2);
}

TEST(Partitioning, FallbackOnConditionalsCorrectlyNoConversion) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("external/torch-tensorrt-tests/modules/conditional_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}

std::vector<torch_tensorrt::core::ir::Input> inputs{torch_tensorrt::core::ir::Input({3, 3, 16, 16})};
auto g = mod.get_method("forward").graph();
torch_tensorrt::core::CompileSpec cfg(inputs);
cfg.partition_info.enabled = true;
cfg.partition_info.no_conversion = true;
torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
auto new_g = new_mod.get_method("forward").graph();
std::cout << *new_g << std::endl;
auto conditional_engines_count = count_trt_engine_functions_in_conditionals(new_g);

ASSERT_TRUE(conditional_engines_count == 1);
}

TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
torch::jit::script::Module mod;
try {
Expand Down
62 changes: 62 additions & 0 deletions tests/core/partitioning/test_fallback_graph_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@ TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
}

TEST(Partitioning, ComputeResNet50NoConvertFallbackGraphCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("external/torch-tensorrt-tests/modules/resnet50_traced.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}

const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};

torch_tensorrt::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;
cfg.partition_info.no_conversion = true;
cfg.partition_info.forced_fallback_operators.push_back("aten::add");

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
}

TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
torch::jit::script::Module mod;
try {
Expand Down Expand Up @@ -66,4 +97,35 @@ TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
}

TEST(Partitioning, ComputeMobileNetNoConvertFallbackGraphCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("external/torch-tensorrt-tests/modules/mobilenet_v2_traced.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}

const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
auto g = mod.get_method("forward").graph();
torch_tensorrt::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;
cfg.partition_info.no_conversion = true;
cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh");

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
}
#endif