Skip to content

Commit 05ac70e

Browse files
authored
Merge pull request #293 from NVIDIA/clamp
Add clamp conversion functionality
2 parents 20b0043 + 96dbbf3 commit 05ac70e

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,41 @@ auto element_wise_registrations TRTORCH_UNUSED =
144144
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
145145
return true;
146146
}})
147+
.pattern({"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)",
148+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
149+
// Compute min(max(min_threshold, input), max_threshold)
150+
auto self = args[0].ITensorOrFreeze(ctx);
151+
auto clamp_layer_out = self;
152+
if (args[1].isIValue() && args[1].IValue()->isScalar()) {
153+
auto minScalar = args[1].unwrapToScalar().to<float>();
154+
auto minTensor = tensor_to_const(ctx, torch::tensor({minScalar}));
155+
auto max_layer = add_elementwise(
156+
ctx,
157+
nvinfer1::ElementWiseOperation::kMAX,
158+
clamp_layer_out,
159+
minTensor,
160+
util::node_info(n) + std::string("_max"));
161+
TRTORCH_CHECK(max_layer, "Unable to create elementwise max layer for node: " << *n);
162+
clamp_layer_out = max_layer->getOutput(0);
163+
}
164+
165+
if (args[2].isIValue() && args[2].IValue()->isScalar()) {
166+
auto maxScalar = args[2].unwrapToScalar().to<float>();
167+
auto maxTensor = tensor_to_const(ctx, torch::tensor({maxScalar}));
168+
auto min_layer = add_elementwise(
169+
ctx,
170+
nvinfer1::ElementWiseOperation::kMIN,
171+
clamp_layer_out,
172+
maxTensor,
173+
util::node_info(n) + std::string("_min"));
174+
TRTORCH_CHECK(min_layer, "Unable to create elementwise min layer for node: " << *n);
175+
clamp_layer_out = min_layer->getOutput(0);
176+
}
177+
178+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
179+
LOG_DEBUG("Clamp layer output tensor shape: " << clamp_layer_out->getDimensions());
180+
return true;
181+
}})
147182
.pattern({"aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
148183
"Tensor",
149184
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,35 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) {
171171
return (%3))IR";
172172
pointwise_test_helper(graph, true, false, {3, 4, 2});
173173
;
174+
175+
TEST(Converters, ATenClampMinConvertsCorrectly) {
176+
const auto graph = R"IR(
177+
graph(%x.1 : Tensor):
178+
%2 : int = prim::Constant[value=-2]()
179+
%3 : None = prim::Constant()
180+
%4 : Tensor = aten::clamp(%x.1, %2, %3)
181+
return (%4))IR";
182+
pointwise_test_helper(graph, true);
183+
}
184+
185+
TEST(Converters, ATenClampMaxConvertsCorrectly) {
186+
const auto graph = R"IR(
187+
graph(%x.1 : Tensor):
188+
%2 : int = prim::Constant[value=3]()
189+
%3 : None = prim::Constant()
190+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
191+
return (%4))IR";
192+
pointwise_test_helper(graph, true);
193+
}
194+
195+
TEST(Converters, ATenClampMinMaxConvertsCorrectly) {
196+
const auto graph = R"IR(
197+
graph(%x.1 : Tensor):
198+
%2 : int = prim::Constant[value=3]()
199+
%3 : int = prim::Constant[value=-2]()
200+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
201+
return (%4))IR";
202+
pointwise_test_helper(graph, true);
174203
}
175204

176205
TEST(Converters, ATenGreaterThanConvertsCorrectly) {

0 commit comments

Comments
 (0)