Skip to content

Commit d35eca3

Browse files
committed
[XeGPU] Add sg_map for scatter verification
1 parent f8a56df commit d35eca3

File tree

3 files changed

+96
-7
lines changed

3 files changed

+96
-7
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
548548
let hasVerifier = 1;
549549
}
550550

551-
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>,
552-
AllElementTypesMatch<["value", "TensorDesc"]>,
553-
AllElementCountsMatch<["value", "TensorDesc"]>]> {
551+
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
554552
let summary = "load a set of scattered data points from memory.";
555553

556554
let description = [{ It (aka. load) load data per each work-item. The output
@@ -620,8 +618,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
620618
let hasVerifier = 1;
621619
}
622620

623-
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
624-
AllElementTypesMatch<["value", "TensorDesc"]>]> {
621+
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
625622
let summary = "store data to scattered memory locations.";
626623
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
627624
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,10 +512,27 @@ LogicalResult LoadGatherOp::verify() {
512512

513513
if (tdescTy.getRank() == 2) {
514514
if (!getTransposeAttr())
515-
return emitOpError("load_gather has to be transposed.");
515+
return emitOpError("load has to be transposed.");
516516
transpose({1, 0}, tdescShape);
517517
}
518518

519+
if (auto sgMap = tdescTy.getSGMapAttr()) {
520+
auto valueVecTy = cast<VectorType>(valueTy);
521+
const int32_t wiData =
522+
sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
523+
if (valueVecTy.getRank() != 1)
524+
return emitOpError("Load in SIMT should return a 1D vector.");
525+
if (valueVecTy.getDimSize(0) != wiData ||
526+
valueVecTy.getDimSize(0) != tdescTy.getChunkSize()) {
527+
return emitOpError("Chunk size, vector size and wi_data must match.");
528+
}
529+
if (tdescTy.getRank() == 1) {
530+
tdescShape = {1};
531+
} else {
532+
tdescShape = {tdescShape[0]};
533+
}
534+
}
535+
519536
if (valueShape != tdescShape)
520537
return emitOpError("Unexpected result shape")
521538
<< "(Expected shape: " << makeString(tdescShape)
@@ -551,10 +568,27 @@ LogicalResult StoreScatterOp::verify() {
551568

552569
if (tdescTy.getRank() == 2) {
553570
if (!getTransposeAttr())
554-
return emitOpError("load_gather has to be transposed.");
571+
return emitOpError("Store op has to be transposed.");
555572
transpose({1, 0}, tdescShape);
556573
}
557574

575+
if (auto sgMap = tdescTy.getSGMapAttr()) {
576+
auto valueVecTy = cast<VectorType>(valueTy);
577+
const int32_t wiData =
578+
sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
579+
if (valueVecTy.getRank() != 1)
580+
return emitOpError("Store in SIMT should return a 1D vector.");
581+
if (valueVecTy.getDimSize(0) != wiData ||
582+
valueVecTy.getDimSize(0) != tdescTy.getChunkSize()) {
583+
return emitOpError("Chunk size, vector size and wi_data must match.");
584+
}
585+
if (tdescTy.getRank() == 1) {
586+
tdescShape = {1};
587+
} else {
588+
tdescShape = {tdescShape[0]};
589+
}
590+
}
591+
558592
if (valueShape != tdescShape)
559593
return emitOpError("Unexpected value shape")
560594
<< "(Expected shape: " << makeString(tdescShape)

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,64 @@ gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
168168
gpu.return
169169
}
170170

171+
// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) {
172+
gpu.func @test_load_with_sg_map(%src: ui64) {
173+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
174+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
175+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
176+
%1 = arith.constant dense<1>: vector<4xi1>
177+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
178+
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
179+
//CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
180+
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
181+
gpu.return
182+
}
183+
184+
// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) {
185+
gpu.func @test_load_with_sg_map_2(%src: ui64) {
186+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
187+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
188+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
189+
%1 = arith.constant dense<1>: vector<4xi1>
190+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
191+
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
192+
//CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
193+
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
194+
gpu.return
195+
}
196+
197+
// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
198+
gpu.func @test_store_with_sg_map(%src: ui64) {
199+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
200+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
201+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
202+
%1 = arith.constant dense<1>: vector<4xi1>
203+
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
204+
%2 = arith.constant dense<2.9>: vector<2xf32>
205+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
206+
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
207+
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
208+
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
209+
gpu.return
210+
}
211+
212+
// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) {
213+
gpu.func @test_store_with_sg_map_2(%src: ui64) {
214+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
215+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
216+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
217+
%1 = arith.constant dense<1>: vector<4xi1>
218+
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
219+
%2 = arith.constant dense<2.9>: vector<1xf32>
220+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
221+
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
222+
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
223+
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
224+
gpu.return
225+
}
226+
227+
228+
171229
// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
172230
gpu.func @test_prefetch_vc(%src: ui64) {
173231
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

0 commit comments

Comments
 (0)