Skip to content

Commit eb69bf2

Browse files
committed
fix RoBERTa compilation bugs
1 parent 6afac82 commit eb69bf2

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

core/conversion/converters/impl/cast.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@ auto cast_registrations TORCHTRT_UNUSED =
1818
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1919
auto self = args[0].ITensorOrFreeze(ctx);
2020
auto output_dtype = args[1].unwrapToScalar().to<int64_t>();
21-
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
21+
auto scalar_dtype = static_cast<at::ScalarType>(output_dtype);
22+
nvinfer1::DataType trt_dtype;
23+
if(scalar_dtype == at::kLong){
24+
LOG_WARNING("Truncating aten::to output type from at::kLong to at::kInt");
25+
trt_dtype = nvinfer1::DataType::kINT32;
26+
}
27+
else{
28+
trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
29+
}
2230
auto casted_itensor = castITensor(ctx, self, trt_dtype);
2331
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
2432
LOG_DEBUG("[aten::to.dtype] Output tensor shape: " << output->getDimensions());
@@ -33,9 +41,15 @@ auto cast_registrations TORCHTRT_UNUSED =
3341
// later shape analysis phase of fallback
3442
auto self = args[0].ITensorOrFreeze(ctx);
3543
auto output_dtype = args[2].unwrapToScalar().to<int64_t>();
36-
37-
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
38-
44+
auto scalar_dtype = static_cast<at::ScalarType>(output_dtype);
45+
nvinfer1::DataType trt_dtype;
46+
if(scalar_dtype == at::kLong){
47+
LOG_WARNING("Truncating aten::to output type from at::kLong to at::kInt");
48+
trt_dtype = nvinfer1::DataType::kINT32;
49+
}
50+
else{
51+
trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
52+
}
3953
auto casted_itensor = castITensor(ctx, self, trt_dtype);
4054
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
4155
LOG_DEBUG("[aten::to.device] Output tensor shape: " << output->getDimensions());

core/conversion/converters/impl/cumsum.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ auto cumsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
4848
auto data = iterator->getOutput(0);
4949
auto newDims = data->getDimensions();
5050

51-
torch::Tensor zeroValue = at::full(util::toVec(newDims), 0, torch::kFloat32);
51+
torch::Tensor zeroValue =
52+
at::full(util::toVec(newDims), 0, torch_tensorrt::core::util::TRTDataTypeToScalarType(in->getType()));
5253
auto zeroTensor = tensor_to_const(ctx, zeroValue);
5354
auto runningSum = loop->addRecurrence(*zeroTensor);
5455
auto runningSumTensor = runningSum->getOutput(0);

core/conversion/converters/impl/element_wise.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,14 @@ auto element_wise_registrations TORCHTRT_UNUSED =
484484
.pattern({"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
485485
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
486486
auto self = args[0].ITensorOrFreeze(ctx);
487-
auto scalar = args[1].unwrapToScalar().to<float>();
488-
auto scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar}));
487+
auto scalar = args[1].unwrapToScalar();
488+
nvinfer1::ITensor* scalar_tensor;
489+
if(self->getType() == nvinfer1::DataType::kFLOAT || self->getType() == nvinfer1::DataType::kHALF){
490+
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<float>()}));
491+
}
492+
else{
493+
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<int>()}));
494+
}
489495
auto equal = add_elementwise(
490496
ctx,
491497
nvinfer1::ElementWiseOperation::kEQUAL,

0 commit comments

Comments
 (0)