Skip to content

Commit 25c5b67

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Fix BinaryOp broadcasting for packed dim (#2653)
Summary: Pull Request resolved: #2653 As copyrightly pointed out, broadcasting was not working properly for the example below. I root caused the to confusion between `sizes()` vs `gpu_sizes()` once again! These concepts are explained in #2520 We should use the CPU size, not the GPU size to detect when we should broadcast across the packed-dim texel's elements. # Example Given inputs `torch.ones(2, 3)` and `torch.ones(2, 1)` and `GPUMemoryLayout::WIDTH_PACKED`, we have CPU widths 3 and 1, respectively. These are aligned up to GPU widths 4 and 4, and hence we were failing to broadcast along the packed-dim texel's elements. ## torch.ones(2, 3) ``` (2, 3) = (H, W) = sizes [[1 1 1] [1 1 1]] -> (W, H) = (3, 2) → (4, 2) = gpu_sizes -> extents = (1, 2) [1 1 1 0] [1 1 1 0] ``` ## torch.ones(2, 1) ``` (2, 1) = (H, W) = sizes [[1] [1]] -> (W, H) = (1, 2) → (4, 2) = gpu_sizes -> extents = (1, 2) [1 0 0 0] [1 0 0 0] -> (broadcast from this change) [1 1 1 1] [1 1 1 1] ``` ## torch.ones(2, 3) + torch.ones(2, 1) Ignore the final element of each texel as it's just padding we never read. ``` No broadcast: [1 1 1 0] [1 1 1 0] + [1 0 0 0] [1 0 0 0] = [2 1 1 0] [2 1 1 0] Broadcast: [1 1 1 0] [1 1 1 0] + [1 1 1 1] [1 1 1 1] = [2 2 2 1] [2 2 2 1] ``` # Cleanup Remove unneeded `check_broadcastable()` since this is caught earlier in the PyTorch compiler pipeline. For example, `torch.ones(2, 3) + torch.ones(2, 2)` triggers this error: ``` TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., size=(2, 3)), FakeTensor(..., size=(2, 2))), **{}): Attempting to broadcast a dimension of length 2 at -1! Mismatching argument at index 1 had torch.Size([2, 2]); but expected shape should be broadcastable to [2, 3] ``` bypass-github-export-checks Reviewed By: SS-JIA Differential Revision: D55278527 fbshipit-source-id: abb8a83924370b21dbbabdd5f1f4af8f502edc1f
1 parent 04d568d commit 25c5b67

File tree

6 files changed

+64
-40
lines changed

6 files changed

+64
-40
lines changed

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes {
3636
}
3737
other_sizes;
3838

39-
layout(set = 0, binding = 6) uniform PRECISION restrict Alpha {
39+
layout(set = 0, binding = 6) uniform PRECISION restrict BroadcastParams {
40+
ivec2 data;
41+
}
42+
broadcast_params;
43+
44+
layout(set = 0, binding = 7) uniform PRECISION restrict Alpha {
4045
float data;
4146
}
4247
alpha;
@@ -63,8 +68,11 @@ void main() {
6368
COORD_TO_POS_${PACKING}(other_coord, other_sizes.data),
6469
0));
6570

66-
// Detect broadcasting
67-
if (PACKED_DIM_${PACKING}(other_sizes.data) < PACKED_DIM_${PACKING}(in_sizes.data)) {
71+
// Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
72+
if (broadcast_params.data.x > 0) {
73+
in_texel = in_texel.xxxx;
74+
}
75+
if (broadcast_params.data.y > 0) {
6876
other_texel = other_texel.xxxx;
6977
}
7078

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ void check_binary_op_args(
2424
const vTensor& other,
2525
const vTensor& out) {
2626
VK_CHECK_COND(check_same_memory_layout(self, other, out));
27-
VK_CHECK_COND(check_broadcastable(self, other));
2827
std::vector<int64_t> broadcasted_sizes =
2928
calculate_broadcasted_output_size(self, other);
3029
VK_CHECK_COND(out.sizes() == broadcasted_sizes);
@@ -36,6 +35,8 @@ void resize_binary_op_node(
3635
const std::vector<ValueRef>& extra_args) {
3736
(void)extra_args;
3837
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
38+
39+
// TODO(T183442143): Verify tensors are broadcastable.
3940
vTensor& self = graph->get_val(args[1].refs[0]).toTensor();
4041
vTensor& other = graph->get_val(args[1].refs[1]).toTensor();
4142

@@ -73,6 +74,9 @@ void add_binary_op_node(
7374
alpha_val = extract_scalar<float>(graph.get_val(alpha));
7475
}
7576

77+
const api::utils::ivec2 broadcast_params =
78+
create_broadcast_params(t_in1, t_in2);
79+
7680
std::stringstream kernel_name;
7781
kernel_name << "binary_" << op_name;
7882
apply_memory_layout_suffix(kernel_name, t_out);
@@ -90,6 +94,7 @@ void add_binary_op_node(
9094
{t_out.gpu_sizes_ubo(),
9195
t_in1.gpu_sizes_ubo(),
9296
t_in2.gpu_sizes_ubo(),
97+
graph.create_params_buffer(broadcast_params),
9398
graph.create_params_buffer(alpha_val)},
9499
// Resizing
95100
resize_binary_op_node));

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,35 @@ bool check_same_memory_layout(
7272
return (t1.gpu_memory_layout() == t3.gpu_memory_layout());
7373
}
7474

75-
bool check_broadcastable(const vTensor& t1, const vTensor& t2) {
76-
size_t ndim = std::max(t1.sizes().size(), t2.sizes().size());
75+
//
76+
// Broadcast flag functions
77+
//
7778

78-
// Match the sizes in reverse because sizes are in NCHW order
79-
for (int i = -1; i >= -ndim; --i) {
80-
int64_t t1_size = api::utils::val_at(i, t1.sizes());
81-
int64_t t2_size = api::utils::val_at(i, t2.sizes());
82-
// If the sizes are not equal, one of them must be 1
83-
if (t1_size != t2_size) {
84-
if (t1_size > 1 && t2_size != 1) {
85-
return false;
86-
} else if (t2_size > 1 && t1_size != 1) {
87-
return false;
88-
}
89-
}
79+
bool is_packed_dim_broadcasted(const vTensor& sndr, const vTensor& rcvr) {
80+
// We assume that the tensors are broadcastable. If values aren't equal at
81+
// some index, then the value of rcvr is 1 and hence should be broadcasted.
82+
switch (sndr.gpu_memory_layout()) {
83+
case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
84+
return api::utils::val_at(-3, sndr.sizes()) >
85+
api::utils::val_at(-3, rcvr.sizes());
86+
case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED:
87+
return api::utils::val_at(-2, sndr.sizes()) >
88+
api::utils::val_at(-2, rcvr.sizes());
89+
case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
90+
return api::utils::val_at(-1, sndr.sizes()) >
91+
api::utils::val_at(-1, rcvr.sizes());
9092
}
93+
}
9194

92-
return true;
95+
api::utils::ivec2 create_broadcast_params(
96+
const vTensor& t1,
97+
const vTensor& t2) {
98+
return api::utils::make_ivec2(
99+
{is_packed_dim_broadcasted(t2, t1), is_packed_dim_broadcasted(t1, t2)});
93100
}
94101

95102
//
96-
// Work Group Size Calculation Utilities
103+
// Work group size calculation functions
97104
//
98105

99106
api::utils::uvec3 adaptive_work_group_size(

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,14 @@ bool check_same_memory_layout(
4747
const vTensor& t2,
4848
const vTensor& t3);
4949

50-
bool check_broadcastable(const vTensor& t1, const vTensor& t2);
50+
//
51+
// Broadcast flag functions
52+
//
53+
54+
api::utils::ivec2 create_broadcast_params(const vTensor& t1, const vTensor& t2);
5155

5256
//
53-
// Work Group Size Calculation Utilities
57+
// Work group size calculation functions
5458
//
5559

5660
api::utils::uvec3 adaptive_work_group_size(

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,20 @@ class AddModule(torch.nn.Module):
146146
def __init__(self):
147147
super().__init__()
148148

149-
def forward(self, x, y):
149+
def forward(self, x, y, w):
150150
z = x + y
151151
z = z + x
152152
z = z + x
153+
z = z + w
154+
z = w + z
155+
z = z + 3 # test scalar broadcasting
153156
return z
154157

155158
add_module = AddModule()
156159
sample_inputs = (
157160
torch.rand(size=(2, 3), dtype=torch.float32),
158161
torch.rand(size=(2, 3), dtype=torch.float32),
162+
torch.rand(size=(2, 1), dtype=torch.float32), # test broadcasting
159163
)
160164

161165
self.lower_module_and_test_output(add_module, sample_inputs)

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
549549
std::vector<int64_t> size_big = {12, 64, 64};
550550
std::vector<int64_t> size_small = {12, 64, 64};
551551

552-
// Build graph
552+
// Build graph and regularly check allocation counts
553553

554554
IOValueRef a = graph.add_input_tensor(
555555
size_big,
@@ -560,9 +560,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
560560
api::kFloat,
561561
/*shared_object_idx = */ 4);
562562

563-
// Allocation count will be 6:
564-
// 4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader
565-
// 2: staging buffer for each input tensor
563+
// +4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader
564+
// +2: staging buffer for each input tensor
566565
EXPECT_TRUE(get_vma_allocation_count() == 6);
567566

568567
ValueRef c = graph.add_tensor(
@@ -578,11 +577,10 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
578577
api::kFloat,
579578
/*shared_object_idx = */ 2);
580579

581-
// Allocation count will be 11, 5 are new:
582-
// 2: out.gpu_sizes_ubo(), alpha UBO for arithmetic shader
583-
// 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader
584-
// 1: staging buffer for the input tensor
585-
EXPECT_TRUE(get_vma_allocation_count() == 11);
580+
// +3: out.gpu_sizes_ubo(), alpha UBO, broadcast UBO for arithmetic shader
581+
// +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader
582+
// +1: staging buffer for the input tensor
583+
EXPECT_TRUE(get_vma_allocation_count() == 12);
586584

587585
ValueRef e = graph.add_tensor(
588586
size_big,
@@ -596,18 +594,16 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
596594
out.value = e;
597595
out.staging = graph.set_output_tensor(out.value);
598596

599-
// Allocation count will be 15, 4 are new:
600-
// 1: alpha UBO for arithmetic shader
601-
// 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader
602-
// 1 staging buffer for the input tensor
603-
EXPECT_TRUE(get_vma_allocation_count() == 15);
597+
// +2: alpha UBO, broadcast UBO for arithmetic shader
598+
// +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader
599+
// +1 staging buffer for the input tensor
600+
EXPECT_TRUE(get_vma_allocation_count() == 17);
604601

605602
graph.prepare();
606603
graph.encode_execute();
607604

608-
// Allocation count will be 18, 3 are new:
609-
// 3: shared memory allocations for tensors
610-
EXPECT_TRUE(get_vma_allocation_count() == 18);
605+
// +3: shared memory allocations for tensors
606+
EXPECT_TRUE(get_vma_allocation_count() == 20);
611607

612608
// Run graph
613609

0 commit comments

Comments
 (0)