Skip to content

Commit 8a38922

Browse files
salilsdesaifacebook-github-bot
authored andcommitted
Revert D47573238 - "[ET][Portable] Dtype compliance: clamp" (#77)
Summary: Pull Request resolved: #77 D47573238 broke pytorch-edge-run-executorch-transcribe-bin Original commit changeset: c083d810365e Original Phabricator Diff: D47573238 Reviewed By: anticlockwise Differential Revision: D48456015 fbshipit-source-id: 1975e574c3e229e5d4d32bc661a4104dd32dab00
1 parent d708222 commit 8a38922

File tree

2 files changed

+26
-87
lines changed

2 files changed

+26
-87
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 24 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -84,99 +84,38 @@ Tensor& clamp_out(
8484
Error err = resize_tensor(out, in.sizes());
8585
ET_CHECK_MSG(err == Error::Ok, "Could not resize output");
8686

87-
ScalarType in_type = in.scalar_type();
88-
ScalarType min_type = in_type;
89-
ScalarType max_type = in_type;
90-
ScalarType common_type = in_type;
91-
ScalarType out_type = out.scalar_type();
92-
93-
bool has_min = min_opt.has_value();
94-
if (has_min) {
95-
min_type = utils::get_scalar_dtype(min_opt.value());
96-
common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
97-
}
98-
bool has_max = max_opt.has_value();
99-
if (has_max) {
100-
max_type = utils::get_scalar_dtype(max_opt.value());
101-
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
102-
}
103-
104-
ET_CHECK_MSG(
105-
has_min || has_max, "At least one of 'min' or 'max' must not be None");
87+
ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out);
10688

107-
ET_CHECK(common_type == out_type);
108-
109-
ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
89+
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "clamp", CTYPE, [&]() {
11090
// Extract optional min value
111-
CTYPE_OUT min = 0;
91+
CTYPE min = 0;
92+
bool has_min = min_opt.has_value();
11293
if (has_min) {
113-
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() {
114-
CTYPE_MIN min_val = 0;
115-
ET_EXTRACT_SCALAR(min_opt.value(), min_val);
116-
if (isIntegralType(out_type, /*includeBool=*/false)) {
117-
if (static_cast<long>(min_val) <
118-
std::numeric_limits<CTYPE_OUT>::lowest() ||
119-
static_cast<long>(min_val) >
120-
std::numeric_limits<CTYPE_OUT>::max()) {
121-
ET_CHECK_MSG(false, "minimum value out of bounds");
122-
}
123-
}
124-
if (isFloatingType(out_type)) {
125-
if (std::isfinite(min_val) &&
126-
(static_cast<double>(min_val) <
127-
std::numeric_limits<CTYPE_OUT>::lowest() ||
128-
static_cast<double>(min_val) >
129-
std::numeric_limits<CTYPE_OUT>::max())) {
130-
ET_CHECK_MSG(false, "minimum value out of bounds");
131-
}
132-
}
133-
min = static_cast<CTYPE_OUT>(min_val);
134-
});
94+
bool ok = utils::extract_scalar<CTYPE>(min_opt.value(), &min);
95+
ET_CHECK_MSG(ok, "Invalid min value: wrong type or out of range");
13596
}
136-
13797
// Extract optional max value
138-
CTYPE_OUT max = 0;
98+
CTYPE max = 0;
99+
bool has_max = max_opt.has_value();
139100
if (has_max) {
140-
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() {
141-
CTYPE_MAX max_val = 0;
142-
ET_EXTRACT_SCALAR(max_opt.value(), max_val);
143-
if (isIntegralType(out_type, /*includeBool=*/false)) {
144-
if (static_cast<long>(max_val) <
145-
std::numeric_limits<CTYPE_OUT>::lowest() ||
146-
static_cast<long>(max_val) >
147-
std::numeric_limits<CTYPE_OUT>::max()) {
148-
ET_CHECK_MSG(false, "maximum value out of bounds");
149-
}
150-
}
151-
if (isFloatingType(out_type)) {
152-
if (std::isfinite(max_val) &&
153-
(static_cast<double>(max_val) <
154-
std::numeric_limits<CTYPE_OUT>::lowest() ||
155-
static_cast<double>(max_val) >
156-
std::numeric_limits<CTYPE_OUT>::max())) {
157-
ET_CHECK_MSG(false, "maximum value out of bounds");
158-
}
159-
}
160-
max = static_cast<CTYPE_OUT>(max_val);
161-
});
101+
bool ok = utils::extract_scalar<CTYPE>(max_opt.value(), &max);
102+
ET_CHECK_MSG(ok, "Invalid max value: wrong type or out of range");
162103
}
163104

164-
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() {
165-
apply_unary_map_fn(
166-
[has_min, min, has_max, max](const CTYPE_IN val_in) {
167-
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
168-
if (has_min) {
169-
val_out = max_override(val_out, min);
170-
}
171-
if (has_max) {
172-
val_out = min_override(val_out, max);
173-
}
174-
return val_out;
175-
},
176-
in.const_data_ptr<CTYPE_IN>(),
177-
out.mutable_data_ptr<CTYPE_OUT>(),
178-
in.numel());
179-
});
105+
apply_unary_map_fn(
106+
[has_min, min, has_max, max](const CTYPE val_in) {
107+
CTYPE val_out = val_in;
108+
if (has_min) {
109+
val_out = max_override(val_out, min);
110+
}
111+
if (has_max) {
112+
val_out = min_override(val_out, max);
113+
}
114+
return val_out;
115+
},
116+
in.const_data_ptr<CTYPE>(),
117+
out.mutable_data_ptr<CTYPE>(),
118+
in.numel());
180119
});
181120

182121
return out;

kernels/test/op_clamp_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,12 @@ TEST(OpClampOutTest, ByteTensorFloatingPointClampDies) {
303303

304304
#ifndef USE_ATEN_LIB
305305
TEST(OpClampOutTest, IntTensorTooSmallClampDies) {
306-
// Cannot be represented by a int32_t.
306+
// Cannot be represented by a uint32_t.
307307
expect_bad_clamp_value_dies<ScalarType::Int>(-2147483649);
308308
}
309309

310310
TEST(OpClampOutTest, IntTensorTooLargeClampDies) {
311-
// Cannot be represented by a int32_t.
311+
// Cannot be represented by a uint32_t.
312312
expect_bad_clamp_value_dies<ScalarType::Int>(2147483648);
313313
}
314314
#endif

0 commit comments

Comments
 (0)