Skip to content

remove exir.capture from model inventory #2302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions kernels/quantized/cpu/op_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,22 @@ Tensor& quantize_per_tensor_out(
}

Tensor& quantize_per_tensor_tensor_args_out(
RuntimeContext& context,
const Tensor& input,
const Tensor& scale,
const Tensor& zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
// Temporary change to allow not fatal failure for now to unblock some
// expected failure tests that are dying instead of failure. Will revisit
// after ET_KERNEL_CHECK is fully implemented and properly allows non fatal
// failures.
if (scale.scalar_type() != ScalarType::Double) {
context.fail(torch::executor::Error::InvalidArgument);
return out;
}
ET_CHECK_MSG(
scale.scalar_type() == ScalarType::Double,
"Expected scale to be Double tensor received: %" PRId8,
Expand Down Expand Up @@ -188,36 +197,34 @@ Tensor& quantize_per_tensor_tensor_args_out(
return out;
}

Tensor& quantize_per_tensor_out(
RuntimeContext& context,

Tensor& quantize_per_tensor_tensor_args_out(
const Tensor& input,
double scale,
int64_t zero_point,
const Tensor& scale,
const Tensor& zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
// TODO(larryliu): Add a context arg to the real op function and remove this
// wrapper
(void)context;
return quantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, dtype, out);
auto context = torch::executor::RuntimeContext();
auto& res = quantize_per_tensor_tensor_args_out(
context, input, scale, zero_point, quant_min, quant_max, dtype, out);
ET_CHECK(context.failure_state() == Error::Ok);
return res;
}

Tensor& quantize_per_tensor_tensor_args_out(
Tensor& quantize_per_tensor_out(
RuntimeContext& context,
const Tensor& input,
const Tensor& scale,
const Tensor& zero_point,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
// TODO(larryliu): Add a context arg to the real op function and remove this
// wrapper
(void)context;
return quantize_per_tensor_tensor_args_out(
return quantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, dtype, out);
}

Expand Down
20 changes: 18 additions & 2 deletions kernels/quantized/test/op_quantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,16 @@ TEST(OpQuantizeOutTest, TensorArgOverload) {
Tensor out = tfo.zeros({3, 5});
// 4 / 0.5 + 127
Tensor expected = tfo.full({3, 5}, 135);
auto context = torch::executor::KernelRuntimeContext();
quantize_per_tensor_tensor_args_out(
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
context,
input,
scale,
zero_point,
quant_min,
quant_max,
ScalarType::Byte,
out);

EXPECT_TENSOR_EQ(out, expected);
}
Expand All @@ -93,8 +101,16 @@ TEST(OpQuantizeOutTest, TestOutOfBounds) {

Tensor expected = tfo.full({1, 3, 256, 256}, 127);

auto context = torch::executor::KernelRuntimeContext();
quantize_per_tensor_tensor_args_out(
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
context,
input,
scale,
zero_point,
quant_min,
quant_max,
ScalarType::Char,
out);

EXPECT_TENSOR_EQ(out, expected);
}
Expand Down