Skip to content

Commit df1a5fb

Browse files
manuelcandalespytorchbot
authored andcommitted
op_clamp: add downcasting tests & fix (#5798)
Summary: Pull Request resolved: #5798 Reviewed By: swolchok Differential Revision: D63716405 fbshipit-source-id: 4987e9ad93f0b3f490432cf07ba19c2f26fc82e0 (cherry picked from commit 3aa6b14)
1 parent d8dacf3 commit df1a5fb

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ Tensor& clamp_tensor_out(
218218
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
219219
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
220220
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
221+
using CTYPE_MINMAX = typename torch::executor::
222+
promote_types<CTYPE_MIN, CTYPE_MAX>::type;
223+
using CTYPE = typename torch::executor::
224+
promote_types<CTYPE_IN, CTYPE_MINMAX>::type;
221225
apply_ternary_elementwise_fn<
222226
CTYPE_IN,
223227
CTYPE_MIN,
@@ -227,16 +231,16 @@ Tensor& clamp_tensor_out(
227231
const CTYPE_IN val_in,
228232
const CTYPE_MIN val_min,
229233
const CTYPE_MAX val_max) {
230-
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
234+
CTYPE val_out = static_cast<CTYPE>(val_in);
231235
if (has_min) {
232-
val_out = utils::max_override(
233-
val_out, static_cast<CTYPE_OUT>(val_min));
236+
val_out =
237+
utils::max_override(val_out, static_cast<CTYPE>(val_min));
234238
}
235239
if (has_max) {
236-
val_out = utils::min_override(
237-
val_out, static_cast<CTYPE_OUT>(val_max));
240+
val_out =
241+
utils::min_override(val_out, static_cast<CTYPE>(val_max));
238242
}
239-
return val_out;
243+
return static_cast<CTYPE_OUT>(val_out);
240244
},
241245
in,
242246
min,

kernels/test/op_clamp_test.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,3 +484,51 @@ TEST_F(OpClampTensorOutTest, SmokeTest) {
484484
op_clamp_tensor_out(in, min, max, out);
485485
EXPECT_TENSOR_EQ(out, expected);
486486
}
487+
488+
TEST_F(OpClampTensorOutTest, DowncastingSmokeTest) {
489+
TensorFactory<ScalarType::Byte> tf_in;
490+
TensorFactory<ScalarType::Short> tf_min;
491+
TensorFactory<ScalarType::Int> tf_max;
492+
TensorFactory<ScalarType::Char> tf_out;
493+
494+
Tensor in = tf_in.make({}, {5});
495+
Tensor min = tf_min.make({}, {-129});
496+
Tensor max = tf_max.make({}, {300});
497+
Tensor out = tf_out.zeros({});
498+
Tensor expected = tf_out.make({}, {5});
499+
500+
op_clamp_tensor_out(in, min, max, out);
501+
EXPECT_TENSOR_EQ(out, expected);
502+
}
503+
504+
TEST_F(OpClampTensorOutTest, DowncastingSmokeTest2) {
505+
TensorFactory<ScalarType::Short> tf_in;
506+
TensorFactory<ScalarType::Short> tf_min;
507+
TensorFactory<ScalarType::Int> tf_max;
508+
TensorFactory<ScalarType::Char> tf_out;
509+
510+
Tensor in = tf_in.make({}, {301});
511+
Tensor min = tf_min.make({}, {-129});
512+
Tensor max = tf_max.make({}, {300});
513+
Tensor out = tf_out.zeros({});
514+
Tensor expected = tf_out.make({}, {44});
515+
516+
op_clamp_tensor_out(in, min, max, out);
517+
EXPECT_TENSOR_EQ(out, expected);
518+
}
519+
520+
TEST_F(OpClampTensorOutTest, DowncastingSmokeTest3) {
521+
TensorFactory<ScalarType::Short> tf_in;
522+
TensorFactory<ScalarType::Short> tf_min;
523+
TensorFactory<ScalarType::Int> tf_max;
524+
TensorFactory<ScalarType::Char> tf_out;
525+
526+
Tensor in = tf_in.make({}, {45});
527+
Tensor min = tf_min.make({}, {-129});
528+
Tensor max = tf_max.make({}, {300});
529+
Tensor out = tf_out.zeros({});
530+
Tensor expected = tf_out.make({}, {45});
531+
532+
op_clamp_tensor_out(in, min, max, out);
533+
EXPECT_TENSOR_EQ(out, expected);
534+
}

0 commit comments

Comments
 (0)