-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][XeGPU] Add sg_map for scatter verification #124300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Artem Kroviakov (akroviakov) ChangesThis PR adds sg_map to the verification of scatter ops in XeGPU. Full diff: https://github.com/llvm/llvm-project/pull/124300.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index a2bfa721f2515b..c2335eecc3781d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -548,9 +548,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
let hasVerifier = 1;
}
-def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>,
- AllElementTypesMatch<["value", "TensorDesc"]>,
- AllElementCountsMatch<["value", "TensorDesc"]>]> {
+def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
let summary = "load a set of scattered data points from memory.";
let description = [{ It (aka. load) load data per each work-item. The output
@@ -620,8 +618,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
let hasVerifier = 1;
}
-def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
- AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 15c435f1fa257b..ff4993688ddf2c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -512,10 +512,27 @@ LogicalResult LoadGatherOp::verify() {
if (tdescTy.getRank() == 2) {
if (!getTransposeAttr())
- return emitOpError("load_gather has to be transposed.");
+ return emitOpError("load has to be transposed.");
transpose({1, 0}, tdescShape);
}
+ if (auto sgMap = tdescTy.getSGMapAttr()) {
+ auto valueVecTy = cast<VectorType>(valueTy);
+ const int32_t wiData =
+ sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
+ if (valueVecTy.getRank() != 1)
+ return emitOpError("Load in SIMT should return a 1D vector.");
+ if (valueVecTy.getDimSize(0) != wiData ||
+ valueVecTy.getDimSize(0) != tdescTy.getChunkSize()) {
+ return emitOpError("Chunk size, vector size and wi_data must match.");
+ }
+ if (tdescTy.getRank() == 1) {
+ tdescShape = {1};
+ } else {
+ tdescShape = {tdescShape[0]};
+ }
+ }
+
if (valueShape != tdescShape)
return emitOpError("Unexpected result shape")
<< "(Expected shape: " << makeString(tdescShape)
@@ -551,10 +568,27 @@ LogicalResult StoreScatterOp::verify() {
if (tdescTy.getRank() == 2) {
if (!getTransposeAttr())
- return emitOpError("load_gather has to be transposed.");
+ return emitOpError("Store op has to be transposed.");
transpose({1, 0}, tdescShape);
}
+ if (auto sgMap = tdescTy.getSGMapAttr()) {
+ auto valueVecTy = cast<VectorType>(valueTy);
+ const int32_t wiData =
+ sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
+ if (valueVecTy.getRank() != 1)
+ return emitOpError("Store in SIMT should return a 1D vector.");
+ if (valueVecTy.getDimSize(0) != wiData ||
+ valueVecTy.getDimSize(0) != tdescTy.getChunkSize()) {
+ return emitOpError("Chunk size, vector size and wi_data must match.");
+ }
+ if (tdescTy.getRank() == 1) {
+ tdescShape = {1};
+ } else {
+ tdescShape = {tdescShape[0]};
+ }
+ }
+
if (valueShape != tdescShape)
return emitOpError("Unexpected value shape")
<< "(Expected shape: " << makeString(tdescShape)
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index d7174a489888a4..04317690020714 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -168,6 +168,64 @@ gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_with_sg_map(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_with_sg_map_2(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_with_sg_map(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
+ %2 = arith.constant dense<2.9>: vector<2xf32>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_with_sg_map_2(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+ %2 = arith.constant dense<2.9>: vector<1xf32>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
+ gpu.return
+}
+
+
+
// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
gpu.func @test_prefetch_vc(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
|
I think one work item can still read multiple elements with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add also negative test cases?
Sure, in this case, according to the documentation, a vector with dimensions |
d35eca3
to
debb403
Compare
Added invalid examples and |
debb403
to
64348e0
Compare
64348e0
to
f5a3a0d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good already, thanks for the tweaks
Just left a couple of comments regarding readability and testing
f5a3a0d
to
303ee55
Compare
303ee55
to
8b52a27
Compare
This PR adds sg_map to the verification of scatter ops in XeGPU.
The documentation says
chunk_size: indicates the number of continuous elements accessed for each offset
, it also mentions the fact that scatter ops are SG-level.Hence, if an operation is distributed to work-items, a 1-d load means a work-item reads one element, a 2-d load means a work-item loads
chunk-size
or second dimension of tdesc elements. The changes in this PR reflect the documentation with the presence of sg_map attribute (i.e., distributed case).