Skip to content

Commit 3c59ece

Browse files
authored
Merge pull request #964 from Njuapp/roberta_fix
Fix roberta conversion bugs
2 parents e9fb6ff + 08920e7 commit 3c59ece

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

core/conversion/converters/impl/cast.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@ auto cast_registrations TORCHTRT_UNUSED =
1818
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1919
auto self = args[0].ITensorOrFreeze(ctx);
2020
auto output_dtype = args[1].unwrapToScalar().to<int64_t>();
21-
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
21+
auto scalar_dtype = static_cast<at::ScalarType>(output_dtype);
22+
nvinfer1::DataType trt_dtype;
23+
if (scalar_dtype == at::kLong) {
24+
LOG_WARNING("Truncating aten::to output type from at::kLong to at::kInt");
25+
trt_dtype = nvinfer1::DataType::kINT32;
26+
} else {
27+
trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
28+
}
2229
auto casted_itensor = castITensor(ctx, self, trt_dtype);
2330
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
2431
LOG_DEBUG("[aten::to.dtype] Output tensor shape: " << output->getDimensions());
@@ -33,9 +40,14 @@ auto cast_registrations TORCHTRT_UNUSED =
3340
// later shape analysis phase of fallback
3441
auto self = args[0].ITensorOrFreeze(ctx);
3542
auto output_dtype = args[2].unwrapToScalar().to<int64_t>();
36-
37-
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
38-
43+
auto scalar_dtype = static_cast<at::ScalarType>(output_dtype);
44+
nvinfer1::DataType trt_dtype;
45+
if (scalar_dtype == at::kLong) {
46+
LOG_WARNING("Truncating aten::to output type from at::kLong to at::kInt");
47+
trt_dtype = nvinfer1::DataType::kINT32;
48+
} else {
49+
trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
50+
}
3951
auto casted_itensor = castITensor(ctx, self, trt_dtype);
4052
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
4153
LOG_DEBUG("[aten::to.device] Output tensor shape: " << output->getDimensions());

core/conversion/converters/impl/cumsum.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ auto cumsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
4848
auto data = iterator->getOutput(0);
4949
auto newDims = data->getDimensions();
5050

51-
torch::Tensor zeroValue = at::full(util::toVec(newDims), 0, torch::kFloat32);
51+
torch::Tensor zeroValue =
52+
at::full(util::toVec(newDims), 0, torch_tensorrt::core::util::TRTDataTypeToScalarType(in->getType()));
5253
auto zeroTensor = tensor_to_const(ctx, zeroValue);
5354
auto runningSum = loop->addRecurrence(*zeroTensor);
5455
auto runningSumTensor = runningSum->getOutput(0);

core/conversion/converters/impl/element_wise.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,13 @@ auto element_wise_registrations TORCHTRT_UNUSED =
484484
.pattern({"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
485485
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
486486
auto self = args[0].ITensorOrFreeze(ctx);
487-
auto scalar = args[1].unwrapToScalar().to<float>();
488-
auto scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar}));
487+
auto scalar = args[1].unwrapToScalar();
488+
nvinfer1::ITensor* scalar_tensor;
489+
if (self->getType() == nvinfer1::DataType::kFLOAT || self->getType() == nvinfer1::DataType::kHALF) {
490+
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<float>()}));
491+
} else {
492+
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<int>()}));
493+
}
489494
auto equal = add_elementwise(
490495
ctx,
491496
nvinfer1::ElementWiseOperation::kEQUAL,

0 commit comments

Comments
 (0)