Skip to content

Commit debb403

Browse files
committed
[XeGPU] Add sg_map for scatter verification
1 parent f8a56df commit debb403

File tree

4 files changed

+155
-10
lines changed

4 files changed

+155
-10
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
548548
let hasVerifier = 1;
549549
}
550550

551-
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>,
552-
AllElementTypesMatch<["value", "TensorDesc"]>,
553-
AllElementCountsMatch<["value", "TensorDesc"]>]> {
551+
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
554552
let summary = "load a set of scattered data points from memory.";
555553

556554
let description = [{ It (aka. load) load data per each work-item. The output
@@ -620,8 +618,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
620618
let hasVerifier = 1;
621619
}
622620

623-
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
624-
AllElementTypesMatch<["value", "TensorDesc"]>]> {
621+
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
625622
let summary = "store data to scattered memory locations.";
626623
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
627624
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,17 @@ LogicalResult CreateDescOp::verify() {
453453
if (shape != tdescShape)
454454
return emitOpError("Incorrect TensorDesc shape. ")
455455
<< "Expected is " << makeString(shape) << "\n";
456-
456+
if (auto sgMap = tdescTy.getSGMapAttr()) {
457+
if (sgMap.getWiData()[0] > 1)
458+
return emitOpError("TensorDesc cannot have wi_data[0] > 1.");
459+
if (chunkSize != static_cast<int>(sgMap.getWiData()[1]))
460+
return emitOpError("TensorDesc's chunkSize must match wi_data[1].");
461+
if (int rank = tdescTy.getRank(); (sgMap.getWiLayout()[2 - rank] == 1))
462+
return emitOpError("TensorDesc of a " + std::to_string(rank) +
463+
"D tensor must have wi_layout[" +
464+
std::to_string(2 - rank) + "] == tdescShape[" +
465+
std::to_string(2 - rank) + "].");
466+
}
457467
return success();
458468
}
459469

@@ -512,10 +522,21 @@ LogicalResult LoadGatherOp::verify() {
512522

513523
if (tdescTy.getRank() == 2) {
514524
if (!getTransposeAttr())
515-
return emitOpError("load_gather has to be transposed.");
525+
return emitOpError("load of rank-2 tensor has to be transposed.");
516526
transpose({1, 0}, tdescShape);
517527
}
518528

529+
if (auto sgMap = tdescTy.getSGMapAttr()) {
530+
auto valueVecTy = cast<VectorType>(valueTy);
531+
const int32_t wiData =
532+
sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
533+
if (valueVecTy.getNumElements() != wiData ||
534+
valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
535+
return emitOpError("Chunk size, vector size and wi_data must match.");
536+
}
537+
tdescShape[tdescTy.getRank() - 1] = 1;
538+
}
539+
519540
if (valueShape != tdescShape)
520541
return emitOpError("Unexpected result shape")
521542
<< "(Expected shape: " << makeString(tdescShape)
@@ -551,10 +572,21 @@ LogicalResult StoreScatterOp::verify() {
551572

552573
if (tdescTy.getRank() == 2) {
553574
if (!getTransposeAttr())
554-
return emitOpError("load_gather has to be transposed.");
575+
return emitOpError("Store of a rank-2 tensor has to be transposed.");
555576
transpose({1, 0}, tdescShape);
556577
}
557578

579+
if (auto sgMap = tdescTy.getSGMapAttr()) {
580+
auto valueVecTy = cast<VectorType>(valueTy);
581+
const int32_t wiData =
582+
sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
583+
if (valueVecTy.getNumElements() != wiData ||
584+
valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
585+
return emitOpError("Chunk size, vector size and wi_data must match.");
586+
}
587+
tdescShape[tdescTy.getRank() - 1] = 1;
588+
}
589+
558590
if (valueShape != tdescShape)
559591
return emitOpError("Unexpected value shape")
560592
<< "(Expected shape: " << makeString(tdescShape)

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,69 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
163163
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
164164
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
165165
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
166-
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
167-
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
166+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
167+
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
168168
gpu.return
169169
}
170170

171+
// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) {
172+
gpu.func @test_load_with_sg_map(%src: ui64) {
173+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
174+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
175+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
176+
%1 = arith.constant dense<1>: vector<4xi1>
177+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
178+
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
179+
//CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
180+
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
181+
gpu.return
182+
}
183+
184+
// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) {
185+
gpu.func @test_load_with_sg_map_2(%src: ui64) {
186+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
187+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
188+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
189+
%1 = arith.constant dense<1>: vector<4xi1>
190+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
191+
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
192+
//CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
193+
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
194+
gpu.return
195+
}
196+
197+
// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
198+
gpu.func @test_store_with_sg_map(%src: ui64) {
199+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
200+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
201+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
202+
%1 = arith.constant dense<1>: vector<4xi1>
203+
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32>
204+
%2 = arith.constant dense<2.9>: vector<2x1xf32>
205+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
206+
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
207+
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
208+
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
209+
gpu.return
210+
}
211+
212+
// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) {
213+
gpu.func @test_store_with_sg_map_2(%src: ui64) {
214+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
215+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
216+
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
217+
%1 = arith.constant dense<1>: vector<4xi1>
218+
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
219+
%2 = arith.constant dense<2.9>: vector<1xf32>
220+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
221+
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
222+
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
223+
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
224+
gpu.return
225+
}
226+
227+
228+
171229
// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
172230
gpu.func @test_prefetch_vc(%src: ui64) {
173231
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,64 @@ func.func @test_prefetch_vc_2(%src: ui64) {
170170
return
171171
}
172172

173+
// -----
174+
func.func @test_create_tdesc_sg_map_1(%src: ui64) {
175+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
176+
// expected-error@+1 {{TensorDesc of a 1D tensor must have wi_layout[1] == tdescShape[1]}}
177+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
178+
return
179+
}
180+
181+
// -----
182+
func.func @test_create_tdesc_sg_map_2(%src: ui64) {
183+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
184+
// expected-error@+1 {{TensorDesc cannot have wi_data[0] > 1}}
185+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [2, 1]>>
186+
return
187+
}
188+
189+
// -----
190+
func.func @test_load_gather_sg_map_1(%src: ui64) {
191+
%0 = arith.constant dense<1>: vector<4xi1>
192+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
193+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
194+
// expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [1, 2])}}
195+
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1x2xf32>
196+
return
197+
}
198+
199+
// -----
200+
func.func @test_load_gather_sg_map_2(%src: ui64) {
201+
%0 = arith.constant dense<1>: vector<4xi1>
202+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
203+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
204+
// expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [2])}}
205+
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
206+
return
207+
}
208+
209+
// -----
210+
func.func @test_store_scatter_sg_map_1(%src: ui64) {
211+
%0 = arith.constant dense<1>: vector<4xi1>
212+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
213+
%val = arith.constant dense<2.9>: vector<1x2xf32>
214+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
215+
// expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [1, 2])}}
216+
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
217+
return
218+
}
219+
220+
// -----
221+
func.func @test_store_scatter_sg_map_2(%src: ui64) {
222+
%0 = arith.constant dense<1>: vector<4xi1>
223+
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
224+
%val = arith.constant dense<2.9>: vector<2xf32>
225+
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
226+
// expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [2])}}
227+
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
228+
return
229+
}
230+
173231
// -----
174232
func.func @test_load_gather_vc_1(%src: memref<24x32xf16>) {
175233
%0 = arith.constant dense<1>: vector<4xi1>

0 commit comments

Comments
 (0)