@@ -117,77 +117,68 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
117
117
118
118
// Change intermediate op output type
119
119
LOG_GRAPH (user->schema ());
120
-
121
120
torch::jit::Node* new_node;
122
- switch (user->kind ()) {
123
- // Use this to handle special cases where the scalar version of the intermediate operator
124
- // has a different schema than the original
125
- case c10::aten::add:
126
- new_node = g->create (
127
- user->kind (),
128
- torch::jit::ArrayRef<torch::jit::Value*>({user->inputs ()[0 ], user->inputs ()[1 ]}),
129
- 1 );
130
- new_node->insertAfter (user);
131
- new_node->outputs ()[0 ]->setType (c10::IntType::get ());
132
- user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
133
- user->destroy ();
134
- break ;
135
- case c10::aten::floor_divide:
136
- new_node = g->create (c10::aten::floordiv, user->inputs (), 1 );
137
- new_node->insertAfter (user);
138
- new_node->outputs ()[0 ]->setType (c10::IntType::get ());
139
- user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
140
- user->destroy ();
141
- break ;
142
- case c10::aten::div:
143
- // If the first two entries to aten::div are non-Tensors,
144
- // there cannot be a rounding mode specified (3rd entry)
145
- if (!user->inputs ()[0 ]->type ()->isSubtypeOf (c10::TensorType::get ()) &&
146
- !user->inputs ()[1 ]->type ()->isSubtypeOf (c10::TensorType::get ()) &&
147
- user->inputs ().size () == 3 &&
148
- user->inputs ()[2 ]->type ()->isSubtypeOf (c10::StringType::get ()) &&
149
- torch::jit::toIValue (user->inputs ()[2 ]).has_value ()) {
150
- // Select the first 2 entries of the inputs, corresponding to the values
151
- auto div_args = user->inputs ().slice (0 , 2 );
152
-
153
- // Depending on the rounding mode, create the appropriate nodes
154
- if (torch::jit::toIValue (user->inputs ()[2 ]).value ().toStringRef () == " trunc" ) {
155
- // Truncate case (round result towards 0)
156
- torch::jit::Node* new_node_div;
157
- // Create node which simply divides the two entries
158
- new_node_div = g->create (c10::aten::div, div_args, 1 );
159
- new_node_div->insertAfter (user);
160
- new_node_div->outputs ()[0 ]->setType (c10::FloatType::get ());
161
-
162
- // Create node which casts the result to an integer, effectively truncating
163
- new_node = g->create (c10::aten::Int, new_node_div->outputs (), 1 );
164
- new_node->insertAfter (new_node_div);
165
- new_node->outputs ()[0 ]->setType (c10::IntType::get ());
166
-
167
- user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
168
- user->destroy ();
169
- break ;
170
-
171
- } else if (torch::jit::toIValue (user->inputs ()[2 ]).value ().toStringRef () == " floor" ) {
172
- // Floor case (round result down)
173
- // Replace aten::div with aten::floordiv
174
- new_node = g->create (c10::aten::floordiv, div_args, 1 );
175
- new_node->insertAfter (user);
176
- new_node->outputs ()[0 ]->setType (c10::IntType::get ());
177
-
178
- user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
179
- user->destroy ();
180
- break ;
181
- }
121
+ // Use this to handle special cases where the scalar version of the intermediate operator
122
+ // has a different schema than the original
123
+ if (user->kind () == c10::Symbol::fromQualString (" aten::add" )) {
124
+ new_node = g->create (
125
+ c10::Symbol::fromQualString (" aten::add" ),
126
+ torch::jit::ArrayRef<torch::jit::Value*>({user->inputs ()[0 ], user->inputs ()[1 ]}),
127
+ 1 );
128
+ new_node->insertAfter (user);
129
+ new_node->outputs ()[0 ]->setType (c10::IntType::get ());
130
+ user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
131
+ user->destroy ();
132
+ } else if (user->kind () == c10::Symbol::fromQualString (" aten::floordiv" )) {
133
+ new_node = g->create (c10::aten::floordiv, user->inputs (), 1 );
134
+ new_node->insertAfter (user);
135
+ new_node->outputs ()[0 ]->setType (c10::IntType::get ());
136
+ user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
137
+ user->destroy ();
138
+ } else if (user->kind () == c10::Symbol::fromQualString (" aten::div" )) {
139
+ // If the first two entries to aten::div are non-Tensors,
140
+ // there cannot be a rounding mode specified (3rd entry)
141
+ if (!user->inputs ()[0 ]->type ()->isSubtypeOf (c10::TensorType::get ()) &&
142
+ !user->inputs ()[1 ]->type ()->isSubtypeOf (c10::TensorType::get ()) &&
143
+ user->inputs ().size () == 3 &&
144
+ user->inputs ()[2 ]->type ()->isSubtypeOf (c10::StringType::get ()) &&
145
+ torch::jit::toIValue (user->inputs ()[2 ]).has_value ()) {
146
+ // Select the first 2 entries of the inputs, corresponding to the values
147
+ auto div_args = user->inputs ().slice (0 , 2 );
148
+
149
+ // Depending on the rounding mode, create the appropriate nodes
150
+ if (torch::jit::toIValue (user->inputs ()[2 ]).value ().toStringRef () == " trunc" ) {
151
+ // Truncate case (round result towards 0)
152
+ torch::jit::Node* new_node_div;
153
+ // Create node which simply divides the two entries
154
+ new_node_div = g->create (c10::aten::div, div_args, 1 );
155
+ new_node_div->insertAfter (user);
156
+ new_node_div->outputs ()[0 ]->setType (c10::FloatType::get ());
157
+
158
+ // Create node which casts the result to an integer, effectively truncating
159
+ new_node = g->create (c10::aten::Int, new_node_div->outputs (), 1 );
160
+ new_node->insertAfter (new_node_div);
161
+ new_node->outputs ()[0 ]->setType (c10::IntType::get ());
162
+
163
+ user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
164
+ user->destroy ();
165
+ } else if (torch::jit::toIValue (user->inputs ()[2 ]).value ().toStringRef () == " floor" ) {
166
+ // Floor case (round result down)
167
+ // Replace aten::div with aten::floordiv
168
+ new_node = g->create (c10::aten::floordiv, div_args, 1 );
169
+ new_node->insertAfter (user);
170
+ new_node->outputs ()[0 ]->setType (c10::IntType::get ());
171
+
172
+ user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
173
+ user->destroy ();
182
174
}
183
-
184
- default :
185
- new_node = g->create (user->kind (), user->inputs (), 1 );
186
- new_node->insertAfter (user);
187
- new_node->outputs ()[0 ]->setType (c10::IntType::get ());
188
- user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
189
- user->destroy ();
190
- break ;
175
+ }
176
+ } else {
177
+ new_node = g->create (user->kind (), user->inputs (), 1 );
178
+ new_node->insertAfter (user);
179
+ new_node->outputs ()[0 ]->setType (c10::IntType::get ());
180
+ user->outputs ()[0 ]->replaceAllUsesWith (new_node->outputs ()[0 ]);
181
+ user->destroy ();
191
182
}
192
183
193
184
LOG_GRAPH (" New intermediate operation: " << *new_node);
0 commit comments