@@ -153,13 +153,22 @@ Tensor& quantize_per_tensor_out(
153
153
}
154
154
155
155
Tensor& quantize_per_tensor_tensor_args_out (
156
+ RuntimeContext& context,
156
157
const Tensor& input,
157
158
const Tensor& scale,
158
159
const Tensor& zero_point,
159
160
int64_t quant_min,
160
161
int64_t quant_max,
161
162
ScalarType dtype,
162
163
Tensor& out) {
164
+ // Temporary change to allow not fatal failure for now to unblock some
165
+ // expected failure tests that are dying instead of failure. Will revisit
166
+ // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal
167
+ // failures.
168
+ if (scale.scalar_type () != ScalarType::Double) {
169
+ context.fail (torch::executor::Error::InvalidArgument);
170
+ return out;
171
+ }
163
172
ET_CHECK_MSG (
164
173
scale.scalar_type () == ScalarType::Double,
165
174
" Expected scale to be Double tensor received: %" PRId8,
@@ -188,36 +197,34 @@ Tensor& quantize_per_tensor_tensor_args_out(
188
197
return out;
189
198
}
190
199
191
- Tensor& quantize_per_tensor_out (
192
- RuntimeContext& context,
193
-
200
+ Tensor& quantize_per_tensor_tensor_args_out (
194
201
const Tensor& input,
195
- double scale,
196
- int64_t zero_point,
202
+ const Tensor& scale,
203
+ const Tensor& zero_point,
197
204
int64_t quant_min,
198
205
int64_t quant_max,
199
206
ScalarType dtype,
200
207
Tensor& out) {
201
- // TODO(larryliu): Add a context arg to the real op function and remove this
202
- // wrapper
203
- ( void ) context;
204
- return quantize_per_tensor_out (
205
- input, scale, zero_point, quant_min, quant_max, dtype, out) ;
208
+ auto context = torch::executor::RuntimeContext ();
209
+ auto & res = quantize_per_tensor_tensor_args_out (
210
+ context, input, scale, zero_point, quant_min, quant_max, dtype, out) ;
211
+ ET_CHECK (context. failure_state () == Error::Ok);
212
+ return res ;
206
213
}
207
214
208
- Tensor& quantize_per_tensor_tensor_args_out (
215
+ Tensor& quantize_per_tensor_out (
209
216
RuntimeContext& context,
210
217
const Tensor& input,
211
- const Tensor& scale,
212
- const Tensor& zero_point,
218
+ double scale,
219
+ int64_t zero_point,
213
220
int64_t quant_min,
214
221
int64_t quant_max,
215
222
ScalarType dtype,
216
223
Tensor& out) {
217
224
// TODO(larryliu): Add a context arg to the real op function and remove this
218
225
// wrapper
219
226
(void )context;
220
- return quantize_per_tensor_tensor_args_out (
227
+ return quantize_per_tensor_out (
221
228
input, scale, zero_point, quant_min, quant_max, dtype, out);
222
229
}
223
230
0 commit comments