Skip to content

Commit d8c9a0c

Browse files
authored
Merge pull request #1513 from mfeliz-cruise/michael.feliz/cast_name_conflict
[fix]Disambiguate cast layer names
2 parents da34c69 + ecff91f commit d8c9a0c

File tree

4 files changed

+45
-6
lines changed

4 files changed

+45
-6
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,11 @@ nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor
205205
return id_out_tensor;
206206
}
207207

208-
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) {
208+
nvinfer1::ITensor* castITensor(
209+
ConversionCtx* ctx,
210+
nvinfer1::ITensor* tensor,
211+
nvinfer1::DataType dtype,
212+
const std::string& layer_name_prefix) {
209213
if (tensor->getType() != dtype) {
210214
std::ostringstream tensor_id;
211215
tensor_id << reinterpret_cast<int*>(tensor);
@@ -219,6 +223,9 @@ nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nv
219223
LOG_DEBUG(ctx->logger, "Casting ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype);
220224

221225
std::stringstream ss;
226+
if (layer_name_prefix.size()) {
227+
ss << layer_name_prefix << " ";
228+
}
222229
ss << "[Cast ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype << "]";
223230
id_layer->setName(ss.str().c_str());
224231
return casted_tensor;

core/conversion/converters/converter_util.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ nvinfer1::ITensor* add_abs(
5656
nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& name);
5757

5858
// If an ITensor is of a type not dtype, add an Identity layer to cast it to dtype
59-
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype);
59+
nvinfer1::ITensor* castITensor(
60+
ConversionCtx* ctx,
61+
nvinfer1::ITensor* tensor,
62+
nvinfer1::DataType dtype,
63+
const std::string& layer_name_prefix = "");
6064

6165
// Freeze an at::Tensor in a IConstant layer
6266
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string());

core/conversion/converters/impl/cast.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ auto cast_registrations TORCHTRT_UNUSED =
2626
} else {
2727
trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
2828
}
29-
auto casted_itensor = castITensor(ctx, self, trt_dtype);
29+
auto casted_itensor = castITensor(ctx, self, trt_dtype, util::node_info(n));
3030
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
3131
LOG_DEBUG("[aten::to.dtype] Output tensor shape: " << output->getDimensions());
3232

@@ -48,7 +48,7 @@ auto cast_registrations TORCHTRT_UNUSED =
4848
} else {
4949
trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
5050
}
51-
auto casted_itensor = castITensor(ctx, self, trt_dtype);
51+
auto casted_itensor = castITensor(ctx, self, trt_dtype, util::node_info(n));
5252
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
5353
LOG_DEBUG("[aten::to.device] Output tensor shape: " << output->getDimensions());
5454

@@ -59,7 +59,7 @@ auto cast_registrations TORCHTRT_UNUSED =
5959
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
6060
auto self = args[0].ITensorOrFreeze(ctx);
6161
nvinfer1::DataType other_dtype = args[1].ITensorOrFreeze(ctx)->getType();
62-
auto casted_itensor = castITensor(ctx, self, other_dtype);
62+
auto casted_itensor = castITensor(ctx, self, other_dtype, util::node_info(n));
6363
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
6464
LOG_DEBUG("[aten::to.other] Output tensor shape: " << output->getDimensions());
6565

@@ -77,7 +77,7 @@ auto cast_registrations TORCHTRT_UNUSED =
7777

7878
auto output_dtype = args[2].unwrapToScalar().to<int64_t>();
7979
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
80-
auto casted_itensor = castITensor(ctx, self, trt_dtype);
80+
auto casted_itensor = castITensor(ctx, self, trt_dtype, util::node_info(n));
8181
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
8282
LOG_DEBUG("[aten::to.prim_Device] Output tensor shape: " << output->getDimensions());
8383

tests/core/conversion/converters/test_cast.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,34 @@ TEST(Converters, ATenToSingleConvertsCorrectly) {
163163
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
164164
}
165165

166+
TEST(Converters, ATenToDuplicateConvertsCorrectly) {
167+
const auto graph = R"IR(
168+
graph(%y.1 : Tensor):
169+
%4 : int = prim::Constant[value=3]()
170+
%5 : bool = prim::Constant[value=0]()
171+
%6 : None = prim::Constant()
172+
%y0.1 : Tensor = aten::to(%y.1, %4, %5, %5, %6)
173+
%y0.2 : Tensor = aten::to(%y.1, %4, %5, %5, %6)
174+
return (%y0.1, %y0.2))IR";
175+
176+
auto g = std::make_shared<torch::jit::Graph>();
177+
torch::jit::parseIR(graph, &*g);
178+
179+
auto in = at::randint(1, 10, {3}, {at::kCUDA});
180+
181+
auto jit_in = at::clone(in);
182+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
183+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
184+
185+
auto trt_in = at::clone(in);
186+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
187+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
188+
for (size_t i = 0UL; i < jit_results.size(); ++i) {
189+
ASSERT_TRUE(jit_results[i].scalar_type() == trt_results[i].scalar_type());
190+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6));
191+
}
192+
}
193+
166194
TEST(Converters, ATenTypeAsConvertsCorrectly) {
167195
const auto graph = R"IR(
168196
graph(%0 : Tensor,

0 commit comments

Comments
 (0)