Skip to content

Commit 118bfcd

Browse files
authored
[MLIR][XEGPU] Add blocking support for scatter ops (#144766)
Add blocking support for scatter ops: Create_tdesc, update, prefetch, load and store. It also enables the load/store with chunk size.
1 parent 51aa6a4 commit 118bfcd

File tree

3 files changed

+142
-27
lines changed

3 files changed

+142
-27
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
134134

135135
std::optional<SmallVector<int64_t>>
136136
XeGPUBlockingPass::getTileShape(Operation *op) const {
137-
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
137+
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
138+
xegpu::UpdateOffsetOp>(op))
138139
return getTileShape(op->getOpResult(0));
139-
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
140+
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
141+
xegpu::LoadGatherOp>(op))
140142
return getTileShape(op->getOpOperand(0));
141-
if (isa<xegpu::StoreNdOp>(op))
143+
if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
142144
return getTileShape(op->getOpOperand(1));
143145

144146
if (isa<xegpu::DpasOp>(op)) {
@@ -295,12 +297,36 @@ void XeGPUBlockingPass::runOnOperation() {
295297
Type elemTy = type.getElementType();
296298
Type newTy;
297299

298-
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
299-
newTy = xegpu::TensorDescType::get(
300-
ctx, tileShape, elemTy, tdescTy.getEncoding(),
301-
tdescTy.getLayoutAttr().dropInstData());
302-
else
300+
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
301+
302+
Attribute encoding = tdescTy.getEncoding();
303+
// If the encoding is a ScatterTensorDescAttr, we need to
304+
// potentially adjust the chunk size based on the inst_data.
305+
if (tdescTy.isScattered()) {
306+
auto scatterAttr =
307+
llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
308+
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
309+
310+
if (chunkSize > 1) {
311+
int64_t blockedChunkSize = chunkSize;
312+
auto instData = tdescTy.getLayoutAttr().getInstData();
313+
if (!instData.empty())
314+
blockedChunkSize = instData.asArrayRef().back();
315+
316+
// To create a new attribute with a different chunk_size:
317+
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
318+
ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
319+
320+
encoding = newEncoding;
321+
}
322+
}
323+
324+
newTy =
325+
xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
326+
tdescTy.getLayoutAttr().dropInstData());
327+
} else {
303328
newTy = type.clone(tileShape, elemTy);
329+
}
304330

305331
std::optional<SmallVector<int64_t>> ratio =
306332
computeShapeRatio(type.getShape(), tileShape);

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 102 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,7 @@ gpu.module @test_kernel {
250250
// -----
251251
#l = #xegpu.layout<inst_data = [16, 16]>
252252
#r = #xegpu.layout<inst_data = [16]>
253-
254-
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
253+
gpu.module @test_kernel {
255254
gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
256255
%acc = arith.constant dense<0.0> : vector<64xf32>
257256
%c64 = arith.constant 64 : index
@@ -271,8 +270,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
271270
// -----
272271
#l = #xegpu.layout<inst_data = [16, 16]>
273272
#r = #xegpu.layout<inst_data = [16]>
274-
275-
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
273+
gpu.module @test_kernel {
276274
gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
277275
%c1 = arith.constant 1 : index
278276
%c32 = arith.constant 32 : index
@@ -299,8 +297,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
299297
// -----
300298
#r = #xegpu.layout<inst_data = [16]>
301299
#l = #xegpu.layout<inst_data = [16, 16]>
302-
303-
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
300+
gpu.module @test_kernel {
304301
gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
305302

306303
%c64 = arith.constant 64 : index
@@ -319,8 +316,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
319316
// -----
320317
#r = #xegpu.layout<inst_data = [16]>
321318
#l = #xegpu.layout<inst_data = [16, 16]>
322-
323-
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
319+
gpu.module @test_kernel {
324320
gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
325321

326322
%c32 = arith.constant 32 : index
@@ -340,8 +336,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
340336
// -----
341337
#l = #xegpu.layout<inst_data = [16, 8]>
342338
#t = #xegpu.layout<inst_data = [8, 16]>
343-
344-
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
339+
gpu.module @test_kernel {
345340
gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
346341

347342
%c32 = arith.constant 32 : index
@@ -355,4 +350,100 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
355350
xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
356351
gpu.return
357352
}
358-
}
353+
}
354+
355+
// -----
356+
gpu.module @test_kernel {
357+
// CHECK-LABEL: test_prefetch_load_store_update
358+
// CHECK-SAME: [[arg0:%.+]]: ui64
359+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
360+
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
361+
// CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
362+
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
363+
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
364+
365+
gpu.func @test_prefetch_load_store_update(%src: ui64) {
366+
367+
%cst = arith.constant dense<[
368+
0, 8, 16, 24, 32, 40, 48, 56,
369+
64, 72, 80, 88, 96, 104, 112, 120,
370+
128, 136, 144, 152, 160, 168, 176, 184,
371+
192, 200, 208, 216, 224, 232, 240, 248
372+
]> : vector<32xindex>
373+
374+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
375+
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
376+
377+
%delta = arith.constant dense<[
378+
32, 32, 32, 32, 32, 32, 32, 32,
379+
32, 32, 32, 32, 32, 32, 32, 64,
380+
128, 128, 128, 128, 128, 128, 128, 128,
381+
128, 128, 128, 128, 128, 128, 128, 256
382+
]> : vector<32xindex>
383+
%new_tdesc = xegpu.update_offset %tdesc, %delta
384+
: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
385+
386+
%c17 = arith.constant 17: index
387+
%mask = vector.create_mask %c17: vector<32xi1>
388+
389+
%ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
390+
391+
%st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
392+
xegpu.store %st_vec, %tdesc, %mask:
393+
vector<32xf32>,
394+
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
395+
vector<32xi1>
396+
397+
gpu.return
398+
}
399+
400+
}
401+
402+
// -----
403+
404+
gpu.module @test_kernel {
405+
// CHECK-LABEL: test_prefetch_load_store_update_chunk
406+
// CHECK-SAME: [[arg0:%.+]]: ui64
407+
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
408+
// CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
409+
// CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
410+
// CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
411+
// CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
412+
413+
gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) {
414+
415+
%cst = arith.constant dense<[
416+
0, 8, 16, 24, 32, 40, 48, 56,
417+
64, 72, 80, 88, 96, 104, 112, 120,
418+
128, 136, 144, 152, 160, 168, 176, 184,
419+
192, 200, 208, 216, 224, 232, 240, 248
420+
]> : vector<32xindex>
421+
422+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
423+
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
424+
425+
%delta = arith.constant dense<[
426+
32, 32, 32, 32, 32, 32, 32, 32,
427+
32, 32, 32, 32, 32, 32, 32, 64,
428+
128, 128, 128, 128, 128, 128, 128, 128,
429+
128, 128, 128, 128, 128, 128, 128, 256
430+
]> : vector<32xindex>
431+
%new_tdesc = xegpu.update_offset %tdesc, %delta
432+
: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
433+
434+
%c17 = arith.constant 17: index
435+
%mask = vector.create_mask %c17: vector<32xi1>
436+
437+
%ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
438+
439+
%st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32>
440+
xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>:
441+
vector<4x32xf32>,
442+
!xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>,
443+
vector<32xi1>
444+
445+
gpu.return
446+
}
447+
}
448+
449+

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,14 @@ struct TestXeGPUUnrollingPatterns
102102
// attribute
103103
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
104104
Attribute encoding = tdescTy.getEncoding();
105-
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
106-
tdescTy.getLayout());
105+
auto layout = tdescTy.getLayoutAttr();
107106

108107
// If the encoding is a ScatterTensorDescAttr, we need to
109108
// potentially adjust the chunk size based on the inst_data.
110-
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
109+
if (tdescTy.isScattered()) {
111110
auto scatterAttr =
112-
mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
111+
llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(
112+
encoding);
113113
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
114114

115115
if (chunkSize > 1) {
@@ -118,12 +118,10 @@ struct TestXeGPUUnrollingPatterns
118118
if (!instData.empty())
119119
blockedChunkSize = instData.asArrayRef().back();
120120

121-
auto chunkSizeAttr = mlir::IntegerAttr::get(
122-
mlir::IntegerType::get(ctx, 64), blockedChunkSize);
123-
124121
// To create a new attribute with a different chunk_size:
125122
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
126-
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
123+
ctx, scatterAttr.getMemorySpace().getValue(),
124+
blockedChunkSize);
127125

128126
encoding = newEncoding;
129127
}

0 commit comments

Comments
 (0)