Skip to content

Commit ef30176

Browse files
authored
Include tensor shapes in get_broadcast_target_size error message (#7944)
This is the motivating example for #7902. Test Plan: Injected failure to new broadcast_test and saw shapes in error message.
1 parent 77f18b2 commit ef30176

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

kernels/portable/cpu/util/broadcast_util.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,22 @@ ET_NODISCARD Error get_broadcast_target_size(
213213
Tensor::SizesType* out_sizes,
214214
const size_t out_sizes_len,
215215
size_t* out_dim) {
216-
ET_CHECK_OR_RETURN_ERROR(
217-
tensors_are_broadcastable_between(a_size, b_size),
218-
InvalidArgument,
219-
"Two input tensors should be broadcastable.\n");
216+
if ET_UNLIKELY (!tensors_are_broadcastable_between(a_size, b_size)) {
217+
#ifdef ET_LOG_ENABLED
218+
const auto a_shape_str = tensor_shape_to_c_string(
219+
executorch::runtime::Span<const Tensor::SizesType>(
220+
a_size.data(), a_size.size()));
221+
const auto b_shape_str = tensor_shape_to_c_string(
222+
executorch::runtime::Span<const Tensor::SizesType>(
223+
b_size.data(), b_size.size()));
224+
#endif
225+
ET_LOG(
226+
Error,
227+
"Two input tensors should be broadcastable but got shapes %s and %s.",
228+
a_shape_str.data(),
229+
b_shape_str.data());
230+
return executorch::runtime::Error::InvalidArgument;
231+
}
220232

221233
auto a_dim = a_size.size();
222234
auto b_dim = b_size.size();

kernels/portable/cpu/util/test/broadcast_test.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ TEST(BroadcastUtilTest, GetBroadcastTargetSize) {
129129
EXPECT_TRUE(
130130
ArrayRef<Tensor::SizesType>(expected_output_size, expected_output_dim)
131131
.equals(ArrayRef<Tensor::SizesType>({5, 2, 2})));
132+
133+
Tensor c = tf.zeros({4, 5});
134+
err = get_broadcast_target_size(
135+
a,
136+
c,
137+
expected_output_size,
138+
torch::executor::kTensorDimensionLimit,
139+
&expected_output_dim);
140+
EXPECT_EQ(err, torch::executor::Error::InvalidArgument);
132141
}
133142

134143
size_t linearize_indexes(size_t* indexes, size_t indexes_len, const Tensor& t) {

0 commit comments

Comments
 (0)