@@ -18,7 +18,14 @@ auto cast_registrations TORCHTRT_UNUSED =
18
18
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19
19
auto self = args[0 ].ITensorOrFreeze (ctx);
20
20
auto output_dtype = args[1 ].unwrapToScalar ().to <int64_t >();
21
- auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
21
+ auto scalar_dtype = static_cast <at::ScalarType>(output_dtype);
22
+ nvinfer1::DataType trt_dtype;
23
+ if (scalar_dtype == at::kLong ) {
24
+ LOG_WARNING (" Truncating aten::to output type from at::kLong to at::kInt" );
25
+ trt_dtype = nvinfer1::DataType::kINT32 ;
26
+ } else {
27
+ trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
28
+ }
22
29
auto casted_itensor = castITensor (ctx, self, trt_dtype);
23
30
auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
24
31
LOG_DEBUG (" [aten::to.dtype] Output tensor shape: " << output->getDimensions ());
@@ -33,9 +40,14 @@ auto cast_registrations TORCHTRT_UNUSED =
33
40
// later shape analysis phase of fallback
34
41
auto self = args[0 ].ITensorOrFreeze (ctx);
35
42
auto output_dtype = args[2 ].unwrapToScalar ().to <int64_t >();
36
-
37
- auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
38
-
43
+ auto scalar_dtype = static_cast <at::ScalarType>(output_dtype);
44
+ nvinfer1::DataType trt_dtype;
45
+ if (scalar_dtype == at::kLong ) {
46
+ LOG_WARNING (" Truncating aten::to output type from at::kLong to at::kInt" );
47
+ trt_dtype = nvinfer1::DataType::kINT32 ;
48
+ } else {
49
+ trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
50
+ }
39
51
auto casted_itensor = castITensor (ctx, self, trt_dtype);
40
52
auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
41
53
LOG_DEBUG (" [aten::to.device] Output tensor shape: " << output->getDimensions ());
0 commit comments