Skip to content

Commit 36d27da

Browse files
authored
Merge pull request #44 from narendasan/ptq
Post training quantization support in TRTorch
2 parents 6be3f1f + 54a24b3 commit 36d27da

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1980
-182
lines changed

.gitignore

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,14 @@ experiments/
1313
py/build/
1414
py/tmp/
1515
py/.eggs
16-
.vscode/
16+
.vscode/
17+
.DS_Store
18+
._DS_Store
19+
*.pth
20+
*.pyc
21+
cpp/ptq/training/vgg16/data/*
22+
*.bin
23+
cpp/ptq/datasets/data/
24+
tests/accuracy/datasets/data/*
25+
._.DS_Store
26+
*.tar.gz

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,19 @@ More Information / System Architecture:
1717
...
1818
auto compile_settings = trtorch::ExtraInfo(dims);
1919
// FP16 execution
20-
compile_settings.op_precision = torch::kHalf;
20+
compile_settings.op_precision = torch::kFloat;
2121
// Compile module
2222
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
2323
// Run like normal
2424
auto results = trt_mod.forward({in_tensor});
2525
...
2626
```
2727
28+
> Notes on running in lower precisions:
29+
> - Set precision with extra_info.op_precision
30+
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
31+
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
32+
2833
## Platform Support
2934
3035
| Platform | Support |

core/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ cc_library(
1616
"@libtorch//:libtorch",
1717
"@tensorrt//:nvinfer"
1818
],
19-
alwayslink=True,
19+
alwayslink=True,
2020
)
2121

2222

core/compiler.cpp

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,24 @@
2424
namespace trtorch {
2525
namespace core {
2626

27-
c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
27+
c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
2828

2929
std::vector<c10::Argument> args;
3030
for (auto in : g->inputs()) {
3131
args.push_back(c10::Argument(in->debugName(), in->type()));
3232
}
33-
33+
3434
std::vector<c10::Argument> returns;
3535
for (auto out : g->outputs()) {
3636
returns.push_back(c10::Argument(out->debugName(), out->type()));
3737
}
38-
38+
3939
return c10::FunctionSchema(method_name, method_name, args, returns);
4040
}
4141

4242

4343
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
44-
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
44+
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
4545
auto schema = execution::GetEngineFunctionSchema(uid);
4646
auto num_io = execution::GetEngineIO(uid);
4747

@@ -53,14 +53,14 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
5353
in_val->setType(c10::TensorType::get());
5454
graph_inputs.push_back(in_val);
5555
}
56-
56+
5757
auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef<torch::jit::Value*>(graph_inputs), num_io.second);
5858
g->block()->appendNode(engine_node);
5959

6060
for (auto o : engine_node->outputs()) {
6161
g->registerOutput(o);
6262
}
63-
63+
6464
return;
6565
}
6666

@@ -69,48 +69,50 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
6969
auto g = mod.get_method(method_name).graph();
7070
// Go through PyTorch Lowering to simplify graph and extract weight parameters
7171
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
72-
72+
7373
g = graph_and_parameters.first;
74-
74+
7575
// Go through TRTorch Lowering to reformat graph to be conversion friendly
7676
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
7777
lowering::LowerGraph(g);
78-
78+
7979
auto params = graph_and_parameters.second;
8080
auto named_params = conversion::get_named_params(g->inputs(), params);
8181
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
82-
82+
8383
// Is this necessary?
8484
lowering::LowerBlock(g->block());
85-
85+
8686
return conversion::VerifyConverterSupportForBlock(g->block());
8787
}
8888

8989
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
9090
std::string method_name,
91-
conversion::ExtraInfo cfg) {
91+
ExtraInfo cfg) {
92+
auto convert_cfg = std::move(cfg.convert_info);
93+
9294
auto g = mod.get_method(method_name).graph();
9395
// Go through PyTorch Lowering to simplify graph and extract weight parameters
9496
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
95-
97+
9698
g = graph_and_parameters.first;
97-
99+
98100
// Go through TRTorch Lowering to reformat graph to be conversion friendly
99101
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
100102
lowering::LowerGraph(g);
101-
103+
102104
auto params = graph_and_parameters.second;
103105
auto named_params = conversion::get_named_params(g->inputs(), params);
104106
LOG_INFO(*g << "(CompileGraph)\n");
105-
107+
106108
// Is this necessary?
107109
lowering::LowerBlock(g->block());
108-
auto engine = ConvertBlockToEngine(g->block(), cfg, named_params);
110+
auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params);
109111
return std::move(engine);
110112
}
111113

112114
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
113-
conversion::ExtraInfo cfg) {
115+
ExtraInfo cfg) {
114116
// TODO: Should be doing a functional transform but need PR #31978
115117
// [jit] More robust mangling
116118
// torch::jit::script::Module new_mod = mod.clone();
@@ -128,7 +130,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
128130

129131
return new_mod;
130132
}
131-
133+
132134
} // namespace core
133135
} // namespace trtorch
134136

core/compiler.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@
66

77
namespace trtorch {
88
namespace core {
9+
10+
struct ExtraInfo {
11+
ExtraInfo(std::vector<conversion::InputRange> input_ranges)
12+
: convert_info(std::move(input_ranges)) {}
13+
conversion::ConversionInfo convert_info;
14+
};
15+
916
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
1017

1118
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
12-
std::string method_name, conversion::ExtraInfo cfg);
19+
std::string method_name, ExtraInfo cfg);
1320

14-
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, conversion::ExtraInfo cfg);
21+
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg);
1522

1623
} // namespace core
17-
} // namespace trtorch
24+
} // namespace trtorch

core/conversion/conversion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void AddInputs(ConversionCtx* ctx,
133133
"Expected dimension specifications for all input tensors" \
134134
<< ", but found " << input_tensors.size() \
135135
<< " input tensors and " \
136-
<< input_dims.size() << "dimension specs (conversion.AddInputs)");
136+
<< input_dims.size() << " dimension specs (conversion.AddInputs)");
137137

138138
auto profile = ctx->builder->createOptimizationProfile();
139139

@@ -179,7 +179,7 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
179179
}
180180
}
181181

182-
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
182+
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
183183
LOG_INFO(ctx->logger, "Converting Block");
184184

185185
auto inputs = b->inputs();
@@ -221,7 +221,7 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraI
221221
// a serialized TensorRT engine that can be deserialized and run
222222

223223
// Probably should consolidate these two functions
224-
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
224+
std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
225225
ConversionCtx ctx(build_info.engine_settings);
226226
ConvertBlockToNetDef(&ctx, b, build_info, static_params);
227227
std::string engine = ctx.SerializeEngine();
@@ -235,7 +235,7 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
235235
if (!OpSupported(n)) {
236236
auto schema = n->maybeSchema();
237237
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
238-
<< " (conversion.AddLayer)");
238+
<< " (conversion.VerifyCoverterSupportForBlock");
239239
std::stringstream ss;
240240
ss << *schema;
241241
unsupported_ops.insert(ss.str());

core/conversion/conversion.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ struct InputRange {
3030
std::vector<int64_t> max_shape);
3131
};
3232

33-
struct ExtraInfo {
33+
struct ConversionInfo {
3434
std::vector<InputRange> input_ranges;
3535
BuilderSettings engine_settings;
36-
ExtraInfo(std::vector<InputRange> input_ranges)
36+
ConversionInfo(std::vector<InputRange> input_ranges)
3737
: input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {}
3838
};
3939

@@ -43,7 +43,7 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vect
4343

4444
// Converts a already lowered block (blocks with no sub blocks) to
4545
// a serialized TensorRT engine that can be deserialized and run
46-
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);
46+
std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params);
4747

4848
bool OpSupported(const torch::jit::Node* n);
4949

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,25 @@ namespace core {
99
namespace conversion {
1010

1111
std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
12-
os << "Settings requested for TensorRT engine:" \
13-
<< "\n Operating Precision: " << s.op_precision \
14-
<< "\n Make Refittable Engine: " << s.refit \
15-
<< "\n Debuggable Engine: " << s.debug \
16-
<< "\n Strict Type: " << s.strict_type \
17-
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
18-
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
19-
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
20-
<< "\n Max Workspace Size: " << s.workspace_size \
21-
<< "\n Device Type: " << s.device \
22-
<< "\n Engine Capability: " << s.capability;
12+
os << "Settings requested for TensorRT engine:" \
13+
<< "\n Operating Precision: " << s.op_precision \
14+
<< "\n Make Refittable Engine: " << s.refit \
15+
<< "\n Debuggable Engine: " << s.debug \
16+
<< "\n Strict Type: " << s.strict_types \
17+
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
18+
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
19+
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
20+
<< "\n Max Workspace Size: " << s.workspace_size;
21+
22+
if (s.max_batch_size != 0) {
23+
os << "\n Max Batch Size: " << s.max_batch_size;
24+
} else {
25+
os << "\n Max Batch Size: Not set";
26+
}
27+
28+
os << "\n Device Type: " << s.device \
29+
<< "\n Engine Capability: " << s.capability \
30+
<< "\n Calibrator Created: " << (s.calibrator != nullptr);
2331
return os;
2432
}
2533

@@ -36,13 +44,17 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
3644

3745
switch(settings.op_precision) {
3846
case nvinfer1::DataType::kHALF:
47+
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does support FP16");
3948
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
4049
input_type = nvinfer1::DataType::kHALF;
4150
break;
42-
// case nvinfer1::DataType::kINT8:
43-
// cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
44-
// input_type = nvinfer1::DataType::kFLOAT;
45-
// break;
51+
case nvinfer1::DataType::kINT8:
52+
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
53+
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
54+
input_type = nvinfer1::DataType::kFLOAT;
55+
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
56+
cfg->setInt8Calibrator(settings.calibrator);
57+
break;
4658
case nvinfer1::DataType::kFLOAT:
4759
default:
4860
input_type = nvinfer1::DataType::kFLOAT;
@@ -57,14 +69,18 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5769
cfg->setFlag(nvinfer1::BuilderFlag::kDEBUG);
5870
}
5971

60-
if (settings.strict_type) {
72+
if (settings.strict_types) {
6173
cfg->setFlag(nvinfer1::BuilderFlag::kSTRICT_TYPES);
6274
}
6375

6476
if (settings.allow_gpu_fallback) {
6577
cfg->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
6678
}
6779

80+
if (settings.max_batch_size != 0) {
81+
builder->setMaxBatchSize(settings.max_batch_size);
82+
}
83+
6884
cfg->setMinTimingIterations(settings.num_min_timing_iters);
6985
cfg->setAvgTimingIterations(settings.num_avg_timing_iters);
7086
cfg->setMaxWorkspaceSize(settings.workspace_size);

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ struct BuilderSettings {
2020
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
2121
bool refit = false;
2222
bool debug = false;
23-
bool strict_type = false;
23+
bool strict_types = false;
2424
bool allow_gpu_fallback = true;
2525
nvinfer1::DeviceType device = nvinfer1::DeviceType::kGPU;
2626
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
27+
nvinfer1::IInt8Calibrator* calibrator = nullptr;
2728
uint64_t num_min_timing_iters = 2;
2829
uint64_t num_avg_timing_iters = 1;
2930
uint64_t workspace_size = 0;
31+
uint64_t max_batch_size = 0;
3032

3133
BuilderSettings() = default;
3234
BuilderSettings(const BuilderSettings& other) = default;

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ volatile auto batch_norm_registrations = RegisterNodeConversionPatterns()
8383
auto gamma = args[1].unwrapToTensor();
8484

8585
if (/*training*/ args[5].unwrapToBool()) {
86-
LOG_WARNING("TensorRT only converts forward pass of graphs, but saw training = True, may see undefined behavior, consider placing module in eval mode");
86+
LOG_WARNING(R"WARN(TRTorch only converts forward pass of graphs, but saw training = True, may see
87+
unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN");
8788
}
8889

8990
// If gamma is None this fails

core/conversion/converters/impl/pooling.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,17 @@ auto pooling_registrations = RegisterNodeConversionPatterns()
7979
for (size_t i = 0; i < out_shape.size(); i++) {
8080
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_shape[(out_shape.size() - 1) - i];
8181
}
82-
LOG_DEBUG("Stride" << util::toDims(stride));
82+
LOG_DEBUG("Stride: " << util::toDims(stride));
8383

8484
std::vector<int64_t> window(out_shape.size());
8585
for (size_t i = 0; i < out_shape.size(); i++) {
8686
window[window.size() - 1 - i] = in_shape[in_shape.size() - 1 - i] - (out_shape[out_shape.size() - 1 - i] - 1) * stride[stride.size() - 1 - i];
8787
}
8888

89-
LOG_DEBUG("Window" << util::toDims(window));
89+
LOG_DEBUG("Window: " << util::toDims(window));
9090

9191
auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window));
92-
if (!new_layer) {
93-
LOG_ERROR("Unable to create average pooling layer from node: " << *n);
94-
return false;
95-
}
92+
TRTORCH_CHECK(new_layer, "Unable to create average pooling layer from node: " << *n);
9693

9794
new_layer->setStrideNd(util::toDims(stride));
9895

0 commit comments

Comments
 (0)