Skip to content

[MLIR][XEGPU] Add blocking support for scatter ops #144766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 18, 2025

Conversation

Jianhui-Li
Copy link
Contributor

Add blocking support for scatter ops: Create_tdesc, update, prefetch, load and store. It also enables the load/store with chunk size.

@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Jianhui Li (Jianhui-Li)

Changes

Add blocking support for scatter ops: Create_tdesc, update, prefetch, load and store. It also enables the load/store with chunk size.


Full diff: https://github.com/llvm/llvm-project/pull/144766.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp (+37-8)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+102-11)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index a3826c56e1f62..02bfab8263847 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -134,11 +134,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
 
 std::optional<SmallVector<int64_t>>
 XeGPUBlockingPass::getTileShape(Operation *op) const {
-  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
+  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
+          xegpu::UpdateOffsetOp>(op))
     return getTileShape(op->getOpResult(0));
-  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
+  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
+          xegpu::LoadGatherOp>(op))
     return getTileShape(op->getOpOperand(0));
-  if (isa<xegpu::StoreNdOp>(op))
+  if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
     return getTileShape(op->getOpOperand(1));
 
   if (isa<xegpu::DpasOp>(op)) {
@@ -295,12 +297,39 @@ void XeGPUBlockingPass::runOnOperation() {
     Type elemTy = type.getElementType();
     Type newTy;
 
-    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
-      newTy = xegpu::TensorDescType::get(
-          ctx, tileShape, elemTy, tdescTy.getEncoding(),
-          tdescTy.getLayoutAttr().dropInstData());
-    else
+    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
+
+      Attribute encoding = tdescTy.getEncoding();
+      // If the encoding is a ScatterTensorDescAttr, we need to
+      // potentially adjust the chunk size based on the inst_data.
+      if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
+        auto scatterAttr =
+            mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+        int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+
+        if (chunkSize > 1) {
+          int64_t blockedChunkSize = chunkSize;
+          auto instData = tdescTy.getLayoutAttr().getInstData();
+          if (!instData.empty())
+            blockedChunkSize = instData.asArrayRef().back();
+
+          auto chunkSizeAttr = mlir::IntegerAttr::get(
+              mlir::IntegerType::get(ctx, 64), blockedChunkSize);
+
+          // To create a new attribute with a different chunk_size:
+          auto newEncoding = xegpu::ScatterTensorDescAttr::get(
+              ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+
+          encoding = newEncoding;
+        }
+      }
+
+      newTy =
+          xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
+                                     tdescTy.getLayoutAttr().dropInstData());
+    } else {
       newTy = type.clone(tileShape, elemTy);
+    }
 
     std::optional<SmallVector<int64_t>> ratio =
         computeShapeRatio(type.getShape(), tileShape);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 67d3bd9b393c0..f977ba3c11bcf 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -250,8 +250,7 @@ gpu.module @test_kernel {
 // -----
 #l = #xegpu.layout<inst_data = [16, 16]>
 #r = #xegpu.layout<inst_data = [16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  {
   gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
     %acc = arith.constant dense<0.0> : vector<64xf32>
     %c64 = arith.constant 64 : index
@@ -271,8 +270,7 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
 // -----
 #l = #xegpu.layout<inst_data = [16, 16]>
 #r = #xegpu.layout<inst_data = [16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel   {
   gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
     %c1 = arith.constant 1 : index
     %c32 = arith.constant 32 : index
@@ -299,8 +297,7 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
 // -----
 #r = #xegpu.layout<inst_data = [16]>
 #l = #xegpu.layout<inst_data = [16, 16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel   {
   gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c64 = arith.constant 64 : index
@@ -319,8 +316,7 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
 // -----
 #r = #xegpu.layout<inst_data = [16]>
 #l = #xegpu.layout<inst_data = [16, 16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  {
   gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c32 = arith.constant 32 : index
@@ -340,8 +336,7 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
 // -----
 #l = #xegpu.layout<inst_data = [16, 8]>
 #t = #xegpu.layout<inst_data = [8, 16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel   {
   gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c32 = arith.constant 32 : index
@@ -355,4 +350,100 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
     xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
     gpu.return
   }
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: test_prefetch_load_store_update
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+   // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+  gpu.func @test_prefetch_load_store_update(%src: ui64)  {
+
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+   
+    %delta = arith.constant dense<[
+    32,   32,  32,  32,  32,  32,  32,  32,
+    32,   32,  32,  32,  32,  32,  32,  64,
+    128, 128, 128, 128, 128, 128, 128, 128,
+    128, 128, 128, 128, 128, 128, 128, 256 
+    ]> : vector<32xindex>
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
+ 
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+    %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+    xegpu.store %st_vec, %tdesc, %mask: 
+                 vector<32xf32>, 
+                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
+                 vector<32xi1>
+  
+    gpu.return
+  }
+  
+}
+
+// -----
+
+gpu.module @test_kernel   {
+  // CHECK-LABEL: test_prefetch_load_store_update_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+   // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+   // CHECK-COUNT-4: xegpu.load  {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
+  // CHECK-COUNT-4: xegpu.store  {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+
+  gpu.func @test_prefetch_load_store_update_chunk(%src: ui64)  {
+
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+   
+    %delta = arith.constant dense<[
+      32,   32,  32,  32,  32,  32,  32,  32,
+      32,   32,  32,  32,  32,  32,  32,  64,
+      128, 128, 128, 128, 128, 128, 128, 128,
+      128, 128, 128, 128, 128, 128, 128, 256 
+    ]> : vector<32xindex>
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+              : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>     
+ 
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+
+    %st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32>
+    xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: 
+                 vector<4x32xf32>, 
+                 !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, 
+                 vector<32xi1>
+  
+    gpu.return
+  }
+}
+
+


// -----
gpu.module @test_kernel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for a gpu.module or gpu.func for testing purposes. we just use func.func. we use gpu.module if some gpu op is necessary for the test case (like thread ids etc). is that the case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test doesn't require gpu op, but currently almost all test file uses this style. We may consider changing it for all test cases in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can someone clarify the diff between xegpu blocking vs unrolling? why do we need 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unrolling is test the pattern so the blocking strategy can be very naive. Xegpu blocking pass built on top of unrolling pattern with sophistic blocking strategy. @chencha3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XeGPU::unrollPattern is following the design of vector::unrollPattern. They provide a very generic interface, verifing targetShape is valid and then unroll the op based on it. xegpu-blocking is built on top of it. It simply focuses on passing the inst_data to the targetShape of unroll pattern, with some logics for robustness.

Attribute encoding = tdescTy.getEncoding();
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use isScattered?

// potentially adjust the chunk size based on the inst_data.
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
auto scatterAttr =
mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use getEncodingAsScatterTensorDescAttr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this read better? I could use this code sequence.

  Attribute encoding = tdescTy.getEncoding();
  // If the encoding is a ScatterTensorDescAttr, we need to
  // potentially adjust the chunk size based on the inst_data.
  if (tdescTy.isScattered()) {
    auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();

gpu.func @test_prefetch_load_store_update(%src: ui64) {

%cst = arith.constant dense<[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens with constants? do the get blocked in this pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
// mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This line can be removed. But I like to use mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding) here, it is clearer about what happened to encoding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change back using "dyn_cast_if_present"

int64_t blockedChunkSize = chunkSize;
auto instData = tdescTy.getLayoutAttr().getInstData();
if (!instData.empty())
blockedChunkSize = instData.asArrayRef().back();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe can reuse the chunkSize variable here, so line 311 can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is written purposely to introduce the variable name: "blockedChunkSize". It makes the code easy to understand: the logic here is to block the chunksize and use the blocked one.


// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there is no need to create chunkSizeAttr, can simply pass blockedChunkSize here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

Copy link

github-actions bot commented Jun 18, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@charithaintc charithaintc merged commit 118bfcd into llvm:main Jun 18, 2025
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants