-
Notifications
You must be signed in to change notification settings - Fork 363
Adding limited support for aten::Int #870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/lowering/passes/remove_unnecessary_casts.cpp b/tmp/changes.txt
index 7f6cc85..064a2cb 100644
--- a/workspace/core/lowering/passes/remove_unnecessary_casts.cpp
+++ b/tmp/changes.txt
@@ -1,5 +1,5 @@
-#include "torch/csrc/jit/passes/subgraph_rewrite.h"
#include "torch/csrc/jit/ir/constants.h"
+#include "torch/csrc/jit/passes/subgraph_rewrite.h"
#include "core/util/prelude.h"
@@ -10,7 +10,6 @@ namespace core {
namespace lowering {
namespace passes {
-
// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just
// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -77,8 +76,8 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
if (user->output()->uses().size() == 1) {
auto potential_cast = user->output()->uses()[0].user;
// The downstream user is aten::Int
- if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int")
- || potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
+ if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int") ||
+ potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
LOG_GRAPH("Downstream user is aten::Int/aten::Float");
auto arg = use.offset;
@@ -88,13 +87,11 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
LOG_GRAPH("Input " << k << " is a Tensor");
if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) {
auto num_to_tensor = user->inputs()[k]->node();
-
- LOG_GRAPH("Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
- << *(*it)
- << *num_to_tensor
- << *user
- << *potential_cast);
-
+
+ LOG_GRAPH(
+ "Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
+ << *(*it) << *num_to_tensor << *user << *potential_cast);
+
// Replace the Tensor Constant with a scalar constant
LOG_GRAPH("Deleting 0-dim Tensor: " << **it);
torch::jit::WithInsertPoint gaurd(*it);
@@ -126,19 +123,16 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
// has a different schema than the original
case c10::aten::add:
new_node = g->create(
- user->kind(),
- torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
- 1);
+ user->kind(),
+ torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
+ 1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
default:
- new_node = g->create(
- user->kind(),
- user->inputs(),
- 1);
+ new_node = g->create(user->kind(), user->inputs(), 1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
@@ -148,7 +142,7 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
LOG_GRAPH("New intermediate operation: " << *new_node);
LOG_GRAPH(new_node->schema());
-
+
// Delete aten::Int
LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast);
potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]);
@@ -163,12 +157,11 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
}
}
}
- }
+ }
}
LOG_ERROR("Post removing single use 0-dim Tensor operations: " << *g);
}
-
} // namespace passes
} // namespace lowering
} // namespace core
diff --git a/workspace/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tmp/changes.txt
index ef370a8..62f913e 100644
--- a/workspace/tests/core/lowering/test_remove_unnecessary_casts.cpp
+++ b/tmp/changes.txt
@@ -102,8 +102,7 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsIntCorrectly) {
auto first_op = *(sg->block()->nodes().begin());
torch::jit::WithInsertPoint guard(first_op);
- torch::jit::Value* r = sg->insertConstant(
- c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
+ torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
@@ -141,8 +140,7 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {
auto first_op = *(sg->block()->nodes().begin());
torch::jit::WithInsertPoint guard(first_op);
- torch::jit::Value* r = sg->insertConstant(
- c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
+ torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
ERROR: Some files do not conform to style guidelines
Likely fixes #732 as well. |
This commit adds a pass to lower out aten::[Int/Float/Bool], aten::NumToTensor pairs w.o. exception. We are assumming this is safe as there are similar passes in PyTorch for ONNX lowering however the scope of this rule is intentionally limited to avoid possible cases where it is not safe. Therefore it should not be expected that all aten::Int issues will be solved with this change and the operator itself remains a limitation of TorchTRT Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
0D Tensors Now we remove select more complex aten::Int cases found in models such as BERT, like the following: ``` graph(%0: int): %1: Tensor = prim::Constant[value={8}]() %2: int = prim::Constant[value=1]() %3: Tensor = prim::NumToTensor(%0) %4: Tensor = aten::add(%1, %3, %2) %5: int = aten::Int(%4) %6: int = aten::add(%5, %5) return (%6)"; graph(%0: int): %1: int = prim::Constant[value=8]() %4: int = aten::add(%1, %0) %6: int = aten::add(%4, %4) return (%6)"; ``` Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
30ee238
to
8139da9
Compare
Lower logging level on debug info Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /workspace/tests/modules/hub.py (original)
+++ /workspace/tests/modules/hub.py (reformatted)
@@ -191,7 +191,6 @@
conditional_script_model = torch.jit.script(conditional_model)
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
-
enc = BertTokenizer.from_pretrained("bert-base-uncased")
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_trt_intercompatibility.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_qat_trt_accuracy.py
Reformatting /workspace/tests/py/test_to_backend_api.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/tests/cpp/test_default_input_types.cpp b/tmp/changes.txt
index 752f51e..a79ddaf 100644
--- a/workspace/tests/cpp/test_default_input_types.cpp
+++ b/tmp/changes.txt
@@ -116,4 +116,5 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) {
INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
- testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
+ testing::Values(
+ PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/tests/cpp/test_default_input_types.cpp b/tmp/changes.txt
index 752f51e..a79ddaf 100644
--- a/workspace/tests/cpp/test_default_input_types.cpp
+++ b/tmp/changes.txt
@@ -116,4 +116,5 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) {
INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
- testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
+ testing::Values(
+ PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /workspace/tests/modules/hub.py (original)
+++ /workspace/tests/modules/hub.py (reformatted)
@@ -191,7 +191,6 @@
conditional_script_model = torch.jit.script(conditional_model)
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
-
enc = BertTokenizer.from_pretrained("bert-base-uncased")
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_trt_intercompatibility.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_qat_trt_accuracy.py
Reformatting /workspace/tests/py/test_to_backend_api.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
af8d22d
to
83ae991
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
|
||
std::string set_attr_pattern = R"IR( | ||
graph(%self, %0): | ||
None = prim::SetAttr[name="_has_warned"](%self, %0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you mention in a comment about this specific attribute _has_warned
? Where is this used ? etc ..
namespace lowering { | ||
namespace passes { | ||
|
||
void RemoveSetAttrs(const torch::jit::Module& mod, std::string method_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this lowering pass used ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a version of transformers that had this and it was breaking the conversion process since setattr does not have a schema. But later versions dont use this so I removed it from the set of active passes
Description
This PR adds support for aten::Int / prim::NumToTensor in a few limited cases.
prim::NumToTensor -> aten::Int
prim::NumToTensor -> X -> aten::Int
in cases where the tensors used are single use and can safely be fusedFixes #513, Fixes #707
Partially: #867, #829, #785, #711, #660
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: