Skip to content

Commit 48a7f28

Browse files
authored
Merge pull request #1201 from pytorch/squashed_collections
feat: support for grouped inputs
2 parents b62df15 + 223dfd1 commit 48a7f28

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1694
-326
lines changed

core/compiler.cpp

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ GraphAndMapping ConstructFallbackGraph(
256256
// update the input ranges for each segments
257257
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
258258

259+
// TODO mapping Inputs Ivalue to flatten one here
259260
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
260261
auto temp_g = std::make_shared<torch::jit::Graph>();
261262
auto device_spec = convert_cfg.engine_settings.device;
@@ -306,57 +307,80 @@ void MapInputsAndDetermineDTypes(
306307
CompileSpec& cfg,
307308
std::shared_ptr<torch::jit::Graph>& g,
308309
ir::StaticParams& static_params,
309-
ir::TypeMap& first_use_type_map) {
310-
// Associate input specs with inputs
311-
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
312-
313-
for (auto& in : g->inputs()) {
314-
if (static_params.find(in) == static_params.end()) {
315-
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
316-
auto est_type_opt = first_use_type_map.find(in)->second;
317-
if (est_type_opt && !spec.dtype_is_user_defined) {
310+
ir::CollectionTypeMap& first_use_type_map) {
311+
cfg.convert_info.collection_input_spec_map =
312+
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
313+
314+
auto collection_inputs = ir::get_collection_inputs(g, static_params);
315+
LOG_DEBUG(
316+
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
317+
<< g->inputs().size() << ", CollectionInputSpecMap size is" << collection_inputs.size());
318+
319+
for (auto in : collection_inputs) {
320+
std::vector<ir::Input>& spec = cfg.convert_info.collection_input_spec_map.find(in)->second;
321+
std::vector<c10::optional<at::ScalarType>> est_type_opt;
322+
323+
auto est_it = first_use_type_map.find(in);
324+
if (est_it != first_use_type_map.end()) {
325+
est_type_opt = first_use_type_map.find(in)->second;
326+
}
327+
// traverse elements in est_type_out and spec
328+
for (size_t i = 0; i < est_type_opt.size(); i++) {
329+
if (est_type_opt[i] && !spec[i].dtype_is_user_defined) {
318330
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
319331
// type
320332
LOG_INFO(
321-
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
322-
<< in->debugName() << " has type " << est_type_opt.value()
323-
<< ". If this is incorrect explicitly set dtype for input and file a bug");
324-
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
325-
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
333+
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
334+
<< in->debugName() << " has type " << est_type_opt[i].value());
335+
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
336+
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
326337
// If we cannot calculate the type and the user did not define the type, then default to FP32
327338
LOG_WARNING(
328339
"Cannot infer input type from calcuations in graph for input "
329340
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
330-
spec.dtype = nvinfer1::DataType::kFLOAT;
331-
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
332-
if (!est_type_opt) {
333-
LOG_INFO("Cannot infer input tensor dtype in graph. Using user provided input dtype settings");
334-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
341+
spec[i].dtype = nvinfer1::DataType::kFLOAT;
342+
} else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
343+
if (!est_type_opt[i]) {
344+
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
345+
std::stringstream ss;
346+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
347+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
348+
ss << ". The compiler is going to use the user setting "
349+
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
350+
auto warn_str = ss.str();
351+
LOG_WARNING(warn_str);
352+
// Overwrite type map with user settings
353+
first_use_type_map[in][i] = {
354+
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
355+
335356
} else {
336-
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
357+
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
358+
est_type_opt[i].value()) {
337359
std::stringstream ss;
338360
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
339-
ss << cfg.convert_info.inputs.find(in)->second.dtype;
361+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
340362
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
341-
ss << est_type_opt.value() << std::endl;
342-
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
363+
ss << est_type_opt[i].value() << std::endl;
364+
ss << "The compiler is going to use the user setting "
365+
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
343366
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
344367
ss << "compatibility with PyTorch's data type convention is required.\n";
345368
ss << "If you do indeed see errors at runtime either:\n";
346369
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
347370
ss << "- Disable partial compilation by setting require_full_compilation to True";
348371
auto warn_str = ss.str();
349372
LOG_WARNING(warn_str);
373+
// Overwrite type map with user settings
374+
first_use_type_map[in][i] = {
375+
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
350376
}
351-
// Overwrite type map with user settings
352-
// We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
353-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
354377
}
355378
} else {
356379
// The user defined the type so no changes are necessary
357380
}
358381
}
359382
}
383+
// }
360384
}
361385

362386
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -370,7 +394,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
370394
auto params = graph_and_parameters.second;
371395
auto static_params = ir::get_static_params(g->inputs(), params);
372396
// Infer the type of an input from the weights of the calculation
373-
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
397+
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
374398

375399
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
376400

@@ -395,23 +419,26 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
395419
auto params = graph_and_parameters.second;
396420
auto static_params = ir::get_static_params(g->inputs(), params);
397421
// Infer the type of an input from the weights of the calculation
398-
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
422+
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
399423

400424
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
401425
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
426+
auto outputIsCollection = conversion::OutputIsCollection(g->block());
402427
if (cfg.partition_info.enabled &&
403428
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
404429
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
405430
LOG_INFO("Skipping partitioning since model is fully supported");
406431
}
407432

408433
if (cfg.partition_info.enabled &&
409-
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
410-
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
411-
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
434+
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
435+
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
436+
outputIsCollection)) {
412437
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
413-
auto graph_and_mapping =
414-
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
438+
auto collection_input_ivalues_map =
439+
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
440+
auto graph_and_mapping = ConstructFallbackGraph(
441+
new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
415442
new_g = graph_and_mapping.first;
416443
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
417444
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
@@ -429,6 +456,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
429456
TORCHTRT_CHECK(
430457
conversion::VerifyConverterSupportForBlock(g->block()),
431458
"Not all operations in graph are supported by the compiler");
459+
// TODO find the right
432460
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
433461
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
434462
}

core/compiler.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
#include "core/partitioning/partitioning.h"
99
#include "core/runtime/runtime.h"
1010
#include "torch/csrc/jit/api/module.h"
11+
#include "torch/csrc/jit/ir/ir.h"
1112

1213
namespace torch_tensorrt {
1314
namespace core {
1415

1516
struct CompileSpec {
16-
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {}
17-
std::vector<ir::Input> inputs;
17+
CompileSpec(std::vector<ir::Input> inputs) : graph_inputs(inputs) {}
18+
CompileSpec(torch::jit::IValue& input_signature) : graph_inputs(input_signature) {}
19+
ir::GraphInputs graph_inputs;
1820
conversion::ConversionInfo convert_info;
1921
lowering::LowerInfo lower_info;
2022
partitioning::PartitionInfo partition_info;

core/conversion/conversion.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,11 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
135135
<< "please report this error to https://www.github.com/NVIDIA/Torch-TensorRT/issues");
136136
}
137137

138-
void AddInputs(
139-
ConversionCtx* ctx,
140-
c10::ArrayRef<const torch::jit::Value*> inputs,
141-
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs) {
138+
void AddInputs(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> inputs, ConversionInfo& conversion_info) {
139+
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs = conversion_info.inputs;
140+
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>> collection_input_spec =
141+
conversion_info.collection_input_spec_map;
142+
142143
std::vector<const torch::jit::Value*> input_tensors;
143144
for (auto in : inputs) {
144145
// Disregarding inputs that are not tensors
@@ -166,9 +167,15 @@ void AddInputs(
166167
for (auto input : input_tensors) {
167168
const torch::jit::Value* in = input;
168169
TORCHTRT_CHECK(
169-
input_specs.find(in) != input_specs.end(),
170+
input_specs.find(in) != input_specs.end() || collection_input_spec.find(in) != collection_input_spec.end(),
170171
"Cannot find an input spec associated with input: " << in->debugName());
171-
ir::Input& spec = input_specs.find(in)->second;
172+
ir::Input spec;
173+
if (input_specs.find(in) != input_specs.end()) {
174+
spec = input_specs.find(in)->second;
175+
} else {
176+
spec = collection_input_spec.find(in)->second[0]; // assume input is tensor
177+
}
178+
// ir::Input& spec = input_specs.find(in)->second;
172179

173180
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
174181
LOG_INFO(
@@ -408,7 +415,7 @@ void ConvertBlockToNetDef(
408415

409416
auto inputs = b->inputs();
410417
AddParamsToCtxValueMap(ctx, static_params);
411-
AddInputs(ctx, inputs, build_info.inputs);
418+
AddInputs(ctx, inputs, build_info);
412419

413420
auto nodes = b->nodes();
414421

@@ -549,6 +556,16 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
549556
return convertable_ops;
550557
}
551558

559+
bool OutputIsCollection(const torch::jit::Block* b) {
560+
for (auto out : b->outputs()) {
561+
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
562+
out->type()->kind() == torch::jit::TypeKind::ListType) {
563+
return true;
564+
}
565+
}
566+
return false;
567+
}
568+
552569
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
553570
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
554571
if (unsupported_ops.size() != 0) {

core/conversion/conversion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace conversion {
1313

1414
struct ConversionInfo {
1515
ir::InputSpecMap inputs;
16+
ir::CollectionInputSpecMap collection_input_spec_map;
1617
BuilderSettings engine_settings;
1718
};
1819

@@ -25,6 +26,8 @@ std::string ConvertBlockToEngine(
2526

2627
bool OpSupported(const torch::jit::Node* n);
2728

29+
bool OutputIsCollection(const torch::jit::Block* b);
30+
2831
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);
2932

3033
c10::optional<torch::jit::IValue> EvaluateNode(

core/conversion/converters/converter_util.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ nvinfer1::ILayer* add_elementwise(
6565
nvinfer1::ITensor* self,
6666
nvinfer1::ITensor* other,
6767
const std::string& name) {
68+
if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) {
69+
LOG_DEBUG("Type mismatch, casting other to " << self->getType());
70+
other = castITensor(ctx, other, self->getType());
71+
} else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) {
72+
LOG_DEBUG("Type mismatch, casting self to " << other->getType());
73+
self = castITensor(ctx, self, other->getType());
74+
}
6875
// ensure self to have larger number of dimension
6976
bool swapSelfOther = false;
7077
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {

core/conversion/converters/impl/element_wise.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
412412
// Should implement self * other
413413
auto self = args[0].ITensorOrFreeze(ctx);
414414
auto other = args[1].ITensorOrFreeze(ctx);
415+
415416
auto mul =
416417
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
417418
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);
@@ -426,6 +427,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
426427
// TODO: Remove with functionalization
427428
auto self = args[0].ITensorOrFreeze(ctx);
428429
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
430+
429431
auto mul =
430432
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
431433
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);

core/conversion/evaluators/aten.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@ namespace conversion {
1919
namespace evaluators {
2020
namespace {
2121

22-
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
23-
if (idx < 0) {
24-
// Handle negative indexing
25-
idx = list_size + idx;
26-
}
27-
return idx;
28-
}
29-
3022
DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
3123
eq,
3224
"aten::eq",

core/conversion/evaluators/eval_util.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@ namespace core {
1212
namespace conversion {
1313
namespace evaluators {
1414

15+
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
16+
if (idx < 0) {
17+
// Handle negative indexing
18+
idx = list_size + idx;
19+
}
20+
return idx;
21+
}
22+
23+
1524
// TODO: Switch back to PyTorch canonical implimentation
1625
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
1726
if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast<c10::FunctionType>()) {

core/conversion/evaluators/eval_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ at::Tensor createTensorFromList(
1313
const torch::jit::IValue& dtype,
1414
const torch::jit::IValue& device);
1515

16+
int64_t normalizeIndex(int64_t idx, int64_t list_size);
17+
1618
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU);
1719

1820
} // namespace evaluators

0 commit comments

Comments
 (0)