Skip to content

Commit f00de94

Browse files
committed
feat(//core/ir): Implementing new internal input spec type
This commit implements the new input spec type trtorch::core::ir::Input, which incapsulates InputRange and adds the new dtype and tensor format arguments. It also changes DataType op_precision in the engine settings to std::set<nvinfer1::DataType> enabled_precisions, allowing the compiler to set more than a single precision without resorting to catch all rules such as FP32 and Int8 without FP16. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6dc3bfa commit f00de94

22 files changed

+333
-188
lines changed

core/compiler.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
194194
LOG_INFO(*g << "(LoweringGraph)\n");
195195

196196
// segment the graph and convert segmented TensorRT block
197-
auto segmented_blocks = partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info);
197+
auto segmented_blocks = partitioning::Partition(g, convert_cfg.inputs, cfg.partition_info);
198198
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
199199
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
200200
return mod;
@@ -208,16 +208,16 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
208208
for (auto& seg_block : segmented_blocks) {
209209
std::string cur_block_target =
210210
seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
211-
LOG_INFO(*seg_block.g() << "(MiniGraphIn" << cur_block_target << "Block)\n");
211+
LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n");
212212
std::ostringstream trt_engine_id;
213213
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
214214
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
215-
std::vector<ir::InputRange> input_ranges;
215+
std::vector<ir::Input> inputs;
216216
for (auto& shape : seg_block.in_shape()) {
217-
input_ranges.push_back(ir::InputRange(shape));
217+
inputs.push_back(ir::Input(shape));
218218
}
219219
// update the input ranges for each segments
220-
convert_cfg.input_ranges = input_ranges;
220+
convert_cfg.inputs = inputs;
221221
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
222222
auto temp_g = std::make_shared<torch::jit::Graph>();
223223
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id.str(), true);

core/compiler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace trtorch {
1111
namespace core {
1212

1313
struct CompileSpec {
14-
CompileSpec(std::vector<ir::InputRange> input_ranges) : convert_info(std::move(input_ranges)) {}
14+
CompileSpec(std::vector<ir::Input> inputs) : convert_info(std::move(inputs)) {}
1515
conversion::ConversionInfo convert_info;
1616
partitioning::PartitionInfo partition_info;
1717
};

core/conversion/conversion.cpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
128128
void AddInputs(
129129
ConversionCtx* ctx,
130130
at::ArrayRef<const torch::jit::Value*> inputs,
131-
std::vector<ir::InputRange>& input_dims) {
131+
std::vector<ir::Input>& input_specs) {
132132
std::vector<const torch::jit::Value*> input_tensors;
133133
for (auto in : inputs) {
134134
// Disregarding inputs that are not tensors
@@ -142,36 +142,33 @@ void AddInputs(
142142
}
143143
}
144144

145+
std::stringstream ss;
146+
ss << "Input Dimension Specs: [\n";
147+
for (auto i : input_specs) {
148+
ss << " " << i << ",";
149+
}
150+
ss << ']';
151+
LOG_DEBUG(ss.str());
152+
145153
TRTORCH_CHECK(
146-
input_tensors.size() == input_dims.size(),
154+
input_tensors.size() == input_specs.size(),
147155
"Expected dimension specifications for all input tensors"
148-
<< ", but found " << input_tensors.size() << " input tensors and " << input_dims.size()
156+
<< ", but found " << input_tensors.size() << " input tensors and " << input_specs.size()
149157
<< " dimension specs (conversion.AddInputs)");
150158

151159
auto profile = ctx->builder->createOptimizationProfile();
152160

153-
TRTORCH_CHECK(
154-
ctx->input_dtypes.size() == 0 || ctx->input_dtypes.size() == input_tensors.size(),
155-
"Number of input_dtypes : " << ctx->input_dtypes.size()
156-
<< " should either be 0 or equal to number of input_tensors which is "
157-
<< input_tensors.size() << " (conversion.AddInputs)");
158-
159-
// If the input_dtypes is not provided, assume all the input tensors to be in float32
160-
if (ctx->input_dtypes.size() == 0) {
161-
LOG_DEBUG("Input datatypes are not provided explicitly. Default float32 datatype is being used for all inputs");
162-
ctx->input_dtypes = std::vector<nvinfer1::DataType>{input_tensors.size(), nvinfer1::DataType::kFLOAT};
163-
}
164-
165161
for (size_t i = 0; i < input_tensors.size(); i++) {
166162
auto in = input_tensors[i];
167-
auto dims = input_dims[i];
163+
auto dims = input_specs[i];
168164
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
169165
LOG_INFO(
170166
ctx->logger,
171-
"Adding Input " << in->debugName() << " named : " << name << ", shape: " << dims.input_shape
172-
<< ", dtype : " << ctx->input_dtypes[i] << " in engine (conversion.AddInputs)");
173-
auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_dtypes[i], dims.input_shape);
167+
"Adding Input " << in->debugName() << " (named: " << name << "): " << dims << " in engine (conversion.AddInputs)");
168+
169+
auto trt_in = ctx->net->addInput(name.c_str(), dims.dtype, dims.input_shape);
174170
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
171+
trt_in->setAllowedFormats(1U << static_cast<int>(dims.format));
175172

176173
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, dims.min);
177174
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt);
@@ -191,7 +188,7 @@ void AddInputs(
191188

192189
ctx->cfg->addOptimizationProfile(profile);
193190
#if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR >= 1)
194-
if (ctx->op_precision == nvinfer1::DataType::kINT8) {
191+
if (ctx->enabled_precisions.find(nvinfer1::DataType::kINT8) != ctx->enabled_precisions.end()) {
195192
ctx->cfg->setCalibrationProfile(profile);
196193
}
197194
#endif
@@ -363,7 +360,7 @@ void ConvertBlockToNetDef(
363360

364361
auto inputs = b->inputs();
365362
AddParamsToCtxValueMap(ctx, static_params);
366-
AddInputs(ctx, inputs, build_info.input_ranges);
363+
AddInputs(ctx, inputs, build_info.inputs);
367364

368365
auto nodes = b->nodes();
369366

core/conversion/conversion.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ namespace core {
1212
namespace conversion {
1313

1414
struct ConversionInfo {
15-
std::vector<ir::InputRange> input_ranges;
15+
std::vector<ir::Input> inputs;
1616
BuilderSettings engine_settings;
17-
ConversionInfo(std::vector<ir::InputRange> input_ranges)
18-
: input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {}
17+
ConversionInfo(std::vector<ir::Input> inputs)
18+
: inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
1919
};
2020

2121
// TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ namespace conversion {
1010
// clang-format off
1111
std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
1212
os << "Settings requested for TensorRT engine:" \
13-
<< "\n Operating Precision: " << s.op_precision \
14-
<< "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \
13+
<< "\n Enabled Precisions: ";
14+
for (auto p = s.enabled_precisions.begin(); p != s.enabled_precisions.end(); ++p) {
15+
os << *p << ' ';
16+
}
17+
os << "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \
1518
<< "\n Truncate Long and Double: " << s.truncate_long_and_double \
1619
<< "\n Make Refittable Engine: " << s.refit \
1720
<< "\n Debuggable Engine: " << s.debug \
@@ -57,30 +60,29 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5760
LOG_DEBUG(build_settings);
5861
cfg = builder->createBuilderConfig();
5962

60-
switch (settings.op_precision) {
61-
case nvinfer1::DataType::kHALF:
62-
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
63-
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
64-
break;
65-
case nvinfer1::DataType::kINT8:
66-
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does not support INT8");
67-
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
68-
if (!settings.strict_types) {
63+
for(auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
64+
switch (*p) {
65+
case nvinfer1::DataType::kHALF:
66+
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
6967
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
70-
}
71-
TRTORCH_CHECK(
72-
settings.calibrator != nullptr,
73-
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
74-
cfg->setInt8Calibrator(settings.calibrator);
75-
break;
76-
case nvinfer1::DataType::kFLOAT:
77-
case nvinfer1::DataType::kINT32:
78-
case nvinfer1::DataType::kBOOL:
79-
default:
80-
break;
68+
break;
69+
case nvinfer1::DataType::kINT8:
70+
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does not support INT8");
71+
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
72+
TRTORCH_CHECK(
73+
settings.calibrator != nullptr,
74+
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
75+
cfg->setInt8Calibrator(settings.calibrator);
76+
break;
77+
case nvinfer1::DataType::kFLOAT:
78+
case nvinfer1::DataType::kINT32:
79+
case nvinfer1::DataType::kBOOL:
80+
default:
81+
break;
82+
}
8183
}
8284

83-
op_precision = settings.op_precision;
85+
enabled_precisions = settings.enabled_precisions;
8486
input_dtypes = settings.input_dtypes;
8587

8688
if (settings.disable_tf32) {
@@ -119,7 +121,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
119121
static_cast<int>(settings.device.dla_core) < nbDLACores,
120122
"Configured DLA Core ID: " << settings.device.dla_core
121123
<< " not available. Total number of available DLA Cores: " << nbDLACores);
122-
TRTORCH_CHECK(settings.op_precision != nvinfer1::DataType::kFLOAT, "DLA supports only fp16 or int8 precision");
124+
TRTORCH_CHECK(settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(), "DLA supports only fp16 or int8 precision");
123125
cfg->setDLACore(settings.device.dla_core);
124126
}
125127
}

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <map>
44
#include <memory>
55
#include <unordered_map>
6+
#include <set>
67

78
#include "NvInfer.h"
89
#include "torch/csrc/jit/ir/ir.h"
@@ -23,7 +24,7 @@ struct Device {
2324
};
2425

2526
struct BuilderSettings {
26-
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
27+
std::set<nvinfer1::DataType> enabled_precisions = {nvinfer1::DataType::kFLOAT};
2728
std::vector<nvinfer1::DataType> input_dtypes;
2829
bool disable_tf32 = false;
2930
bool refit = false;
@@ -59,7 +60,7 @@ struct ConversionCtx {
5960
nvinfer1::INetworkDefinition* net;
6061
nvinfer1::IBuilderConfig* cfg;
6162
std::vector<nvinfer1::DataType> input_dtypes;
62-
nvinfer1::DataType op_precision;
63+
std::set<nvinfer1::DataType> enabled_precisions;
6364
BuilderSettings settings;
6465
util::logging::TRTorchLogger logger;
6566
// Pointers to data that needs to remain alive until conversion is done

core/conversion/converters/impl/activation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ auto acthardtanh TRTORCH_UNUSED =
177177
std::string pluginName = "CustomGeluPluginDynamic";
178178
nvinfer1::PluginFieldCollection fc;
179179
std::vector<nvinfer1::PluginField> f;
180-
int type_id = ctx->settings.op_precision == nvinfer1::DataType::kFLOAT
180+
//REVIEW is this right?
181+
int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) == ctx->settings.enabled_precisions.end()
181182
? 0
182183
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
183184
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));

core/ir/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ cc_library(
1313
"ir.h"
1414
],
1515
srcs = [
16-
"InputRange.cpp",
16+
"Input.cpp"
1717
],
1818
deps = [
1919
"@tensorrt//:nvinfer",

0 commit comments

Comments
 (0)