1
- #include " torch/csrc/jit/passes/subgraph_rewrite.h"
2
1
#include " torch/csrc/jit/ir/constants.h"
2
+ #include " torch/csrc/jit/passes/subgraph_rewrite.h"
3
3
4
4
#include " core/util/prelude.h"
5
5
@@ -10,7 +10,6 @@ namespace core {
10
10
namespace lowering {
11
11
namespace passes {
12
12
13
-
14
13
// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just
15
14
// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright
16
15
void RemoveUnnecessaryCasts (std::shared_ptr<torch::jit::Graph>& graph) {
@@ -77,8 +76,8 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
77
76
if (user->output ()->uses ().size () == 1 ) {
78
77
auto potential_cast = user->output ()->uses ()[0 ].user ;
79
78
// The downstream user is aten::Int
80
- if (potential_cast->kind () == c10::Symbol::fromQualString (" aten::Int" )
81
- || potential_cast->kind () == c10::Symbol::fromQualString (" aten::Float" )) {
79
+ if (potential_cast->kind () == c10::Symbol::fromQualString (" aten::Int" ) ||
80
+ potential_cast->kind () == c10::Symbol::fromQualString (" aten::Float" )) {
82
81
LOG_GRAPH (" Downstream user is aten::Int/aten::Float" );
83
82
auto arg = use.offset ;
84
83
@@ -88,13 +87,11 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
88
87
LOG_GRAPH (" Input " << k << " is a Tensor" );
89
88
if (user->inputs ()[k]->node ()->kind () == c10::Symbol::fromQualString (" prim::NumToTensor" )) {
90
89
auto num_to_tensor = user->inputs ()[k]->node ();
91
-
92
- LOG_GRAPH (" Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
93
- << *(*it)
94
- << *num_to_tensor
95
- << *user
96
- << *potential_cast);
97
-
90
+
91
+ LOG_GRAPH (
92
+ " Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
93
+ << *(*it) << *num_to_tensor << *user << *potential_cast);
94
+
98
95
// Replace the Tensor Constant with a scalar constant
99
96
LOG_GRAPH (" Deleting 0-dim Tensor: " << **it);
100
97
torch::jit::WithInsertPoint gaurd (*it);
@@ -126,19 +123,16 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
126
123
// has a different schema than the original
127
124
case c10::aten::add:
128
125
new_node = g->create (
129
- user->kind (),
130
- torch::jit::ArrayRef<torch::jit::Value*>({user->inputs ()[0 ], user->inputs ()[1 ]}),
131
- 1 );
126
+ user->kind (),
127
+ torch::jit::ArrayRef<torch::jit::Value*>({user->inputs ()[0 ], user->inputs ()[1 ]}),
128
+ 1 );
132
129
new_node->insertAfter (user);
133
130
new_node->outputs ()[0 ]->setType (c10::IntType::get ());
134
131
user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
135
132
user->destroy ();
136
133
break ;
137
134
default :
138
- new_node = g->create (
139
- user->kind (),
140
- user->inputs (),
141
- 1 );
135
+ new_node = g->create (user->kind (), user->inputs (), 1 );
142
136
new_node->insertAfter (user);
143
137
new_node->outputs ()[0 ]->setType (c10::IntType::get ());
144
138
user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
@@ -148,7 +142,7 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
148
142
149
143
LOG_GRAPH (" New intermediate operation: " << *new_node);
150
144
LOG_GRAPH (new_node->schema ());
151
-
145
+
152
146
// Delete aten::Int
153
147
LOG_GRAPH (" Deleting aten::[Int/Float]: " << *potential_cast);
154
148
potential_cast->output ()->replaceAllUsesWith (potential_cast->inputs ()[0 ]);
@@ -163,12 +157,11 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
163
157
}
164
158
}
165
159
}
166
- }
160
+ }
167
161
}
168
162
LOG_ERROR (" Post removing single use 0-dim Tensor operations: " << *g);
169
163
}
170
164
171
-
172
165
} // namespace passes
173
166
} // namespace lowering
174
167
} // namespace core
0 commit comments