@@ -51,6 +51,13 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
51
51
}
52
52
53
53
namespace internal {
54
+ template <typename Ignore, typename T>
55
+ using ignore_first_yield_second = T;
56
+
57
+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
58
+ using op_call_result =
59
+ std::invoke_result_t <Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
60
+
54
61
template <
55
62
typename CTYPE_COMMON,
56
63
const char * op_name,
@@ -89,9 +96,16 @@ inline void apply_elementwise_fn(
89
96
inputs.first ->element_size (),
90
97
})...};
91
98
92
- const auto store_common_to_out =
93
- internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
94
- out, out_dtypes);
99
+ // NOTE: the result of compute_fun is not necessarily CTYPE_COMMON!
100
+ // For example, consider the possibility that compute_fun is a
101
+ // trigonometric function like acos, the common input type is bool,
102
+ // and the output type is float -- we would truncate acos(0) ~= 1.67
103
+ // to just 1. Conveniently, it costs us nothing at runtime to handle
104
+ // this correctly.
105
+ const auto store_compute_result_to_out =
106
+ internal::get_store_common_to_tensor_fn<
107
+ op_call_result<CTYPE_COMMON, Op, Args...>,
108
+ op_name>(out, out_dtypes);
95
109
char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
96
110
const auto out_element_size = out.element_size ();
97
111
@@ -114,7 +128,8 @@ inline void apply_elementwise_fn(
114
128
.data_ptr [indexes[idx + 1 ] * input_info.element_size ]);
115
129
}
116
130
auto result = std::apply (compute_fun, loaded_inputs);
117
- store_common_to_out (result, &data_out[indexes[0 ] * out_element_size]);
131
+ store_compute_result_to_out (
132
+ result, &data_out[indexes[0 ] * out_element_size]);
118
133
}
119
134
});
120
135
}
0 commit comments