@@ -213,6 +213,42 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op,
213
213
return success ();
214
214
}
215
215
216
+ static bool isTypeSupported (Type outType, Type operandAType,
217
+ Type operandBType) {
218
+ if (!outType.isF32 () && !outType.isSignedInteger (32 ))
219
+ return false ;
220
+
221
+ if (outType.isF32 ()) {
222
+ if (!(operandAType.isF32 () && operandBType.isF32 ()) &&
223
+ !(operandAType.isBF16 () && operandBType.isBF16 ()))
224
+ return false ;
225
+ }
226
+ if (outType.isSignedInteger (32 )) {
227
+ if (!(operandAType.isSignedInteger (8 ) ||
228
+ operandAType.isUnsignedInteger (8 )) &&
229
+ (operandBType.isSignedInteger (8 ) || operandBType.isUnsignedInteger (8 )))
230
+ return false ;
231
+ }
232
+ return true ;
233
+ }
234
+
235
+ // TODO(haixin): could use compiler-wide VNNI utils?
236
+ static bool isInVnniLayout (ShapedType type) {
237
+ if (!type.getElementType ().isBF16 () &&
238
+ !type.getElementType ().isSignedInteger (8 ) &&
239
+ !type.getElementType ().isUnsignedInteger (8 ))
240
+ return false ;
241
+
242
+ auto blockingFactor = 0 ;
243
+ if (type.getElementType ().isBF16 ())
244
+ blockingFactor = 2 ;
245
+ else if (type.getElementType ().isSignedInteger (8 ) ||
246
+ type.getElementType ().isUnsignedInteger (8 ))
247
+ blockingFactor = 4 ;
248
+
249
+ return type.getShape ().back () == blockingFactor;
250
+ }
251
+
216
252
// ///////////////////////////////////////////////////
217
253
// Start of BrgemmOp
218
254
@@ -308,9 +344,8 @@ static inline ArrayRef<int64_t> getShapedValueShape(Value val) {
308
344
assert ((llvm::isa<TensorType>(val.getType ()) ||
309
345
llvm::isa<MemRefType>(val.getType ())) &&
310
346
" Expecting shaped value" );
311
- if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType ())) {
347
+ if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType ()))
312
348
return tensorTy.getShape ();
313
- }
314
349
auto memrefTy = dyn_cast_or_null<MemRefType>(val.getType ());
315
350
return memrefTy.getShape ();
316
351
}
@@ -331,15 +366,27 @@ LogicalResult BrgemmOp::verify() {
331
366
return op.emitOpError ()
332
367
<< " expect inputs and its related info to be size 2\n " ;
333
368
369
+ auto elemTypeA = getElementTypeOrSelf (ins[0 ]);
370
+ auto elemTypeB = getElementTypeOrSelf (ins[1 ]);
371
+ auto elemTypeC = getElementTypeOrSelf (out);
372
+ if (!isTypeSupported (elemTypeC, elemTypeA, elemTypeB))
373
+ return op.emitOpError () << " unsupported input matrix types\n " ;
374
+
334
375
ArrayRef<int64_t > dimA = getShapedValueShape (ins[0 ]);
335
376
ArrayRef<int64_t > dimB = getShapedValueShape (ins[1 ]);
336
377
ArrayRef<int64_t > dimC = getShapedValueShape (out);
337
378
if (dimA.size () != 3 )
338
379
return op.emitOpError () << " expect input A to be 3D\n " ;
339
- if (dimB.size () != 3 && dimB.size () != 4 )
340
- return op.emitOpError () << " expect input B to be 3D or 4D\n " ;
341
- if (dimB.size () == 4 && (dimB[3 ] != 2 && dimB[3 ] != 4 ))
342
- return op.emitOpError () << " expect input B vnni step to be 2 or 4\n " ;
380
+ if (!elemTypeB.isF32 ()) {
381
+ if (dimB.size () != 4 ||
382
+ !isInVnniLayout (dyn_cast<ShapedType>(ins[1 ].getType ())))
383
+ return op.emitOpError ()
384
+ << " expect a 4d VNNI input B for non-F32 operand: " << ins[1 ];
385
+ } else {
386
+ if (dimB.size () != 3 )
387
+ return op.emitOpError ()
388
+ << " expect a 3d input B for F32 operand: " << ins[1 ];
389
+ }
343
390
if (dimC.size () != 2 )
344
391
return op.emitOpError () << " expect input C to be 2D\n " ;
345
392
for (auto dim : batchDims)
@@ -558,42 +605,6 @@ LogicalResult BrgemmDispatchOp::verify() {
558
605
// ///////////////////////////////////////////////////
559
606
// Start of BrgemmExecuteOp
560
607
561
- // TODO(haixin): could use compiler-wide VNNI utils?
562
- static bool isInVnniLayout (MemRefType memref) {
563
- if (!memref.getElementType ().isBF16 () &&
564
- !memref.getElementType ().isSignedInteger (8 ) &&
565
- !memref.getElementType ().isUnsignedInteger (8 ))
566
- return false ;
567
-
568
- auto blockingFactor = 0 ;
569
- if (memref.getElementType ().isBF16 ())
570
- blockingFactor = 2 ;
571
- else if (memref.getElementType ().isSignedInteger (8 ) ||
572
- memref.getElementType ().isUnsignedInteger (8 ))
573
- blockingFactor = 4 ;
574
-
575
- return memref.getShape ().back () == blockingFactor;
576
- }
577
-
578
- static bool isTypeSupported (Type outType, Type operandAType,
579
- Type operandBType) {
580
- if (!outType.isF32 () && !outType.isSignedInteger (32 ))
581
- return false ;
582
-
583
- if (outType.isF32 ()) {
584
- if (!(operandAType.isF32 () && operandBType.isF32 ()) &&
585
- !(operandAType.isBF16 () && operandBType.isBF16 ()))
586
- return false ;
587
- }
588
- if (outType.isSignedInteger (32 )) {
589
- if (!(operandAType.isSignedInteger (8 ) ||
590
- operandAType.isUnsignedInteger (8 )) &&
591
- (operandBType.isSignedInteger (8 ) || operandBType.isUnsignedInteger (8 )))
592
- return false ;
593
- }
594
- return true ;
595
- }
596
-
597
608
LogicalResult BrgemmExecuteOp::verify () {
598
609
BrgemmExecuteOp &brgemmOp = *this ;
599
610
0 commit comments