Skip to content

Commit 0f1cada

Browse files
committed
Add clamp conversion functionality
1 parent b787c5e commit 0f1cada

File tree

2 files changed

+71
-6
lines changed

2 files changed

+71
-6
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,41 @@ auto element_wise_registrations TRTORCH_UNUSED =
144144
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
145145
return true;
146146
}})
147+
.pattern({"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)",
148+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
149+
// Compute min(max(min_threshold, input), max_threshold)
150+
auto self = args[0].ITensorOrFreeze(ctx);
151+
auto clamp_layer_out = self;
152+
if (args[1].isIValue() && args[1].IValue()->isScalar()) {
153+
auto minScalar = args[1].unwrapToScalar().to<float>();
154+
auto minTensor = tensor_to_const(ctx, torch::tensor({minScalar}));
155+
auto max_layer = add_elementwise(
156+
ctx,
157+
nvinfer1::ElementWiseOperation::kMAX,
158+
clamp_layer_out,
159+
minTensor,
160+
util::node_info(n) + std::string("_max"));
161+
TRTORCH_CHECK(max_layer, "Unable to create elementwise max layer for node: " << *n);
162+
clamp_layer_out = max_layer->getOutput(0);
163+
}
164+
165+
if (args[2].isIValue() && args[2].IValue()->isScalar()) {
166+
auto maxScalar = args[2].unwrapToScalar().to<float>();
167+
auto maxTensor = tensor_to_const(ctx, torch::tensor({maxScalar}));
168+
auto min_layer = add_elementwise(
169+
ctx,
170+
nvinfer1::ElementWiseOperation::kMIN,
171+
clamp_layer_out,
172+
maxTensor,
173+
util::node_info(n) + std::string("_min"));
174+
TRTORCH_CHECK(min_layer, "Unable to create elementwise min layer for node: " << *n);
175+
clamp_layer_out = min_layer->getOutput(0);
176+
}
177+
178+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
179+
LOG_DEBUG("Clamp layer output tensor shape: " << clamp_layer_out->getDimensions());
180+
return true;
181+
}})
147182
.pattern({"aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
148183
"Tensor",
149184
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ TEST(Converters, ATenAddImplicitWithAlphaConvertsCorrectly) {
7676
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
7777
}
7878

79+
TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
80+
const auto graph = R"IR(
81+
graph(%0 : Tensor):
82+
%2 : int = prim::Constant[value=1]()
83+
%scalar : float = prim::Constant[value=2.4]()
84+
%3 : Tensor = aten::add(%0, %scalar, %2)
85+
return (%3))IR";
86+
pointwise_test_helper(graph, true);
87+
}
88+
7989
TEST(Converters, ATenSubConvertsCorrectly) {
8090
const auto graph = R"IR(
8191
graph(%0 : Tensor, %1 : Tensor):
@@ -134,12 +144,32 @@ TEST(Converters, ATenPowScalarConvertsCorrectly) {
134144
pointwise_test_helper(graph, true);
135145
}
136146

137-
TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
147+
TEST(Converters, ATenClampMinConvertsCorrectly) {
138148
const auto graph = R"IR(
139-
graph(%0 : Tensor):
140-
%2 : int = prim::Constant[value=1]()
141-
%scalar : float = prim::Constant[value=2.4]()
142-
%3 : Tensor = aten::add(%0, %scalar, %2)
143-
return (%3))IR";
149+
graph(%x.1 : Tensor):
150+
%2 : int = prim::Constant[value=-2]()
151+
%3 : None = prim::Constant()
152+
%4 : Tensor = aten::clamp(%x.1, %2, %3)
153+
return (%4))IR";
154+
pointwise_test_helper(graph, true);
155+
}
156+
157+
TEST(Converters, ATenClampMaxConvertsCorrectly) {
158+
const auto graph = R"IR(
159+
graph(%x.1 : Tensor):
160+
%2 : int = prim::Constant[value=3]()
161+
%3 : None = prim::Constant()
162+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
163+
return (%4))IR";
164+
pointwise_test_helper(graph, true);
165+
}
166+
167+
TEST(Converters, ATenClampMinMaxConvertsCorrectly) {
168+
const auto graph = R"IR(
169+
graph(%x.1 : Tensor):
170+
%2 : int = prim::Constant[value=3]()
171+
%3 : int = prim::Constant[value=-2]()
172+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
173+
return (%4))IR";
144174
pointwise_test_helper(graph, true);
145175
}

0 commit comments

Comments
 (0)