@@ -73,6 +73,29 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
73
73
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
74
74
}
75
75
76
+ // Validations for nd instruction arguments is successful if any of these are
77
+ // true:
78
+ // - tensor descriptor and the output vector shapes exactly match.
79
+ // - tensor descriptor has a sg_map attribute and the distributed vector shape
80
+ // matches the tensor descriptor shape when scaled using sg_map factors on
81
+ // each dimension.
82
+ static bool isArgShapesValid (ArrayRef<int64_t > descShape,
83
+ ArrayRef<int64_t > valShape, SGMapAttr sgMap) {
84
+ if (descShape == valShape)
85
+ return true ;
86
+
87
+ if (!sgMap)
88
+ return false ;
89
+
90
+ for (const auto &[factor, dim, expected] :
91
+ llvm::zip_equal (sgMap.getWiLayout (), valShape, descShape)) {
92
+ if (factor * dim != expected)
93
+ return false ;
94
+ }
95
+
96
+ return true ;
97
+ }
98
+
76
99
// ===----------------------------------------------------------------------===//
77
100
// XeGPU_CreateNdDescOp
78
101
// ===----------------------------------------------------------------------===//
@@ -210,13 +233,13 @@ LogicalResult PrefetchNdOp::verify() {
210
233
return emitOpError (" Expects a non-scattered TensorDesc.\n " );
211
234
212
235
if (!isReadHintOrNone (getL1HintAttr ()))
213
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
236
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
214
237
215
238
if (!isReadHintOrNone (getL2HintAttr ()))
216
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
239
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
217
240
218
241
if (!isReadHintOrNone (getL3HintAttr ()))
219
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
242
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
220
243
221
244
return success ();
222
245
}
@@ -238,13 +261,13 @@ LogicalResult LoadNdOp::verify() {
238
261
return emitOpError (" Invalid result, it should be a VectorType.\n " );
239
262
240
263
if (!isReadHintOrNone (getL1HintAttr ()))
241
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
264
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
242
265
243
266
if (!isReadHintOrNone (getL2HintAttr ()))
244
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
267
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
245
268
246
269
if (!isReadHintOrNone (getL3HintAttr ()))
247
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
270
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
248
271
249
272
auto array_len = tdescTy.getArrayLength ();
250
273
auto tdescShape = getShapeOf (tdescTy);
@@ -280,8 +303,9 @@ LogicalResult LoadNdOp::verify() {
280
303
auto it = tdescShape.begin ();
281
304
tdescShape.insert (it, array_len);
282
305
}
306
+ auto sgMap = tdescTy.getSGMapAttr ();
283
307
284
- if (tdescShape != valueShape)
308
+ if (! isArgShapesValid ( tdescShape, valueShape, sgMap) )
285
309
return emitOpError () << " Result shape doesn't match TensorDesc shape."
286
310
<< " The expected shape is " << makeString (tdescShape)
287
311
<< " . But the given shape is "
@@ -303,17 +327,26 @@ LogicalResult StoreNdOp::verify() {
303
327
return emitOpError (" Expects a non-scattered TensorDesc.\n " );
304
328
305
329
if (!valTy)
306
- return emitOpError (" Exepcting a VectorType result.\n " );
330
+ return emitOpError (" Expecting a VectorType result.\n " );
307
331
308
332
if (!isWriteHintOrNone (getL1HintAttr ()))
309
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
333
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
310
334
311
335
if (!isWriteHintOrNone (getL2HintAttr ()))
312
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
336
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
313
337
314
338
if (!isWriteHintOrNone (getL3HintAttr ()))
315
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
339
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
340
+
341
+ auto tdescShape = getShapeOf (dstTy);
342
+ auto valueShape = getShapeOf (valTy);
343
+ auto sgMap = dstTy.getSGMapAttr ();
316
344
345
+ if (!isArgShapesValid (tdescShape, valueShape, sgMap))
346
+ return emitOpError () << " Result shape doesn't match TensorDesc shape."
347
+ << " The expected shape is " << makeString (tdescShape)
348
+ << " . But the given shape is "
349
+ << makeString (valueShape) << " .\n " ;
317
350
return success ();
318
351
}
319
352
@@ -423,13 +456,13 @@ LogicalResult PrefetchOp::verify() {
423
456
return emitOpError (" Expects a scattered TensorDesc.\n " );
424
457
425
458
if (!isReadHintOrNone (getL1HintAttr ()))
426
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
459
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
427
460
428
461
if (!isReadHintOrNone (getL2HintAttr ()))
429
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
462
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
430
463
431
464
if (!isReadHintOrNone (getL3HintAttr ()))
432
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
465
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
433
466
434
467
return success ();
435
468
}
@@ -446,13 +479,13 @@ LogicalResult LoadGatherOp::verify() {
446
479
return emitOpError (" Expects a scattered TensorDesc.\n " );
447
480
448
481
if (!isReadHintOrNone (getL1HintAttr ()))
449
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
482
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
450
483
451
484
if (!isReadHintOrNone (getL2HintAttr ()))
452
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
485
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
453
486
454
487
if (!isReadHintOrNone (getL3HintAttr ()))
455
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
488
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
456
489
457
490
auto tdescElemTy = tdescTy.getElementType ();
458
491
auto valueElemTy = getElementType ();
@@ -490,13 +523,13 @@ LogicalResult StoreScatterOp::verify() {
490
523
return emitOpError (" Expects a scattered TensorDesc.\n " );
491
524
492
525
if (!isWriteHintOrNone (getL1HintAttr ()))
493
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
526
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
494
527
495
528
if (!isWriteHintOrNone (getL2HintAttr ()))
496
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
529
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
497
530
498
531
if (!isWriteHintOrNone (getL3HintAttr ()))
499
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
532
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
500
533
501
534
auto maskTy = getMaskType ();
502
535
auto valueTy = getValueType ();
0 commit comments