Skip to content

Commit c62474a

Browse files
committed
fix: Return static size for desired dimension if it's available
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 0eccfbf commit c62474a

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,20 +299,22 @@ auto aten_registrations TORCHTRT_UNUSED =
299299
} else {
300300
auto dim = args.at(n->input(1)).unwrapToInt();
301301
if (tensor_var.isITensor()) {
302-
if (ctx->input_is_dynamic) {
302+
auto tensor = tensor_var.ITensor();
303+
auto dims = util::toVec(tensor->getDimensions());
304+
auto nbDims = tensor->getDimensions().nbDims;
305+
if (dim < 0) {
306+
dim += nbDims;
307+
}
308+
// Check if selected dimension size is -1 else return static size
309+
if (ctx->input_is_dynamic && dims[dim] == -1) {
303310
if (ctx->settings.allow_shape_tensors) {
304311
return dynamic_size_layer(ctx, n, args);
305312
} else {
306313
LOG_WARNING(
307314
"There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors");
308315
}
309316
}
310-
auto tensor = tensor_var.ITensor();
311-
auto dims = util::toVec(tensor->getDimensions());
312-
auto nbDims = tensor->getDimensions().nbDims;
313-
if (dim < 0) {
314-
dim += nbDims;
315-
}
317+
316318
return dims[dim];
317319
} else if (tensor_var.IValue()->isTensor()) {
318320
auto tensor = tensor_var.unwrapToTensor();

core/conversion/evaluators/eval_util.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw
4545
// Handle negative axis by refering to nbDims of input Tensor
4646
dim = dim < 0 ? dim + maxDim : dim;
4747
LOG_DEBUG("Dimension to select: " << dim);
48+
// Check if selected dimension size is -1 else return static size
49+
if (input_dims.d[dim] != -1) {
50+
return input_dims.d[dim];
51+
}
4852
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
4953
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
5054

0 commit comments

Comments
 (0)