Skip to content

[MLIR][XEGPU] Add blocking support for scatter ops #144766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can someone clarify the diff between xegpu blocking vs unrolling? why do we need 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unrolling is test the pattern so the blocking strategy can be very naive. Xegpu blocking pass built on top of unrolling pattern with sophistic blocking strategy. @chencha3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XeGPU::unrollPattern is following the design of vector::unrollPattern. They provide a very generic interface, verifing targetShape is valid and then unroll the op based on it. xegpu-blocking is built on top of it. It simply focuses on passing the inst_data to the targetShape of unroll pattern, with some logics for robustness.

Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {

std::optional<SmallVector<int64_t>>
XeGPUBlockingPass::getTileShape(Operation *op) const {
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
xegpu::UpdateOffsetOp>(op))
return getTileShape(op->getOpResult(0));
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
xegpu::LoadGatherOp>(op))
return getTileShape(op->getOpOperand(0));
if (isa<xegpu::StoreNdOp>(op))
if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
return getTileShape(op->getOpOperand(1));

if (isa<xegpu::DpasOp>(op)) {
Expand Down Expand Up @@ -295,12 +297,36 @@ void XeGPUBlockingPass::runOnOperation() {
Type elemTy = type.getElementType();
Type newTy;

if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
newTy = xegpu::TensorDescType::get(
ctx, tileShape, elemTy, tdescTy.getEncoding(),
tdescTy.getLayoutAttr().dropInstData());
else
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {

Attribute encoding = tdescTy.getEncoding();
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
auto scatterAttr =
llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();

if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
auto instData = tdescTy.getLayoutAttr().getInstData();
if (!instData.empty())
blockedChunkSize = instData.asArrayRef().back();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe can reuse the chunkSize variable here, so line 311 can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is written purposely to introduce the variable name: "blockedChunkSize". It makes the code easy to understand: the logic here is to block the chunksize and use the blocked one.


// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);

encoding = newEncoding;
}
}

newTy =
xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
tdescTy.getLayoutAttr().dropInstData());
} else {
newTy = type.clone(tileShape, elemTy);
}

std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
Expand Down
113 changes: 102 additions & 11 deletions mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,7 @@ gpu.module @test_kernel {
// -----
#l = #xegpu.layout<inst_data = [16, 16]>
#r = #xegpu.layout<inst_data = [16]>

gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel {
gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%acc = arith.constant dense<0.0> : vector<64xf32>
%c64 = arith.constant 64 : index
Expand All @@ -271,8 +270,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#l = #xegpu.layout<inst_data = [16, 16]>
#r = #xegpu.layout<inst_data = [16]>

gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel {
gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
Expand All @@ -299,8 +297,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#r = #xegpu.layout<inst_data = [16]>
#l = #xegpu.layout<inst_data = [16, 16]>

gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel {
gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {

%c64 = arith.constant 64 : index
Expand All @@ -319,8 +316,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#r = #xegpu.layout<inst_data = [16]>
#l = #xegpu.layout<inst_data = [16, 16]>

gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel {
gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {

%c32 = arith.constant 32 : index
Expand All @@ -340,8 +336,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#l = #xegpu.layout<inst_data = [16, 8]>
#t = #xegpu.layout<inst_data = [8, 16]>

gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel {
gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {

%c32 = arith.constant 32 : index
Expand All @@ -355,4 +350,100 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
gpu.return
}
}
}

// -----
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for a gpu.module or gpu.func for testing purposes. we just use func.func. we use gpu.module if some gpu op is necessary for the test case (like thread ids etc). is that the case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test doesn't require gpu op, but currently almost all test file uses this style. We may consider changing it for all test cases in a separate PR.

gpu.module @test_kernel {
// CHECK-LABEL: test_prefetch_load_store_update
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>

gpu.func @test_prefetch_load_store_update(%src: ui64) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens with constants? do the get blocked in this pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>

%delta = arith.constant dense<[
32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 64,
128, 128, 128, 128, 128, 128, 128, 128,
128, 128, 128, 128, 128, 128, 128, 256
]> : vector<32xindex>
%new_tdesc = xegpu.update_offset %tdesc, %delta
: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>

%ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>

%st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
xegpu.store %st_vec, %tdesc, %mask:
vector<32xf32>,
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
vector<32xi1>

gpu.return
}

}

// -----

gpu.module @test_kernel {
// CHECK-LABEL: test_prefetch_load_store_update_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
// CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
// CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>

gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) {

%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>

%delta = arith.constant dense<[
32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 64,
128, 128, 128, 128, 128, 128, 128, 128,
128, 128, 128, 128, 128, 128, 128, 256
]> : vector<32xindex>
%new_tdesc = xegpu.update_offset %tdesc, %delta
: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>

%ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>

%st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32>
xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>:
vector<4x32xf32>,
!xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>,
vector<32xi1>

gpu.return
}
}


14 changes: 6 additions & 8 deletions mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ struct TestXeGPUUnrollingPatterns
// attribute
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
Attribute encoding = tdescTy.getEncoding();
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
tdescTy.getLayout());
auto layout = tdescTy.getLayoutAttr();

// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
if (tdescTy.isScattered()) {
auto scatterAttr =
mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(
encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();

if (chunkSize > 1) {
Expand All @@ -118,12 +118,10 @@ struct TestXeGPUUnrollingPatterns
if (!instData.empty())
blockedChunkSize = instData.asArrayRef().back();

auto chunkSizeAttr = mlir::IntegerAttr::get(
mlir::IntegerType::get(ctx, 64), blockedChunkSize);

// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
ctx, scatterAttr.getMemorySpace().getValue(),
blockedChunkSize);

encoding = newEncoding;
}
Expand Down
Loading