Skip to content

Commit 8b52a27

Browse files
committed
[MLIR][XeGPU] Add sg_map for scatter verification
1 parent f8a56df commit 8b52a27

File tree

4 files changed

+183
-10
lines changed

4 files changed

+183
-10
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
548548
let hasVerifier = 1;
549549
}
550550

551-
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>,
552-
AllElementTypesMatch<["value", "TensorDesc"]>,
553-
AllElementCountsMatch<["value", "TensorDesc"]>]> {
551+
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
554552
let summary = "load a set of scattered data points from memory.";
555553

556554
let description = [{ It (aka. load) load data per each work-item. The output
@@ -620,8 +618,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
620618
let hasVerifier = 1;
621619
}
622620

623-
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
624-
AllElementTypesMatch<["value", "TensorDesc"]>]> {
621+
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
625622
let summary = "store data to scattered memory locations.";
626623
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
627624
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,22 @@ LogicalResult CreateDescOp::verify() {
453453
if (shape != tdescShape)
454454
return emitOpError("Incorrect TensorDesc shape. ")
455455
<< "Expected is " << makeString(shape) << "\n";
456-
456+
if (auto sgMap = tdescTy.getSGMapAttr()) {
457+
// A work-item's slice of the TensorDesc with shape [sg_size] or
458+
// [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
459+
// the mapping should reflect that.
460+
if (sgMap.getWiData()[0] > 1)
461+
return emitOpError("TensorDesc's SG map only supports multiple elements "
462+
"contiguous along rows.");
463+
if (chunkSize != static_cast<int>(sgMap.getWiData()[1]))
464+
return emitOpError(
465+
"TensorDesc's chunkSize must match WI's data mapping.");
466+
if (int rank = tdescTy.getRank();
467+
(sgMap.getWiLayout()[2 - rank] != tdescShape[0]))
468+
return emitOpError("Detected a conflict between SG map's work-item "
469+
"layout and TensorDesc shape. Check the index of "
470+
"`subgroup_size` in WI layout map.");
471+
}
457472
return success();
458473
}
459474

@@ -512,10 +527,23 @@ LogicalResult LoadGatherOp::verify() {
512527

513528
if (tdescTy.getRank() == 2) {
514529
if (!getTransposeAttr())
515-
return emitOpError("load_gather has to be transposed.");
530+
return emitOpError("load of rank-2 tensor has to be transposed.");
516531
transpose({1, 0}, tdescShape);
517532
}
518533

534+
if (auto sgMap = tdescTy.getSGMapAttr()) {
535+
auto valueVecTy = cast<VectorType>(valueTy);
536+
const int32_t wiData =
537+
sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
538+
// All represent the same concept: a number of row elements to store.
539+
if (valueVecTy.getNumElements() != wiData ||
540+
valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
541+
return emitOpError("Chunk size, vector size and wi_data must match.");
542+
}
543+
// Work-item's slice (i.e., vector shape to load) is [1] or [1, chunk_size].
544+
tdescShape[tdescTy.getRank() - 1] = 1;
545+
}
546+
519547
if (valueShape != tdescShape)
520548
return emitOpError("Unexpected result shape")
521549
<< "(Expected shape: " << makeString(tdescShape)
@@ -551,10 +579,23 @@ LogicalResult StoreScatterOp::verify() {
551579

552580
if (tdescTy.getRank() == 2) {
553581
if (!getTransposeAttr())
554-
return emitOpError("load_gather has to be transposed.");
582+
return emitOpError("Store of a rank-2 tensor has to be transposed.");
555583
transpose({1, 0}, tdescShape);
556584
}
557585

586+
if (auto sgMap = tdescTy.getSGMapAttr()) {
587+
auto valueVecTy = cast<VectorType>(valueTy);
588+
const int32_t wiData =
589+
sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
590+
// All represent the same concept: a number of row elements to store.
591+
if (valueVecTy.getNumElements() != wiData ||
592+
valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
593+
return emitOpError("Chunk size, vector size and wi_data must match.");
594+
}
595+
// Work-item's slice (i.e., vector to store) is [1] or [1, chunk_size].
596+
tdescShape[tdescTy.getRank() - 1] = 1;
597+
}
598+
558599
if (valueShape != tdescShape)
559600
return emitOpError("Unexpected value shape")
560601
<< "(Expected shape: " << makeString(tdescShape)

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,69 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
163163
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
164164
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
165165
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
166-
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
167-
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
166+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
167+
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
168168
gpu.return
169169
}
170170

171+
// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) {
172+
gpu.func @test_load_with_sg_map(%src: ui64) {
173+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
174+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
175+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
176+
%1 = arith.constant dense<1>: vector<4xi1>
177+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
178+
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
179+
//CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
180+
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
181+
gpu.return
182+
}
183+
184+
// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) {
185+
gpu.func @test_load_with_sg_map_2(%src: ui64) {
186+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
187+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
188+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
189+
%1 = arith.constant dense<1>: vector<4xi1>
190+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
191+
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
192+
//CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
193+
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
194+
gpu.return
195+
}
196+
197+
// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
198+
gpu.func @test_store_with_sg_map(%src: ui64) {
199+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
200+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
201+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
202+
%1 = arith.constant dense<1>: vector<4xi1>
203+
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32>
204+
%2 = arith.constant dense<2.9>: vector<2x1xf32>
205+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
206+
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
207+
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
208+
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
209+
gpu.return
210+
}
211+
212+
// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) {
213+
gpu.func @test_store_with_sg_map_2(%src: ui64) {
214+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
215+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
216+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
217+
%1 = arith.constant dense<1>: vector<4xi1>
218+
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
219+
%2 = arith.constant dense<2.9>: vector<1xf32>
220+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
221+
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
222+
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
223+
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
224+
gpu.return
225+
}
226+
227+
228+
171229
// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
172230
gpu.func @test_prefetch_vc(%src: ui64) {
173231
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,83 @@ func.func @test_prefetch_vc_2(%src: ui64) {
170170
return
171171
}
172172

173+
// -----
174+
func.func @test_create_tdesc_sg_map_1(%src: ui64) {
175+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
176+
// expected-error@+1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}}
177+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
178+
return
179+
}
180+
181+
// -----
182+
func.func @test_create_tdesc_sg_map_2(%src: ui64) {
183+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
184+
// expected-error@+1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}}
185+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [2, 1]>>
186+
return
187+
}
188+
189+
// -----
190+
func.func @test_create_tdesc_sg_map_3(%src: ui64) {
191+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
192+
// expected-error@+1 {{TensorDesc's chunkSize must match WI's data mapping}}
193+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
194+
return
195+
}
196+
197+
// -----
198+
func.func @test_load_gather_sg_map_1(%src: ui64) {
199+
%0 = arith.constant dense<1>: vector<4xi1>
200+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
201+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
202+
// expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [1, 2])}}
203+
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1x2xf32>
204+
return
205+
}
206+
207+
// -----
208+
func.func @test_load_gather_sg_map_2(%src: ui64) {
209+
%0 = arith.constant dense<1>: vector<4xi1>
210+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
211+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
212+
// expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [2])}}
213+
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
214+
return
215+
}
216+
217+
// -----
218+
func.func @test_load_gather_sg_map_3(%src: ui64) {
219+
%0 = arith.constant dense<1>: vector<4xi1>
220+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
221+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
222+
// expected-error@+1 {{Chunk size, vector size and wi_data must match}}
223+
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1xf32>
224+
return
225+
}
226+
227+
228+
// -----
229+
func.func @test_store_scatter_sg_map_1(%src: ui64) {
230+
%0 = arith.constant dense<1>: vector<4xi1>
231+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
232+
%val = arith.constant dense<2.9>: vector<1x2xf32>
233+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
234+
// expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [1, 2])}}
235+
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
236+
return
237+
}
238+
239+
// -----
240+
func.func @test_store_scatter_sg_map_2(%src: ui64) {
241+
%0 = arith.constant dense<1>: vector<4xi1>
242+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
243+
%val = arith.constant dense<2.9>: vector<2xf32>
244+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
245+
// expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [2])}}
246+
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
247+
return
248+
}
249+
173250
// -----
174251
func.func @test_load_gather_vc_1(%src: memref<24x32xf16>) {
175252
%0 = arith.constant dense<1>: vector<4xi1>

0 commit comments

Comments
 (0)