Skip to content

Commit 49ea091

Browse files
committed
[MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for the distributed IR case.
1 parent eaf482f commit 49ea091

File tree

4 files changed

+84
-28
lines changed

4 files changed

+84
-28
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
327327
let hasVerifier = 1;
328328
}
329329

330-
def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
331-
AllElementTypesMatch<["value", "TensorDesc"]>]> {
330+
def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
332331
let summary = "stores a n-D block register region back to memory, currently only supports 2D";
333332

334333
let description = [{

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,29 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
7373
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
7474
}
7575

76+
// Validations for nd instruction arguments is successful if any of these are
77+
// true:
78+
// - tensor descriptor and the output vector shapes exactly match.
79+
// - tensor descriptor has a sg_map attribute and the distributed vector shape
80+
// matches the tensor descriptor shape when scaled using sg_map factors on
81+
// each dimension.
82+
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
83+
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
84+
if (descShape == valShape)
85+
return true;
86+
87+
if (!sgMap)
88+
return false;
89+
90+
for (const auto &[factor, dim, expected] :
91+
llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
92+
if (factor * dim != expected)
93+
return false;
94+
}
95+
96+
return true;
97+
}
98+
7699
//===----------------------------------------------------------------------===//
77100
// XeGPU_CreateNdDescOp
78101
//===----------------------------------------------------------------------===//
@@ -210,13 +233,13 @@ LogicalResult PrefetchNdOp::verify() {
210233
return emitOpError("Expects a non-scattered TensorDesc.\n");
211234

212235
if (!isReadHintOrNone(getL1HintAttr()))
213-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
236+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
214237

215238
if (!isReadHintOrNone(getL2HintAttr()))
216-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
239+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
217240

218241
if (!isReadHintOrNone(getL3HintAttr()))
219-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
242+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
220243

221244
return success();
222245
}
@@ -238,13 +261,13 @@ LogicalResult LoadNdOp::verify() {
238261
return emitOpError("Invalid result, it should be a VectorType.\n");
239262

240263
if (!isReadHintOrNone(getL1HintAttr()))
241-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
264+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
242265

243266
if (!isReadHintOrNone(getL2HintAttr()))
244-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
267+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
245268

246269
if (!isReadHintOrNone(getL3HintAttr()))
247-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
270+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
248271

249272
auto array_len = tdescTy.getArrayLength();
250273
auto tdescShape = getShapeOf(tdescTy);
@@ -280,8 +303,9 @@ LogicalResult LoadNdOp::verify() {
280303
auto it = tdescShape.begin();
281304
tdescShape.insert(it, array_len);
282305
}
306+
auto sgMap = tdescTy.getSGMapAttr();
283307

284-
if (tdescShape != valueShape)
308+
if (!isArgShapesValid(tdescShape, valueShape, sgMap))
285309
return emitOpError() << "Result shape doesn't match TensorDesc shape."
286310
<< "The expected shape is " << makeString(tdescShape)
287311
<< ". But the given shape is "
@@ -303,17 +327,26 @@ LogicalResult StoreNdOp::verify() {
303327
return emitOpError("Expects a non-scattered TensorDesc.\n");
304328

305329
if (!valTy)
306-
return emitOpError("Exepcting a VectorType result.\n");
330+
return emitOpError("Expecting a VectorType result.\n");
307331

308332
if (!isWriteHintOrNone(getL1HintAttr()))
309-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
333+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
310334

311335
if (!isWriteHintOrNone(getL2HintAttr()))
312-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
336+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
313337

314338
if (!isWriteHintOrNone(getL3HintAttr()))
315-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
339+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
340+
341+
auto tdescShape = getShapeOf(dstTy);
342+
auto valueShape = getShapeOf(valTy);
343+
auto sgMap = dstTy.getSGMapAttr();
316344

345+
if (!isArgShapesValid(tdescShape, valueShape, sgMap))
346+
return emitOpError() << "Result shape doesn't match TensorDesc shape."
347+
<< "The expected shape is " << makeString(tdescShape)
348+
<< ". But the given shape is "
349+
<< makeString(valueShape) << ".\n";
317350
return success();
318351
}
319352

@@ -423,13 +456,13 @@ LogicalResult PrefetchOp::verify() {
423456
return emitOpError("Expects a scattered TensorDesc.\n");
424457

425458
if (!isReadHintOrNone(getL1HintAttr()))
426-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
459+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
427460

428461
if (!isReadHintOrNone(getL2HintAttr()))
429-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
462+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
430463

431464
if (!isReadHintOrNone(getL3HintAttr()))
432-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
465+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
433466

434467
return success();
435468
}
@@ -446,13 +479,13 @@ LogicalResult LoadGatherOp::verify() {
446479
return emitOpError("Expects a scattered TensorDesc.\n");
447480

448481
if (!isReadHintOrNone(getL1HintAttr()))
449-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
482+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
450483

451484
if (!isReadHintOrNone(getL2HintAttr()))
452-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
485+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
453486

454487
if (!isReadHintOrNone(getL3HintAttr()))
455-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
488+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
456489

457490
auto tdescElemTy = tdescTy.getElementType();
458491
auto valueElemTy = getElementType();
@@ -490,13 +523,13 @@ LogicalResult StoreScatterOp::verify() {
490523
return emitOpError("Expects a scattered TensorDesc.\n");
491524

492525
if (!isWriteHintOrNone(getL1HintAttr()))
493-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
526+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
494527

495528
if (!isWriteHintOrNone(getL2HintAttr()))
496-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
529+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
497530

498531
if (!isWriteHintOrNone(getL3HintAttr()))
499-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
532+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
500533

501534
auto maskTy = getMaskType();
502535
auto valueTy = getValueType();

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
8686
gpu.return
8787
}
8888

89+
// load_nd args may have different shapes, validated against sg_map
90+
// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
91+
gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
92+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
93+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
94+
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
95+
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
96+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
97+
gpu.return
98+
}
99+
89100
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
90101
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
91102
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -108,6 +119,19 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
108119
gpu.return
109120
}
110121

122+
// store_nd args may have different shapes, validated against sg_map
123+
// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
124+
gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
125+
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
126+
%1 = arith.constant dense<1.0>: vector<24x2xf16>
127+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
128+
%2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
129+
!xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
130+
// CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
131+
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
132+
gpu.return
133+
}
134+
111135
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
112136
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
113137
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
3232
// -----
3333
func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
3434
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
35-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
35+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
3636
xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<8x16xf16>
3737
return
3838
}
@@ -51,7 +51,7 @@ func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
5151
// -----
5252
func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
5353
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
54-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
54+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
5555
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>
5656
: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
5757
return
@@ -81,7 +81,7 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
8181
func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
8282
%1 = arith.constant dense<1.0>: vector<24x32xf16>
8383
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
84-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
84+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
8585
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16>
8686
return
8787
}
@@ -147,7 +147,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
147147
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
148148
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>
149149
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
150-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
150+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
151151
xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
152152
return
153153
}
@@ -168,7 +168,7 @@ func.func @test_load_gather_vc_2(%src: ui64) {
168168
%0 = arith.constant dense<1>: vector<4xi1>
169169
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
170170
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
171-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
171+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
172172
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
173173
: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
174174
-> vector<4x2xf32>
@@ -193,7 +193,7 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
193193
%1 = arith.constant dense<2.9>: vector<4x2xf32>
194194
%2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
195195
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
196-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
196+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
197197
xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
198198
!xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
199199
return

0 commit comments

Comments
 (0)