Skip to content

Commit 9630d7c

Browse files
authored
[MLIR][XeGPU] add blocking support for reduce, broadcast, and transpose (#143389)
This PR adds blocking support for vector dialect operations (`reduce`, `broadcast`, and `transpose`) in the XeGPU based IR. It simply assigned the shape specified by "inst_data" as its target shape of the unrolling to implement the blocking. It is based on #140163.
1 parent bdcbe67 commit 9630d7c

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
169169
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
170170
return getTileShape(op->getOpResult(0));
171171

172+
if (isa<vector::MultiDimReductionOp>(op))
173+
return getTileShape(op->getOpOperand(0));
174+
175+
if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
176+
return getTileShape(op->getOpResult(0));
177+
172178
return std::nullopt;
173179
}
174180

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,113 @@ gpu.module @test_kernel {
246246
gpu.return
247247
}
248248
}
249+
250+
// -----
251+
#l = #xegpu.layout<inst_data = [16, 16]>
252+
#r = #xegpu.layout<inst_data = [16]>
253+
254+
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
255+
gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
256+
%acc = arith.constant dense<0.0> : vector<64xf32>
257+
%c64 = arith.constant 64 : index
258+
%block_id_x = gpu.block_id x
259+
%m = arith.muli %block_id_x, %c64 : index
260+
%0 = xegpu.create_nd_tdesc %a[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
261+
%1 = xegpu.load_nd %0: !xegpu.tensor_desc<16x64xf32, #l> -> vector<16x64xf32>
262+
// CHECK: vector.multi_reduction <add>, {{.*}}, [[ACC:%[0-9A-Za-z]+]] [0] : vector<16x16xf32> to vector<16xf32>
263+
// CHECK-COUNT-3: vector.multi_reduction <add>, {{.*}}, [[ACC]] [0] : vector<16x16xf32> to vector<16xf32>
264+
%2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [0]: vector<16x64xf32> to vector<64xf32>
265+
%3 = xegpu.create_nd_tdesc %b[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
266+
xegpu.store_nd %2, %3: vector<64xf32>, !xegpu.tensor_desc<64xf32, #r>
267+
gpu.return
268+
}
269+
}
270+
271+
// -----
272+
#l = #xegpu.layout<inst_data = [16, 16]>
273+
#r = #xegpu.layout<inst_data = [16]>
274+
275+
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
276+
gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
277+
%c1 = arith.constant 1 : index
278+
%c32 = arith.constant 32 : index
279+
%acc = arith.constant dense<0.0> : vector<32xf32>
280+
281+
%block_id_x = gpu.block_id x
282+
%block_id_y = gpu.block_id y
283+
284+
%m = arith.muli %block_id_x, %c32 : index
285+
%n = arith.muli %block_id_y, %c32 : index
286+
%0 = xegpu.create_nd_tdesc %a[%m, %n] : memref<512x32xf32> -> !xegpu.tensor_desc<32x128xf32, #l>
287+
%1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x128xf32, #l> -> vector<32x128xf32>
288+
289+
// CHECK: vector.multi_reduction <add>, {{.*}}, [[INIT:%[0-9A-Za-z]+]] [1] : vector<16x16xf32> to vector<16xf32>
290+
// CHECK-COUNT-1: vector.multi_reduction <add>, {{.*}}, [[INIT]] [1] : vector<16x16xf32> to vector<16xf32>
291+
292+
%2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [1]: vector<32x128xf32> to vector<32xf32>
293+
%3 = xegpu.create_nd_tdesc %b[%n] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
294+
xegpu.store_nd %2, %3: vector<32xf32>, !xegpu.tensor_desc<32xf32, #r>
295+
gpu.return
296+
}
297+
}
298+
299+
// -----
300+
#r = #xegpu.layout<inst_data = [16]>
301+
#l = #xegpu.layout<inst_data = [16, 16]>
302+
303+
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
304+
gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
305+
306+
%c64 = arith.constant 64 : index
307+
%block_id_x = gpu.block_id x
308+
%m = arith.muli %block_id_x, %c64 : index
309+
%0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
310+
%1 = xegpu.load_nd %0: !xegpu.tensor_desc<64xf32, #r> -> vector<64xf32>
311+
// CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16xf32> to vector<16x16xf32>
312+
%2 = vector.broadcast %1 {layout_result_0 = #l} : vector<64xf32> to vector<16x64xf32>
313+
%3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
314+
xegpu.store_nd %2, %3: vector<16x64xf32>, !xegpu.tensor_desc<16x64xf32, #l>
315+
gpu.return
316+
}
317+
}
318+
319+
// -----
320+
#r = #xegpu.layout<inst_data = [16]>
321+
#l = #xegpu.layout<inst_data = [16, 16]>
322+
323+
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
324+
gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
325+
326+
%c32 = arith.constant 32 : index
327+
%block_id_x = gpu.block_id x
328+
%m = arith.muli %block_id_x, %c32 : index
329+
%0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
330+
%1 = xegpu.load_nd %0: !xegpu.tensor_desc<32xf32, #r> -> vector<32xf32>
331+
%11 = vector.shape_cast %1 : vector<32xf32> to vector<32x1xf32>
332+
// CHECK-COUNT-8: vector.broadcast {{.*}}: vector<16x1xf32> to vector<16x16xf32>
333+
%2 = vector.broadcast %11 {layout_result_0 = #l} : vector<32x1xf32> to vector<32x64xf32>
334+
%3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<32x64xf32, #l>
335+
xegpu.store_nd %2, %3: vector<32x64xf32>, !xegpu.tensor_desc<32x64xf32, #l>
336+
gpu.return
337+
}
338+
}
339+
340+
// -----
341+
#l = #xegpu.layout<inst_data = [16, 8]>
342+
#t = #xegpu.layout<inst_data = [8, 16]>
343+
344+
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
345+
gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
346+
347+
%c32 = arith.constant 32 : index
348+
%block_id_x = gpu.block_id x
349+
%m = arith.muli %block_id_x, %c32 : index
350+
%0 = xegpu.create_nd_tdesc %a[%m, 0] : memref<512x8xf32> -> !xegpu.tensor_desc<32x8xf32, #l>
351+
%1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x8xf32, #l> -> vector<32x8xf32>
352+
// CHECK-COUNT-2: vector.transpose {{.*}} [1, 0] : vector<16x8xf32> to vector<8x16xf32>
353+
%2 = vector.transpose %1, [1, 0] {layout_result_0 = #t} : vector<32x8xf32> to vector<8x32xf32>
354+
%3 = xegpu.create_nd_tdesc %b[0, %m] : memref<8x512xf32> -> !xegpu.tensor_desc<8x32xf32, #t>
355+
xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
356+
gpu.return
357+
}
358+
}

0 commit comments

Comments
 (0)