Skip to content

[MLIR][XeGPU] Add sg_map for scatter verification #124300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2025

Conversation

akroviakov
Copy link
Contributor

This PR adds sg_map to the verification of scatter ops in XeGPU.
The documentation says chunk_size: indicates the number of continuous elements accessed for each offset, it also mentions the fact that scatter ops are SG-level.
Hence, if an operation is distributed to work-items, a 1-d load means a work-item reads one element, a 2-d load means a work-item loads chunk-size or second dimension of tdesc elements. The changes in this PR reflect the documentation with the presence of sg_map attribute (i.e., distributed case).

@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Artem Kroviakov (akroviakov)

Changes

This PR adds sg_map to the verification of scatter ops in XeGPU.
The documentation says chunk_size: indicates the number of continuous elements accessed for each offset, it also mentions the fact that scatter ops are SG-level.
Hence, if an operation is distributed to work-items, a 1-d load means a work-item reads one element, a 2-d load means a work-item loads chunk-size or second dimension of tdesc elements. The changes in this PR reflect the documentation with the presence of sg_map attribute (i.e., distributed case).


Full diff: https://github.com/llvm/llvm-project/pull/124300.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+2-5)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+36-2)
  • (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+58)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index a2bfa721f2515b..c2335eecc3781d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -548,9 +548,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
   let hasVerifier = 1;
 }
 
-def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>,
-                                    AllElementTypesMatch<["value", "TensorDesc"]>,
-                                   AllElementCountsMatch<["value", "TensorDesc"]>]> {
+def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "load a set of scattered data points from memory.";
 
   let description = [{ It (aka. load) load data per each work-item. The output
@@ -620,8 +618,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
   let hasVerifier = 1;
 }
 
-def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
-                                              AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "store data to scattered memory locations.";
   let description = [{ It (aka. store) stores data to scattered memory locations. The value is
   typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 15c435f1fa257b..ff4993688ddf2c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -512,10 +512,27 @@ LogicalResult LoadGatherOp::verify() {
 
   if (tdescTy.getRank() == 2) {
     if (!getTransposeAttr())
-      return emitOpError("load_gather has to be transposed.");
+      return emitOpError("load has to be transposed.");
     transpose({1, 0}, tdescShape);
   }
 
+  if (auto sgMap = tdescTy.getSGMapAttr()) {
+    auto valueVecTy = cast<VectorType>(valueTy);
+    const int32_t wiData =
+        sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
+    if (valueVecTy.getRank() != 1)
+      return emitOpError("Load in SIMT should return a 1D vector.");
+    if (valueVecTy.getDimSize(0) != wiData ||
+        valueVecTy.getDimSize(0) != tdescTy.getChunkSize()) {
+      return emitOpError("Chunk size, vector size and wi_data must match.");
+    }
+    if (tdescTy.getRank() == 1) {
+      tdescShape = {1};
+    } else {
+      tdescShape = {tdescShape[0]};
+    }
+  }
+
   if (valueShape != tdescShape)
     return emitOpError("Unexpected result shape")
            << "(Expected shape: " << makeString(tdescShape)
@@ -551,10 +568,27 @@ LogicalResult StoreScatterOp::verify() {
 
   if (tdescTy.getRank() == 2) {
     if (!getTransposeAttr())
-      return emitOpError("load_gather has to be transposed.");
+      return emitOpError("Store op has to be transposed.");
     transpose({1, 0}, tdescShape);
   }
 
+  if (auto sgMap = tdescTy.getSGMapAttr()) {
+    auto valueVecTy = cast<VectorType>(valueTy);
+    const int32_t wiData =
+        sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
+    if (valueVecTy.getRank() != 1)
+      return emitOpError("Store in SIMT should return a 1D vector.");
+    if (valueVecTy.getDimSize(0) != wiData ||
+        valueVecTy.getDimSize(0) != tdescTy.getChunkSize()) {
+      return emitOpError("Chunk size, vector size and wi_data must match.");
+    }
+    if (tdescTy.getRank() == 1) {
+      tdescShape = {1};
+    } else {
+      tdescShape = {tdescShape[0]};
+    }
+  }
+
   if (valueShape != tdescShape)
     return emitOpError("Unexpected value shape")
            << "(Expected shape: " << makeString(tdescShape)
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index d7174a489888a4..04317690020714 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -168,6 +168,64 @@ gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_with_sg_map(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32> 
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_with_sg_map_2(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32> 
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_with_sg_map(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
+  %2 = arith.constant dense<2.9>: vector<2xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>> 
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_with_sg_map_2(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+  %2 = arith.constant dense<2.9>: vector<1xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>> 
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
+  gpu.return
+}
+
+
+
 // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_prefetch_vc(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

@adam-smnk
Copy link
Contributor

if an operation is distributed to work-items, a 1-d load means a work-item reads one element

I think one work item can still read multiple elements with wi_data = [1, N]? @chencha3

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add also negative test cases?

@akroviakov
Copy link
Contributor Author

akroviakov commented Jan 27, 2025

I think one work item can still read multiple elements with wi_data = [1, N]?

Sure, in this case, according to the documentation, a vector with dimensions <1xN> is returned to a work-item.
transpose attribute means that the result vector is then <Nx1>.

@akroviakov akroviakov force-pushed the xegpu-scatter-verification branch from d35eca3 to debb403 Compare January 28, 2025 12:37
@akroviakov
Copy link
Contributor Author

Added invalid examples and sg_map awareness to create_tdesc, aligned rules with the doc.

@akroviakov akroviakov force-pushed the xegpu-scatter-verification branch from debb403 to 64348e0 Compare January 28, 2025 13:27
@akroviakov akroviakov force-pushed the xegpu-scatter-verification branch from 64348e0 to f5a3a0d Compare January 30, 2025 11:54
Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good already, thanks for the tweaks
Just left a couple of comments regarding readability and testing

@akroviakov akroviakov force-pushed the xegpu-scatter-verification branch from f5a3a0d to 303ee55 Compare January 30, 2025 13:45
@akroviakov akroviakov force-pushed the xegpu-scatter-verification branch from 303ee55 to 8b52a27 Compare January 30, 2025 13:49
@akroviakov akroviakov changed the title [XeGPU] Add sg_map for scatter verification [MLIR][XeGPU] Add sg_map for scatter verification Jan 30, 2025
@adam-smnk adam-smnk merged commit e436bf6 into llvm:main Jan 30, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants