-
Notifications
You must be signed in to change notification settings - Fork 364
Add clamp conversion functionality #293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,6 +144,41 @@ auto element_wise_registrations TRTORCH_UNUSED = | |
LOG_DEBUG("Output tensor shape: " << out->getDimensions()); | ||
return true; | ||
}}) | ||
.pattern({"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)", | ||
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { | ||
// Compute min(max(min_threshold, input), max_threshold) | ||
auto self = args[0].ITensorOrFreeze(ctx); | ||
auto clamp_layer_out = self; | ||
if (args[1].isIValue() && args[1].IValue()->isScalar()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we might be able to do !isNone on the IValue instead of two checks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I remember correctly, isNone() didn't work for me. I will double check again. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The op does basically this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see |
||
auto minScalar = args[1].unwrapToScalar().to<float>(); | ||
auto minTensor = tensor_to_const(ctx, torch::tensor({minScalar})); | ||
auto max_layer = add_elementwise( | ||
ctx, | ||
nvinfer1::ElementWiseOperation::kMAX, | ||
clamp_layer_out, | ||
minTensor, | ||
util::node_info(n) + std::string("_max")); | ||
TRTORCH_CHECK(max_layer, "Unable to create elementwise max layer for node: " << *n); | ||
clamp_layer_out = max_layer->getOutput(0); | ||
} | ||
|
||
if (args[2].isIValue() && args[2].IValue()->isScalar()) { | ||
auto maxScalar = args[2].unwrapToScalar().to<float>(); | ||
auto maxTensor = tensor_to_const(ctx, torch::tensor({maxScalar})); | ||
auto min_layer = add_elementwise( | ||
ctx, | ||
nvinfer1::ElementWiseOperation::kMIN, | ||
clamp_layer_out, | ||
maxTensor, | ||
util::node_info(n) + std::string("_min")); | ||
TRTORCH_CHECK(min_layer, "Unable to create elementwise min layer for node: " << *n); | ||
clamp_layer_out = min_layer->getOutput(0); | ||
} | ||
|
||
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out); | ||
LOG_DEBUG("Clamp layer output tensor shape: " << clamp_layer_out->getDimensions()); | ||
return true; | ||
}}) | ||
.pattern({"aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> " | ||
"Tensor", | ||
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we possibly lower to hardtanh or is the functionality different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hardtanh looks a bit different https://pytorch.org/docs/stable/generated/torch.nn.Hardtanh.html
hardtanh compares the elements to -1 and 1 and replaces with default values (thresholds) where as clamp compares the input elements to thresholds directly. They both probably can match in certain scenarios
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hardtanh lets you configure the min and max values. really my point is we should be able to lower clamp and hardtanh to the same clip activation converter. This is something we could leave for later though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally our converter library should be as small as possible and we should do as much as we can through lowering