@@ -144,6 +144,41 @@ auto element_wise_registrations TRTORCH_UNUSED =
144
144
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
145
145
return true ;
146
146
}})
147
+ .pattern({" aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)" ,
148
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
149
+ // Compute min(max(min_threshold, input), max_threshold)
150
+ auto self = args[0 ].ITensorOrFreeze (ctx);
151
+ auto clamp_layer_out = self;
152
+ if (args[1 ].isIValue () && args[1 ].IValue ()->isScalar ()) {
153
+ auto minScalar = args[1 ].unwrapToScalar ().to <float >();
154
+ auto minTensor = tensor_to_const (ctx, torch::tensor ({minScalar}));
155
+ auto max_layer = add_elementwise (
156
+ ctx,
157
+ nvinfer1::ElementWiseOperation::kMAX ,
158
+ clamp_layer_out,
159
+ minTensor,
160
+ util::node_info (n) + std::string (" _max" ));
161
+ TRTORCH_CHECK (max_layer, " Unable to create elementwise max layer for node: " << *n);
162
+ clamp_layer_out = max_layer->getOutput (0 );
163
+ }
164
+
165
+ if (args[2 ].isIValue () && args[2 ].IValue ()->isScalar ()) {
166
+ auto maxScalar = args[2 ].unwrapToScalar ().to <float >();
167
+ auto maxTensor = tensor_to_const (ctx, torch::tensor ({maxScalar}));
168
+ auto min_layer = add_elementwise (
169
+ ctx,
170
+ nvinfer1::ElementWiseOperation::kMIN ,
171
+ clamp_layer_out,
172
+ maxTensor,
173
+ util::node_info (n) + std::string (" _min" ));
174
+ TRTORCH_CHECK (min_layer, " Unable to create elementwise min layer for node: " << *n);
175
+ clamp_layer_out = min_layer->getOutput (0 );
176
+ }
177
+
178
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], clamp_layer_out);
179
+ LOG_DEBUG (" Clamp layer output tensor shape: " << clamp_layer_out->getDimensions ());
180
+ return true ;
181
+ }})
147
182
.pattern({" aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
148
183
" Tensor" ,
149
184
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
0 commit comments