Skip to content

Commit 96dbbf3

Browse files
committed
Add clamp conversion functionality
Signed-off-by: Dheeraj Peri <[email protected]> Apply linting Signed-off-by: Dheeraj Peri <[email protected]> Fix merge conflicts Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 05652b8 commit 96dbbf3

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,41 @@ auto element_wise_registrations TRTORCH_UNUSED =
144144
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
145145
return true;
146146
}})
147+
.pattern({"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)",
148+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
149+
// Compute min(max(min_threshold, input), max_threshold)
150+
auto self = args[0].ITensorOrFreeze(ctx);
151+
auto clamp_layer_out = self;
152+
if (args[1].isIValue() && args[1].IValue()->isScalar()) {
153+
auto minScalar = args[1].unwrapToScalar().to<float>();
154+
auto minTensor = tensor_to_const(ctx, torch::tensor({minScalar}));
155+
auto max_layer = add_elementwise(
156+
ctx,
157+
nvinfer1::ElementWiseOperation::kMAX,
158+
clamp_layer_out,
159+
minTensor,
160+
util::node_info(n) + std::string("_max"));
161+
TRTORCH_CHECK(max_layer, "Unable to create elementwise max layer for node: " << *n);
162+
clamp_layer_out = max_layer->getOutput(0);
163+
}
164+
165+
if (args[2].isIValue() && args[2].IValue()->isScalar()) {
166+
auto maxScalar = args[2].unwrapToScalar().to<float>();
167+
auto maxTensor = tensor_to_const(ctx, torch::tensor({maxScalar}));
168+
auto min_layer = add_elementwise(
169+
ctx,
170+
nvinfer1::ElementWiseOperation::kMIN,
171+
clamp_layer_out,
172+
maxTensor,
173+
util::node_info(n) + std::string("_min"));
174+
TRTORCH_CHECK(min_layer, "Unable to create elementwise min layer for node: " << *n);
175+
clamp_layer_out = min_layer->getOutput(0);
176+
}
177+
178+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
179+
LOG_DEBUG("Clamp layer output tensor shape: " << clamp_layer_out->getDimensions());
180+
return true;
181+
}})
147182
.pattern({"aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
148183
"Tensor",
149184
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,33 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) {
162162
return (%3))IR";
163163
pointwise_test_helper(graph, true, false, {3, 4, 2});
164164
;
165+
166+
TEST(Converters, ATenClampMinConvertsCorrectly) {
167+
const auto graph = R"IR(
168+
graph(%x.1 : Tensor):
169+
%2 : int = prim::Constant[value=-2]()
170+
%3 : None = prim::Constant()
171+
%4 : Tensor = aten::clamp(%x.1, %2, %3)
172+
return (%4))IR";
173+
pointwise_test_helper(graph, true);
174+
}
175+
176+
TEST(Converters, ATenClampMaxConvertsCorrectly) {
177+
const auto graph = R"IR(
178+
graph(%x.1 : Tensor):
179+
%2 : int = prim::Constant[value=3]()
180+
%3 : None = prim::Constant()
181+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
182+
return (%4))IR";
183+
pointwise_test_helper(graph, true);
184+
}
185+
186+
TEST(Converters, ATenClampMinMaxConvertsCorrectly) {
187+
const auto graph = R"IR(
188+
graph(%x.1 : Tensor):
189+
%2 : int = prim::Constant[value=3]()
190+
%3 : int = prim::Constant[value=-2]()
191+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
192+
return (%4))IR";
193+
pointwise_test_helper(graph, true);
165194
}

0 commit comments

Comments
 (0)