Skip to content

Commit 4969495

Browse files
committed
feat: Implement dynamic shape support for floordiv, NumToTensor, layernorm
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 5b156dc commit 4969495

File tree

4 files changed

+33
-9
lines changed

4 files changed

+33
-9
lines changed

core/conversion/converters/impl/layer_norm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
2020

2121
/* Layer_Norm normalizes over last N dimensions.
2222
normalizaed_shape could be (C,H,W), (H,W), or (W). */
23-
auto normalized_shape = args[1].unwrapToIntList();
24-
auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape));
23+
// This could be an IntList or ITensorList. We only need the size of this list.
24+
auto normalized_shape = args[1].IValue()->toList();
2525

2626
// Unwrap eps.
2727
auto eps = args[4].unwrapToDouble();
@@ -30,7 +30,7 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
3030

3131
// Set up axis_ask for E[x].
3232
uint32_t axis_mask = 0;
33-
for (size_t i = 0; i < normalized_shape_vec.size(); i++) {
33+
for (size_t i = 0; i < normalized_shape.size(); i++) {
3434
axis_mask |= 1 << (shape.size() - i - 1);
3535
}
3636
LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask));

core/conversion/evaluators/aten.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010
#include "torch/torch.h"
1111

12+
#include "core/conversion/converters/converter_util.h"
1213
#include "core/conversion/evaluators/eval_macros.h"
1314
#include "core/conversion/evaluators/eval_util.h"
1415
#include "core/conversion/evaluators/evaluators.h"
@@ -677,6 +678,25 @@ auto aten_registrations TORCHTRT_UNUSED =
677678
.evaluator(
678679
{c10::Symbol::fromQualString("aten::floordiv"),
679680
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
681+
// Dynamic version of aten::floordiv
682+
if (args.at(n->input(0)).isITensor()) {
683+
if (args.at(n->input(1)).IValue()->isInt()) {
684+
auto int_tensor = scalar_to_tensor(args.at(n->input(1)).IValue()->toInt());
685+
auto int_itensor = converters::tensor_to_const(ctx, int_tensor, util::node_info(n) + "_constant");
686+
auto elementwise_layer = converters::add_elementwise(
687+
ctx,
688+
nvinfer1::ElementWiseOperation::kFLOOR_DIV,
689+
args.at(n->input(0)).ITensor(),
690+
int_itensor,
691+
util::node_info(n));
692+
auto output_tensor = elementwise_layer->getOutput(0);
693+
auto tensor_holder = TensorContainer();
694+
tensor_holder.hold_tensor(output_tensor);
695+
auto output_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
696+
return output_ivalue;
697+
}
698+
}
699+
// Static version
680700
if (args.at(n->input(0)).IValue()->isInt()) {
681701
auto a = args.at(n->input(0)).unwrapToInt();
682702
auto b = args.at(n->input(1)).unwrapToInt();

core/conversion/evaluators/prim.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ auto prim_registrations =
3232
.evaluator(
3333
{torch::jit::prim::NumToTensor,
3434
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
35+
// Dynamic version receives an ITensor here so pass that as output directly.
36+
if (args.at(n->input(0)).isITensor()) {
37+
return args.at(n->input(0)).ITensor();
38+
}
3539
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
3640
}})
3741
.evaluator(

tests/core/conversion/converters/BUILD

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,33 +224,33 @@ test_suite(
224224
":test_div",
225225
":test_einsum",
226226
":test_expand",
227+
":test_index",
227228
":test_instance_norm",
228229
":test_interpolate",
229-
":test_index",
230230
":test_layer_norm",
231231
":test_linear",
232232
":test_lstm_cell",
233-
":test_matrix_multiply",
234233
":test_masked_fill",
234+
":test_matrix_multiply",
235235
":test_max",
236236
":test_normalize",
237237
":test_pooling",
238238
":test_reduce",
239-
":test_roll",
240239
":test_replication_pad",
240+
":test_roll",
241241
":test_scatter",
242242
":test_select",
243243
":test_shuffle",
244+
":test_slice",
244245
":test_softmax",
246+
":test_split",
245247
":test_squeeze",
246248
":test_stack",
247-
":test_split",
248-
":test_slice",
249249
":test_topk",
250250
":test_unary",
251-
":test_unsqueeze",
252251
":test_unbind",
253252
":test_unpack",
253+
":test_unsqueeze",
254254
":test_where",
255255
],
256256
)

0 commit comments

Comments
 (0)