Skip to content

Commit a5183c8

Browse files
committed
Fix roberta conversion bugs
1 parent 8801573 commit a5183c8

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

core/conversion/converters/impl/cumsum.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ auto cumsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
4848
auto data = iterator->getOutput(0);
4949
auto newDims = data->getDimensions();
5050

51-
torch::Tensor zeroValue = at::full(util::toVec(newDims), 0, torch::kFloat32);
51+
torch::Tensor zeroValue =
52+
at::full(util::toVec(newDims), 0, torch_tensorrt::core::util::TRTDataTypeToScalarType(in->getType()));
5253
auto zeroTensor = tensor_to_const(ctx, zeroValue);
5354
auto runningSum = loop->addRecurrence(*zeroTensor);
5455
auto runningSumTensor = runningSum->getOutput(0);

core/conversion/converters/impl/element_wise.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,12 @@ auto element_wise_registrations TORCHTRT_UNUSED =
484484
.pattern({"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
485485
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
486486
auto self = args[0].ITensorOrFreeze(ctx);
487-
auto scalar = args[1].unwrapToScalar().to<float>();
488-
auto scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar}));
487+
auto scalar = args[1].unwrapToScalar();
488+
nvinfer1::ITensor* scalar_tensor;
489+
if (scalar.isFloatingPoint())
490+
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<float>()}));
491+
else
492+
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<int>()}));
489493
auto equal = add_elementwise(
490494
ctx,
491495
nvinfer1::ElementWiseOperation::kEQUAL,

core/util/trt_util.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
238238
{at::kFloat, nvinfer1::DataType::kFLOAT},
239239
{at::kHalf, nvinfer1::DataType::kHALF},
240240
{at::kInt, nvinfer1::DataType::kINT32},
241+
{at::kLong, nvinfer1::DataType::kINT32},
241242
{at::kChar, nvinfer1::DataType::kINT8},
242243
{at::kBool, nvinfer1::DataType::kBOOL}};
243244
return at_trt_type_map;

0 commit comments

Comments
 (0)