Skip to content

Commit e63908b

Browse files
committed
refactor: Apply linting
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 46ac757 commit e63908b

File tree

2 files changed

+16
-25
lines changed

2 files changed

+16
-25
lines changed

core/lowering/passes/remove_unnecessary_casts.cpp

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
21
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
33

44
#include "core/util/prelude.h"
55

@@ -10,7 +10,6 @@ namespace core {
1010
namespace lowering {
1111
namespace passes {
1212

13-
1413
// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just
1514
// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright
1615
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -77,8 +76,8 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
7776
if (user->output()->uses().size() == 1) {
7877
auto potential_cast = user->output()->uses()[0].user;
7978
// The downstream user is aten::Int
80-
if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int")
81-
|| potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
79+
if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int") ||
80+
potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
8281
LOG_GRAPH("Downstream user is aten::Int/aten::Float");
8382
auto arg = use.offset;
8483

@@ -88,13 +87,11 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
8887
LOG_GRAPH("Input " << k << " is a Tensor");
8988
if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) {
9089
auto num_to_tensor = user->inputs()[k]->node();
91-
92-
LOG_GRAPH("Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
93-
<< *(*it)
94-
<< *num_to_tensor
95-
<< *user
96-
<< *potential_cast);
97-
90+
91+
LOG_GRAPH(
92+
"Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
93+
<< *(*it) << *num_to_tensor << *user << *potential_cast);
94+
9895
// Replace the Tensor Constant with a scalar constant
9996
LOG_GRAPH("Deleting 0-dim Tensor: " << **it);
10097
torch::jit::WithInsertPoint gaurd(*it);
@@ -126,19 +123,16 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
126123
// has a different schema than the original
127124
case c10::aten::add:
128125
new_node = g->create(
129-
user->kind(),
130-
torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
131-
1);
126+
user->kind(),
127+
torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
128+
1);
132129
new_node->insertAfter(user);
133130
new_node->outputs()[0]->setType(c10::IntType::get());
134131
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
135132
user->destroy();
136133
break;
137134
default:
138-
new_node = g->create(
139-
user->kind(),
140-
user->inputs(),
141-
1);
135+
new_node = g->create(user->kind(), user->inputs(), 1);
142136
new_node->insertAfter(user);
143137
new_node->outputs()[0]->setType(c10::IntType::get());
144138
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
@@ -148,7 +142,7 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
148142

149143
LOG_GRAPH("New intermediate operation: " << *new_node);
150144
LOG_GRAPH(new_node->schema());
151-
145+
152146
// Delete aten::Int
153147
LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast);
154148
potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]);
@@ -163,12 +157,11 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
163157
}
164158
}
165159
}
166-
}
160+
}
167161
}
168162
LOG_ERROR("Post removing single use 0-dim Tensor operations: " << *g);
169163
}
170164

171-
172165
} // namespace passes
173166
} // namespace lowering
174167
} // namespace core

tests/core/lowering/test_remove_unnecessary_casts.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsIntCorrectly) {
102102

103103
auto first_op = *(sg->block()->nodes().begin());
104104
torch::jit::WithInsertPoint guard(first_op);
105-
torch::jit::Value* r = sg->insertConstant(
106-
c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
105+
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
107106
r->copyMetadata(first_op->output());
108107
r->setType(c10::TensorType::get());
109108
first_op->output()->replaceAllUsesWith(r);
@@ -141,8 +140,7 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {
141140

142141
auto first_op = *(sg->block()->nodes().begin());
143142
torch::jit::WithInsertPoint guard(first_op);
144-
torch::jit::Value* r = sg->insertConstant(
145-
c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
143+
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
146144
r->copyMetadata(first_op->output());
147145
r->setType(c10::TensorType::get());
148146
first_op->output()->replaceAllUsesWith(r);

0 commit comments

Comments
 (0)