@@ -18,7 +18,15 @@ auto cast_registrations TORCHTRT_UNUSED =
18
18
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19
19
auto self = args[0 ].ITensorOrFreeze (ctx);
20
20
auto output_dtype = args[1 ].unwrapToScalar ().to <int64_t >();
21
- auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
21
+ auto scalar_dtype = static_cast <at::ScalarType>(output_dtype);
22
+ nvinfer1::DataType trt_dtype;
23
+ if (scalar_dtype == at::kLong ){
24
+ LOG_WARNING (" Truncating aten::to output type from at::kLong to at::kInt" );
25
+ trt_dtype = nvinfer1::DataType::kINT32 ;
26
+ }
27
+ else {
28
+ trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
29
+ }
22
30
auto casted_itensor = castITensor (ctx, self, trt_dtype);
23
31
auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
24
32
LOG_DEBUG (" [aten::to.dtype] Output tensor shape: " << output->getDimensions ());
@@ -33,9 +41,15 @@ auto cast_registrations TORCHTRT_UNUSED =
33
41
// later shape analysis phase of fallback
34
42
auto self = args[0 ].ITensorOrFreeze (ctx);
35
43
auto output_dtype = args[2 ].unwrapToScalar ().to <int64_t >();
36
-
37
- auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
38
-
44
+ auto scalar_dtype = static_cast <at::ScalarType>(output_dtype);
45
+ nvinfer1::DataType trt_dtype;
46
+ if (scalar_dtype == at::kLong ){
47
+ LOG_WARNING (" Truncating aten::to output type from at::kLong to at::kInt" );
48
+ trt_dtype = nvinfer1::DataType::kINT32 ;
49
+ }
50
+ else {
51
+ trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
52
+ }
39
53
auto casted_itensor = castITensor (ctx, self, trt_dtype);
40
54
auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
41
55
LOG_DEBUG (" [aten::to.device] Output tensor shape: " << output->getDimensions ());
0 commit comments