Skip to content

Commit 78a1c61

Browse files
committed
feat(//core/conversion/conversionctx): Make op precision available at
conversion time through ctx Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent cd6b1b9 commit 78a1c61

File tree

3 files changed

+6
-0
lines changed

3 files changed

+6
-0
lines changed

core/conversion/conversion.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ void AddInputs(ConversionCtx* ctx,
160160
TRTORCH_CHECK(profile->isValid(), "Optimization profile is invalid, please check the input range provided (conversion.AddInputs)");
161161

162162
ctx->cfg->addOptimizationProfile(profile);
163+
// TODO: Enable in TRT 7.1
164+
// if (ctx->op_precision == nvinfer1::DataType::kINT8) {
165+
// ctx->cfg->setCalibrationProfile(profile);
166+
// }
163167
}
164168

165169
void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outputs) {

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
6060
input_type = nvinfer1::DataType::kFLOAT;
6161
break;
6262
}
63+
op_precision = settings.op_precision;
6364

6465
if (settings.refit) {
6566
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct ConversionCtx {
4747
nvinfer1::INetworkDefinition* net;
4848
nvinfer1::IBuilderConfig* cfg;
4949
nvinfer1::DataType input_type;
50+
nvinfer1::DataType op_precision;
5051
BuilderSettings settings;
5152
util::logging::TRTorchLogger logger;
5253
// Pointers to data that needs to remain alive until conversion is done

0 commit comments

Comments
 (0)