Skip to content

Commit 1f16f5e

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Fix BinaryOp broadcasting for packed dim
Summary: As copyrightly pointed out, broadcasting was not working properly for the example below. I root caused the to confusion between `sizes()` vs `gpu_sizes()` once again! These concepts are explained in #2520 We should use the CPU size, not the GPU size to detect when we should broadcast across the packed-dim texel's elements. For example, given inputs `torch.ones(2, 3)` and `torch.ones(2, 1)` and `GPUMemoryLayout::WIDTH_PACKED`, we have CPU widths 3 and 1, respectively. These are aligned up to GPU widths 4 and 4, and hence we were failing to broadcast along the packed-dim texel's elements. ## torch.ones(2, 3) ``` (2, 3) = (H, W) = sizes [[1 1 1] [1 1 1]] -> (W, H) = (3, 2) → (4, 2) = gpu_sizes -> extents = (1, 2) [1 1 1 0] [1 1 1 0] ``` ## torch.ones(2, 1) ``` (2, 1) = (H, W) = sizes [[1] [1]] -> (W, H) = (1, 2) → (4, 2) = gpu_sizes -> extents = (1, 2) [1 0 0 0] [1 0 0 0] -> (broadcast from this change) [1 1 1 1] [1 1 1 1] ``` ## torch.ones(2, 3) + torch.ones(2, 1) Ignore the final element of each texel as it's just padding we never read. ``` No broadcast: [1 1 1 0] [1 1 1 0] + [1 0 0 0] [1 0 0 0] = [2 1 1 0] [2 1 1 0] Broadcast: [1 1 1 0] [1 1 1 0] + [1 1 1 1] [1 1 1 1] = [2 2 2 1] [2 2 2 1] ``` Differential Revision: D55278527
1 parent 6a0a6c7 commit 1f16f5e

File tree

6 files changed

+53
-23
lines changed

6 files changed

+53
-23
lines changed

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes {
3636
}
3737
other_sizes;
3838

39-
layout(set = 0, binding = 6) uniform PRECISION restrict Alpha {
39+
layout(set = 0, binding = 6) uniform PRECISION restrict BroadcastFlag {
40+
bool data;
41+
}
42+
broadcast_flag;
43+
44+
layout(set = 0, binding = 7) uniform PRECISION restrict Alpha {
4045
float data;
4146
}
4247
alpha;
@@ -63,8 +68,7 @@ void main() {
6368
COORD_TO_POS_${PACKING}(other_coord, other_sizes.data),
6469
0));
6570

66-
// Detect broadcasting
67-
if (PACKED_DIM_${PACKING}(other_sizes.data) < PACKED_DIM_${PACKING}(in_sizes.data)) {
71+
if (broadcast_flag.data) {
6872
other_texel = other_texel.xxxx;
6973
}
7074

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ void add_binary_op_node(
7373
alpha_val = extract_scalar<float>(graph.get_val(alpha));
7474
}
7575

76+
const bool broadcast_flag = is_packed_dim_broadcasted(t_in1, t_in2);
77+
7678
std::stringstream kernel_name;
7779
kernel_name << "binary_" << op_name;
7880
apply_memory_layout_suffix(kernel_name, t_out);
@@ -90,6 +92,7 @@ void add_binary_op_node(
9092
{t_out.gpu_sizes_ubo(),
9193
t_in1.gpu_sizes_ubo(),
9294
t_in2.gpu_sizes_ubo(),
95+
graph.create_params_buffer(broadcast_flag),
9396
graph.create_params_buffer(alpha_val)},
9497
// Resizing
9598
resize_binary_op_node));

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,25 @@ bool check_broadcastable(const vTensor& t1, const vTensor& t2) {
9393
}
9494

9595
//
96-
// Work Group Size Calculation Utilities
96+
// Broadcast flag functions
97+
//
98+
99+
bool is_packed_dim_broadcasted(const vTensor& t1, const vTensor& t2) {
100+
switch (t1.gpu_memory_layout()) {
101+
case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
102+
return api::utils::val_at(-3, t1.sizes()) >
103+
api::utils::val_at(-3, t2.sizes());
104+
case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED:
105+
return api::utils::val_at(-2, t1.sizes()) >
106+
api::utils::val_at(-2, t2.sizes());
107+
case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
108+
return api::utils::val_at(-1, t1.sizes()) >
109+
api::utils::val_at(-1, t2.sizes());
110+
}
111+
}
112+
113+
//
114+
// Work group size calculation functions
97115
//
98116

99117
api::utils::uvec3 adaptive_work_group_size(

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ bool check_same_memory_layout(
5050
bool check_broadcastable(const vTensor& t1, const vTensor& t2);
5151

5252
//
53-
// Work Group Size Calculation Utilities
53+
// Broadcast flag functions
54+
//
55+
56+
bool is_packed_dim_broadcasted(const vTensor& t1, const vTensor& t2);
57+
58+
//
59+
// Work group size calculation functions
5460
//
5561

5662
api::utils::uvec3 adaptive_work_group_size(

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,19 @@ class AddModule(torch.nn.Module):
146146
def __init__(self):
147147
super().__init__()
148148

149-
def forward(self, x, y):
149+
def forward(self, x, y, w):
150150
z = x + y
151151
z = z + x
152152
z = z + x
153+
z = z + w
154+
z = z + 3 # test scalar broadcasting
153155
return z
154156

155157
add_module = AddModule()
156158
sample_inputs = (
157159
torch.rand(size=(2, 3), dtype=torch.float32),
158160
torch.rand(size=(2, 3), dtype=torch.float32),
161+
torch.rand(size=(2, 1), dtype=torch.float32), # test broadcasting
159162
)
160163

161164
self.lower_module_and_test_output(add_module, sample_inputs)

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
549549
std::vector<int64_t> size_big = {12, 64, 64};
550550
std::vector<int64_t> size_small = {12, 64, 64};
551551

552-
// Build graph
552+
// Build graph and regularly check allocation counts
553553

554554
IOValueRef a = graph.add_input_tensor(
555555
size_big,
@@ -560,9 +560,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
560560
api::kFloat,
561561
/*shared_object_idx = */ 4);
562562

563-
// Allocation count will be 6:
564-
// 4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader
565-
// 2: staging buffer for each input tensor
563+
// +4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader
564+
// +2: staging buffer for each input tensor
566565
EXPECT_TRUE(get_vma_allocation_count() == 6);
567566

568567
ValueRef c = graph.add_tensor(
@@ -578,11 +577,10 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
578577
api::kFloat,
579578
/*shared_object_idx = */ 2);
580579

581-
// Allocation count will be 11, 5 are new:
582-
// 2: out.gpu_sizes_ubo(), alpha UBO for arithmetic shader
583-
// 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader
584-
// 1: staging buffer for the input tensor
585-
EXPECT_TRUE(get_vma_allocation_count() == 11);
580+
// +3: out.gpu_sizes_ubo(), alpha UBO, broadcast UBO for arithmetic shader
581+
// +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader
582+
// +1: staging buffer for the input tensor
583+
EXPECT_TRUE(get_vma_allocation_count() == 12);
586584

587585
ValueRef e = graph.add_tensor(
588586
size_big,
@@ -596,18 +594,16 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
596594
out.value = e;
597595
out.staging = graph.set_output_tensor(out.value);
598596

599-
// Allocation count will be 15, 4 are new:
600-
// 1: alpha UBO for arithmetic shader
601-
// 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader
602-
// 1 staging buffer for the input tensor
603-
EXPECT_TRUE(get_vma_allocation_count() == 15);
597+
// +2: alpha UBO, broadcast UBO for arithmetic shader
598+
// +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader
599+
// +1 staging buffer for the input tensor
600+
EXPECT_TRUE(get_vma_allocation_count() == 17);
604601

605602
graph.prepare();
606603
graph.encode_execute();
607604

608-
// Allocation count will be 18, 3 are new:
609-
// 3: shared memory allocations for tensors
610-
EXPECT_TRUE(get_vma_allocation_count() == 18);
605+
// +3: shared memory allocations for tensors
606+
EXPECT_TRUE(get_vma_allocation_count() == 20);
611607

612608
// Run graph
613609

0 commit comments

Comments
 (0)