Skip to content

Commit 9bbba4f

Browse files
narendasanNaren Dasan
andauthored
Fix: Layer norm Torchscript converter (#3062)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> Co-authored-by: Naren Dasan <[email protected]>
1 parent 2589fdb commit 9bbba4f

File tree

11 files changed

+161
-261
lines changed

11 files changed

+161
-261
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,77 @@ nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) {
438438
return out;
439439
}
440440

441+
nvinfer1::ITensor* add_expand(ConversionCtx* ctx, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
442+
auto input_dims = in->getDimensions();
443+
TORCHTRT_CHECK(
444+
input_dims.nbDims <= expandedDims.nbDims,
445+
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");
446+
447+
// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
448+
for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) {
449+
int64_t offset = expandedDims.nbDims - 1 - i;
450+
int64_t dim = input_dims.nbDims - 1 - offset;
451+
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
452+
int64_t targetSize = expandedDims.d[i];
453+
// In expand layer passing -1 as the size for a dimension means not changing the size of that dimension.
454+
if (targetSize != -1) {
455+
if (size != targetSize) {
456+
if (size != 1) {
457+
TORCHTRT_THROW_ERROR(
458+
"The expanded size of tensor (" << targetSize << ")"
459+
<< " must match the existing size (" << size << ")"
460+
<< " at dimension " << i);
461+
}
462+
}
463+
} else {
464+
// For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but
465+
// not [-1, 3, 4].
466+
if (dim < 0) {
467+
TORCHTRT_THROW_ERROR(
468+
"The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension "
469+
<< i);
470+
} else {
471+
// in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
472+
expandedDims.d[i] = input_dims.d[dim];
473+
}
474+
}
475+
}
476+
477+
auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims;
478+
if (num_expand_dims > 0) {
479+
nvinfer1::Dims reshape_dims;
480+
reshape_dims.nbDims = expandedDims.nbDims;
481+
for (int64_t i = 0; i < num_expand_dims; i++) {
482+
reshape_dims.d[i] = 1;
483+
}
484+
for (int64_t i = 0; i < input_dims.nbDims; i++) {
485+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
486+
}
487+
// Add a reshape layer to expand dims
488+
auto reshape_layer = ctx->net->addShuffle(*in);
489+
reshape_layer->setReshapeDimensions(reshape_dims);
490+
in = reshape_layer->getOutput(0);
491+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
492+
}
493+
494+
// Start the slicing from beginning of tensor since this is an expand layer
495+
std::vector<int64_t> start_vec(expandedDims.nbDims, 0);
496+
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
497+
498+
// Set the stride of non singleton dimension to 1
499+
std::vector<int64_t> strides_vec(expandedDims.nbDims, 0);
500+
for (int64_t i = 0; i < expandedDims.nbDims; i++) {
501+
strides_vec[i] = (in->getDimensions().d[i] != 1);
502+
}
503+
504+
auto strides = util::toDims(c10::IntArrayRef(strides_vec));
505+
// Slice layer does the expansion in TRT. Desired output size is specified by expandedDims
506+
auto slice_layer = ctx->net->addSlice(*in, start_offset, expandedDims, strides);
507+
LOG_DEBUG(ctx->logger, "Expand Tensor: " << in->getName());
508+
509+
return slice_layer->getOutput(0);
510+
}
511+
441512
} // namespace converters
442513
} // namespace conversion
443514
} // namespace core

core/conversion/converters/converter_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s);
101101

102102
nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b);
103103

104+
nvinfer1::ITensor* add_expand(ConversionCtx* ctx, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims);
105+
104106
} // namespace converters
105107
} // namespace conversion
106108
} // namespace core

core/conversion/converters/impl/expand.cpp

Lines changed: 9 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -27,78 +27,14 @@ nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfe
2727
}
2828
}
2929

30-
bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
31-
auto input_dims = in->getDimensions();
32-
TORCHTRT_CHECK(
33-
input_dims.nbDims <= expandedDims.nbDims,
34-
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");
35-
36-
// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
37-
for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) {
38-
int64_t offset = expandedDims.nbDims - 1 - i;
39-
int64_t dim = input_dims.nbDims - 1 - offset;
40-
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
41-
int64_t targetSize = expandedDims.d[i];
42-
// In expand layer passing -1 as the size for a dimension means not changing the size of that dimension.
43-
if (targetSize != -1) {
44-
if (size != targetSize) {
45-
if (size != 1) {
46-
TORCHTRT_THROW_ERROR(
47-
"The expanded size of tensor (" << targetSize << ")"
48-
<< " must match the existing size (" << size << ")"
49-
<< " at dimension " << i);
50-
}
51-
}
52-
} else {
53-
// For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but
54-
// not [-1, 3, 4].
55-
if (dim < 0) {
56-
TORCHTRT_THROW_ERROR(
57-
"The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension "
58-
<< i);
59-
} else {
60-
// in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
61-
expandedDims.d[i] = input_dims.d[dim];
62-
}
63-
}
64-
}
65-
66-
auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims;
67-
if (num_expand_dims > 0) {
68-
nvinfer1::Dims reshape_dims;
69-
reshape_dims.nbDims = expandedDims.nbDims;
70-
for (int64_t i = 0; i < num_expand_dims; i++) {
71-
reshape_dims.d[i] = 1;
72-
}
73-
for (int64_t i = 0; i < input_dims.nbDims; i++) {
74-
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
75-
}
76-
// Add a reshape layer to expand dims
77-
auto reshape_layer = ctx->net->addShuffle(*in);
78-
reshape_layer->setReshapeDimensions(reshape_dims);
79-
in = reshape_layer->getOutput(0);
80-
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
81-
}
82-
83-
// Start the slicing from beginning of tensor since this is an expand layer
84-
std::vector<int64_t> start_vec(expandedDims.nbDims, 0);
85-
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
86-
87-
// Set the stride of non singleton dimension to 1
88-
std::vector<int64_t> strides_vec(expandedDims.nbDims, 0);
89-
for (int64_t i = 0; i < expandedDims.nbDims; i++) {
90-
strides_vec[i] = (in->getDimensions().d[i] != 1);
91-
}
92-
93-
auto strides = util::toDims(c10::IntArrayRef(strides_vec));
94-
// Slice layer does the expansion in TRT. Desired output size is specified by expandedDims
95-
auto slice_layer = ctx->net->addSlice(*in, start_offset, expandedDims, strides);
96-
slice_layer->setName(util::node_info(n).c_str());
97-
98-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
99-
30+
bool add_expand_static(
31+
ConversionCtx* ctx,
32+
const torch::jit::Node* n,
33+
nvinfer1::ITensor* in,
34+
nvinfer1::Dims expandedDims) {
35+
auto expand_out = add_expand(ctx, in, expandedDims);
36+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], expand_out);
10037
LOG_DEBUG("Expand layer output tensor shape: " << out->getDimensions());
101-
10238
return true;
10339
}
10440

@@ -209,7 +145,7 @@ auto expand_registrations TORCHTRT_UNUSED =
209145
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
210146
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
211147
} else {
212-
return add_expand(ctx, n, in, expandedDims);
148+
return add_expand_static(ctx, n, in, expandedDims);
213149
}
214150
}})
215151
.pattern(
@@ -223,7 +159,7 @@ auto expand_registrations TORCHTRT_UNUSED =
223159
if (ctx->input_is_dynamic) {
224160
return add_expand_dynamic(ctx, n, in, getShapeOutput(ctx, targetTensor), targetDims, false);
225161
} else {
226-
return add_expand(ctx, n, in, targetDims);
162+
return add_expand_static(ctx, n, in, targetDims);
227163
}
228164
}})
229165
.pattern(

core/conversion/converters/impl/layer_norm.cpp

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,41 +10,6 @@ namespace converters {
1010
namespace impl {
1111
namespace {
1212

13-
nvinfer1::ITensor* broadcast(
14-
ConversionCtx* ctx,
15-
const torch::jit::Node* n,
16-
nvinfer1::ITensor* to_broadcast,
17-
const int nbDims,
18-
const std::string& tag) {
19-
auto to_broadcast_nbdims = to_broadcast->getDimensions().nbDims;
20-
TORCHTRT_CHECK(to_broadcast_nbdims <= nbDims, "Cannot broadcast tensor with more dimensions than the target");
21-
if (to_broadcast_nbdims == nbDims) {
22-
return to_broadcast;
23-
}
24-
auto shape_layer = ctx->net->addShape(*to_broadcast);
25-
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
26-
shape_layer->setName((util::node_info(n) + "_shape_" + tag).c_str());
27-
auto shape_layer_out = shape_layer->getOutput(0);
28-
29-
auto extra_dims_tensor = torch::ones({nbDims - to_broadcast_nbdims}, torch::TensorOptions().dtype(torch::kInt32));
30-
auto extra_dims_itensor = tensor_to_const(ctx, extra_dims_tensor);
31-
32-
std::vector<nvinfer1::ITensor*> to_concat = {extra_dims_itensor, shape_layer_out};
33-
auto concat_layer = ctx->net->addConcatenation(to_concat.data(), to_concat.size());
34-
TORCHTRT_CHECK(concat_layer, "Unable to create concat layer from node: " << *n);
35-
concat_layer->setName((util::node_info(n) + "_concat_" + tag).c_str());
36-
auto target_shape = concat_layer->getOutput(0);
37-
38-
auto shuffle_layer = ctx->net->addShuffle(*to_broadcast);
39-
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
40-
shuffle_layer->setName((util::node_info(n) + "_shuffle_" + tag).c_str());
41-
shuffle_layer->setInput(1, *target_shape);
42-
auto output = shuffle_layer->getOutput(0);
43-
LOG_DEBUG(
44-
"Broadcast " << tag << " to shape: " << output->getDimensions() << " from " << to_broadcast->getDimensions());
45-
return output;
46-
}
47-
4813
auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern({
4914
R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta,
5015
float eps, bool cudnn_enabled) -> (Tensor))SIG",
@@ -62,20 +27,22 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
6227

6328
nvinfer1::ITensor* gamma = nullptr;
6429
if (args[2].IValue()->isNone()) {
65-
auto gamma_torch_tensor = torch::ones(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32));
30+
auto gamma_torch_tensor =
31+
torch::ones(input_shape_vec, torch::TensorOptions().dtype(util::TRTDataTypeToScalarType(input->getType())));
6632
gamma = tensor_to_const(ctx, gamma_torch_tensor);
6733
} else {
6834
gamma = args[2].ITensorOrFreeze(ctx);
69-
gamma = broadcast(ctx, n, gamma, input_shape_vec.size(), "gamma");
35+
gamma = add_expand(ctx, gamma, input_shape);
7036
}
7137

7238
nvinfer1::ITensor* beta = nullptr;
7339
if (args[3].IValue()->isNone()) {
74-
auto beta_torch_tensor = torch::zeros(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32));
40+
auto beta_torch_tensor = torch::zeros(
41+
input_shape_vec, torch::TensorOptions().dtype(util::TRTDataTypeToScalarType(input->getType())));
7542
beta = tensor_to_const(ctx, beta_torch_tensor);
7643
} else {
7744
beta = args[3].ITensorOrFreeze(ctx);
78-
beta = broadcast(ctx, n, beta, input_shape_vec.size(), "beta");
45+
beta = add_expand(ctx, beta, input_shape);
7946
}
8047

8148
auto eps = args[4].unwrapToDouble();
@@ -84,7 +51,7 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
8451
TORCHTRT_CHECK(normalize_layer, "Unable to create layer_norm from node: " << *n);
8552
normalize_layer->setName(util::node_info(n).c_str());
8653
normalize_layer->setEpsilon(eps);
87-
normalize_layer->setComputePrecision(nvinfer1::DataType::kFLOAT);
54+
normalize_layer->setComputePrecision(input->getType());
8855
auto normalized = normalize_layer->getOutput(0);
8956

9057
ctx->AssociateValueAndTensor(n->outputs()[0], normalized);

core/lowering/lowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
142142
passes::SiluToSigmoidMultipication(g);
143143
passes::RemoveSingleUse0DTensors(g);
144144
passes::RemoveUnnecessaryCasts(g);
145+
passes::UnpackScaledDotProductAttention(g);
145146
passes::ReplaceAtenInt(g);
146147
if (lower_info.converting_to_trt_engine) {
147148
passes::RemoveCollectionCast(g);
148149
}
149-
passes::UnpackScaledDotProductAttention(g);
150150
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
151151
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
152152
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());

0 commit comments

Comments
 (0)