Skip to content

Commit 65c8e0a

Browse files
committed
Make changes to scale implementation
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 9bdb5c8 commit 65c8e0a

File tree

2 files changed

+36
-30
lines changed

2 files changed

+36
-30
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
153153
auto other = args[1].ITensorOrFreeze(ctx);
154154

155155
if (1 != scalar) {
156-
auto scaleW = Weights(ctx, scalar);
157-
auto unuse = Weights();
158-
// IScaleLayer assert shift, scale and power to have
159-
// the same dtype
160-
auto scaleLayer = ctx->net->addScale(
161-
*other, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
162-
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
156+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
157+
auto scaleLayer = add_elementwise(
158+
ctx,
159+
nvinfer1::ElementWiseOperation::kPROD,
160+
other,
161+
alphaTensor,
162+
util::node_info(n) + std::string("_AlphaMultiplier"));
163+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
163164
other = scaleLayer->getOutput(0);
164165
}
165166

@@ -181,13 +182,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
181182
auto other = args[1].ITensorOrFreeze(ctx);
182183

183184
if (1 != scalar) {
184-
auto scaleW = Weights(ctx, scalar);
185-
auto unuse = Weights();
186-
// IScaleLayer assert shift, scale and power to have
187-
// the same dtype
188-
auto scaleLayer = ctx->net->addScale(
189-
*other, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
190-
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
185+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
186+
auto scaleLayer = add_elementwise(
187+
ctx,
188+
nvinfer1::ElementWiseOperation::kPROD,
189+
other,
190+
alphaTensor,
191+
util::node_info(n) + std::string("_AlphaMultiplier"));
192+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
191193
other = scaleLayer->getOutput(0);
192194
}
193195

@@ -209,13 +211,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
209211
auto scalar = args[2].unwrapToScalar().to<float>();
210212

211213
if (1 != scalar) {
212-
auto scaleW = Weights(ctx, scalar);
213-
auto unuse = Weights();
214-
// IScaleLayer assert shift, scale and power to have
215-
// the same dtype
216-
auto scaleLayer =
217-
ctx->net->addScale(*self, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
218-
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
214+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
215+
auto scaleLayer = add_elementwise(
216+
ctx,
217+
nvinfer1::ElementWiseOperation::kPROD,
218+
self,
219+
alphaTensor,
220+
util::node_info(n) + std::string("_AlphaMultiplier"));
221+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
219222
self = scaleLayer->getOutput(0);
220223
}
221224

@@ -236,13 +239,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
236239
auto scalar = args[2].unwrapToScalar().to<float>();
237240

238241
if (1 != scalar) {
239-
auto scaleW = Weights(ctx, scalar);
240-
auto unuse = Weights();
241-
// IScaleLayer assert shift, scale and power to have
242-
// the same dtype
243-
auto scaleLayer =
244-
ctx->net->addScale(*self, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
245-
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
242+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
243+
auto scaleLayer = add_elementwise(
244+
ctx,
245+
nvinfer1::ElementWiseOperation::kPROD,
246+
self,
247+
alphaTensor,
248+
util::node_info(n) + std::string("_AlphaMultiplier"));
249+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
246250
self = scaleLayer->getOutput(0);
247251
}
248252

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
8989
TEST(Converters, ATenSubConvertsCorrectly) {
9090
const auto graph = R"IR(
9191
graph(%0 : Tensor, %1 : Tensor):
92-
%2 : int = prim::Constant[value=1]()
92+
%2 : int = prim::Constant[value=2.3]()
9393
%3 : Tensor = aten::sub(%0, %1, %2)
9494
return (%3))IR";
9595
pointwise_test_helper(graph, false);
@@ -215,6 +215,8 @@ TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
215215
%2 : int = prim::Constant[value=2]()
216216
%3 : Tensor = aten::rsub(%0, %1, %2)
217217
return (%3))IR";
218+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
219+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
218220
pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3});
219221
}
220222

@@ -226,4 +228,4 @@ TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
226228
%3 : Tensor = aten::rsub(%0, %scalar, %2)
227229
return (%3))IR";
228230
pointwise_test_helper(graph, true, false, {4, 3, 3, 3});
229-
}
231+
}

0 commit comments

Comments
 (0)