@@ -84,99 +84,38 @@ Tensor& clamp_out(
84
84
Error err = resize_tensor (out, in.sizes ());
85
85
ET_CHECK_MSG (err == Error::Ok, " Could not resize output" );
86
86
87
- ScalarType in_type = in.scalar_type ();
88
- ScalarType min_type = in_type;
89
- ScalarType max_type = in_type;
90
- ScalarType common_type = in_type;
91
- ScalarType out_type = out.scalar_type ();
92
-
93
- bool has_min = min_opt.has_value ();
94
- if (has_min) {
95
- min_type = utils::get_scalar_dtype (min_opt.value ());
96
- common_type = utils::promote_type_with_scalar (common_type, min_opt.value ());
97
- }
98
- bool has_max = max_opt.has_value ();
99
- if (has_max) {
100
- max_type = utils::get_scalar_dtype (max_opt.value ());
101
- common_type = utils::promote_type_with_scalar (common_type, max_opt.value ());
102
- }
103
-
104
- ET_CHECK_MSG (
105
- has_min || has_max, " At least one of 'min' or 'max' must not be None" );
87
+ ET_CHECK_SAME_SHAPE_AND_DTYPE2 (in, out);
106
88
107
- ET_CHECK (common_type == out_type);
108
-
109
- ET_SWITCH_REAL_TYPES (out_type, ctx, " clamp" , CTYPE_OUT, [&]() {
89
+ ET_SWITCH_REAL_TYPES (in.scalar_type (), ctx, " clamp" , CTYPE, [&]() {
110
90
// Extract optional min value
111
- CTYPE_OUT min = 0 ;
91
+ CTYPE min = 0 ;
92
+ bool has_min = min_opt.has_value ();
112
93
if (has_min) {
113
- ET_SWITCH_SCALAR_OBJ_TYPES (min_type, ctx, " clamp" , CTYPE_MIN, [&]() {
114
- CTYPE_MIN min_val = 0 ;
115
- ET_EXTRACT_SCALAR (min_opt.value (), min_val);
116
- if (isIntegralType (out_type, /* includeBool=*/ false )) {
117
- if (static_cast <long >(min_val) <
118
- std::numeric_limits<CTYPE_OUT>::lowest () ||
119
- static_cast <long >(min_val) >
120
- std::numeric_limits<CTYPE_OUT>::max ()) {
121
- ET_CHECK_MSG (false , " minimum value out of bounds" );
122
- }
123
- }
124
- if (isFloatingType (out_type)) {
125
- if (std::isfinite (min_val) &&
126
- (static_cast <double >(min_val) <
127
- std::numeric_limits<CTYPE_OUT>::lowest () ||
128
- static_cast <double >(min_val) >
129
- std::numeric_limits<CTYPE_OUT>::max ())) {
130
- ET_CHECK_MSG (false , " minimum value out of bounds" );
131
- }
132
- }
133
- min = static_cast <CTYPE_OUT>(min_val);
134
- });
94
+ bool ok = utils::extract_scalar<CTYPE>(min_opt.value (), &min);
95
+ ET_CHECK_MSG (ok, " Invalid min value: wrong type or out of range" );
135
96
}
136
-
137
97
// Extract optional max value
138
- CTYPE_OUT max = 0 ;
98
+ CTYPE max = 0 ;
99
+ bool has_max = max_opt.has_value ();
139
100
if (has_max) {
140
- ET_SWITCH_SCALAR_OBJ_TYPES (max_type, ctx, " clamp" , CTYPE_MAX, [&]() {
141
- CTYPE_MAX max_val = 0 ;
142
- ET_EXTRACT_SCALAR (max_opt.value (), max_val);
143
- if (isIntegralType (out_type, /* includeBool=*/ false )) {
144
- if (static_cast <long >(max_val) <
145
- std::numeric_limits<CTYPE_OUT>::lowest () ||
146
- static_cast <long >(max_val) >
147
- std::numeric_limits<CTYPE_OUT>::max ()) {
148
- ET_CHECK_MSG (false , " maximum value out of bounds" );
149
- }
150
- }
151
- if (isFloatingType (out_type)) {
152
- if (std::isfinite (max_val) &&
153
- (static_cast <double >(max_val) <
154
- std::numeric_limits<CTYPE_OUT>::lowest () ||
155
- static_cast <double >(max_val) >
156
- std::numeric_limits<CTYPE_OUT>::max ())) {
157
- ET_CHECK_MSG (false , " maximum value out of bounds" );
158
- }
159
- }
160
- max = static_cast <CTYPE_OUT>(max_val);
161
- });
101
+ bool ok = utils::extract_scalar<CTYPE>(max_opt.value (), &max);
102
+ ET_CHECK_MSG (ok, " Invalid max value: wrong type or out of range" );
162
103
}
163
104
164
- ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, " clamp" , CTYPE_IN, [&]() {
165
- apply_unary_map_fn (
166
- [has_min, min, has_max, max](const CTYPE_IN val_in) {
167
- CTYPE_OUT val_out = static_cast <CTYPE_OUT>(val_in);
168
- if (has_min) {
169
- val_out = max_override (val_out, min);
170
- }
171
- if (has_max) {
172
- val_out = min_override (val_out, max);
173
- }
174
- return val_out;
175
- },
176
- in.const_data_ptr <CTYPE_IN>(),
177
- out.mutable_data_ptr <CTYPE_OUT>(),
178
- in.numel ());
179
- });
105
+ apply_unary_map_fn (
106
+ [has_min, min, has_max, max](const CTYPE val_in) {
107
+ CTYPE val_out = val_in;
108
+ if (has_min) {
109
+ val_out = max_override (val_out, min);
110
+ }
111
+ if (has_max) {
112
+ val_out = min_override (val_out, max);
113
+ }
114
+ return val_out;
115
+ },
116
+ in.const_data_ptr <CTYPE>(),
117
+ out.mutable_data_ptr <CTYPE>(),
118
+ in.numel ());
180
119
});
181
120
182
121
return out;
0 commit comments