Skip to content

Commit 8bb172d

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from model inventory
Summary: capture is deprecated we should be using export. Also deleted some tests that werent maintained and werent passing due to several reasons. Since this lib is basically on life support and we have coverage elsewhere I didnt want to spend a ton of time debugging. Reviewed By: Jack-Khuu, angelayi Differential Revision: D54562526
1 parent 9d6bf72 commit 8bb172d

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,22 @@ Tensor& quantize_per_tensor_out(
153153
}
154154

155155
Tensor& quantize_per_tensor_tensor_args_out(
156+
RuntimeContext& context,
156157
const Tensor& input,
157158
const Tensor& scale,
158159
const Tensor& zero_point,
159160
int64_t quant_min,
160161
int64_t quant_max,
161162
ScalarType dtype,
162163
Tensor& out) {
164+
// Temporary change to allow not fatal failure for now to unblock some
165+
// expected failure tests that are dying instead of failure. Will revisit
166+
// after ET_KERNEL_CHECK is fully implemented and properly allows non fatal
167+
// failures.
168+
if (scale.scalar_type() != ScalarType::Double) {
169+
context.fail(torch::executor::Error::InvalidArgument);
170+
return out;
171+
}
163172
ET_CHECK_MSG(
164173
scale.scalar_type() == ScalarType::Double,
165174
"Expected scale to be Double tensor received: %" PRId8,
@@ -188,36 +197,34 @@ Tensor& quantize_per_tensor_tensor_args_out(
188197
return out;
189198
}
190199

191-
Tensor& quantize_per_tensor_out(
192-
RuntimeContext& context,
193-
200+
Tensor& quantize_per_tensor_tensor_args_out(
194201
const Tensor& input,
195-
double scale,
196-
int64_t zero_point,
202+
const Tensor& scale,
203+
const Tensor& zero_point,
197204
int64_t quant_min,
198205
int64_t quant_max,
199206
ScalarType dtype,
200207
Tensor& out) {
201-
// TODO(larryliu): Add a context arg to the real op function and remove this
202-
// wrapper
203-
(void)context;
204-
return quantize_per_tensor_out(
205-
input, scale, zero_point, quant_min, quant_max, dtype, out);
208+
auto context = torch::executor::RuntimeContext();
209+
auto& res = quantize_per_tensor_tensor_args_out(
210+
context, input, scale, zero_point, quant_min, quant_max, dtype, out);
211+
ET_CHECK(context.failure_state() == Error::Ok);
212+
return res;
206213
}
207214

208-
Tensor& quantize_per_tensor_tensor_args_out(
215+
Tensor& quantize_per_tensor_out(
209216
RuntimeContext& context,
210217
const Tensor& input,
211-
const Tensor& scale,
212-
const Tensor& zero_point,
218+
double scale,
219+
int64_t zero_point,
213220
int64_t quant_min,
214221
int64_t quant_max,
215222
ScalarType dtype,
216223
Tensor& out) {
217224
// TODO(larryliu): Add a context arg to the real op function and remove this
218225
// wrapper
219226
(void)context;
220-
return quantize_per_tensor_tensor_args_out(
227+
return quantize_per_tensor_out(
221228
input, scale, zero_point, quant_min, quant_max, dtype, out);
222229
}
223230

kernels/quantized/test/op_quantize_test.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,16 @@ TEST(OpQuantizeOutTest, TensorArgOverload) {
6868
Tensor out = tfo.zeros({3, 5});
6969
// 4 / 0.5 + 127
7070
Tensor expected = tfo.full({3, 5}, 135);
71+
auto context = torch::executor::KernelRuntimeContext();
7172
quantize_per_tensor_tensor_args_out(
72-
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
73+
context,
74+
input,
75+
scale,
76+
zero_point,
77+
quant_min,
78+
quant_max,
79+
ScalarType::Byte,
80+
out);
7381

7482
EXPECT_TENSOR_EQ(out, expected);
7583
}
@@ -93,8 +101,16 @@ TEST(OpQuantizeOutTest, TestOutOfBounds) {
93101

94102
Tensor expected = tfo.full({1, 3, 256, 256}, 127);
95103

104+
auto context = torch::executor::KernelRuntimeContext();
96105
quantize_per_tensor_tensor_args_out(
97-
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
106+
context,
107+
input,
108+
scale,
109+
zero_point,
110+
quant_min,
111+
quant_max,
112+
ScalarType::Char,
113+
out);
98114

99115
EXPECT_TENSOR_EQ(out, expected);
100116
}

0 commit comments

Comments
 (0)