14
14
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
+ #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
17
18
#include " mlir/Dialect/Vector/IR/VectorOps.h"
18
19
#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
19
20
#include " mlir/Pass/Pass.h"
@@ -312,28 +313,6 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
312
313
}
313
314
};
314
315
315
- static LogicalResult validateDpasIndexing (PatternRewriter &rewriter,
316
- vector::ContractionOp contractOp) {
317
- MLIRContext *ctx = contractOp.getContext ();
318
- SmallVector<AffineMap, 4 > maps = contractOp.getIndexingMapsArray ();
319
-
320
- // Operand rank defines expected data layout:
321
- // - 2D for standard GEMM
322
- // - 3D for VNNI layout
323
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
324
- auto infer = [&](MapList m) { return AffineMap::inferFromExprList (m, ctx); };
325
- AffineExpr m, n, k, vnni;
326
- bindDims (ctx, m, n, k, vnni);
327
-
328
- if (contractOp.getRhsType ().getRank () == 2 ) {
329
- // Require plain GEMM without any transposition.
330
- return success (maps == infer ({{m, k}, {k, n}, {m, n}}));
331
- }
332
-
333
- // Require VNNI layout.
334
- return success (maps == infer ({{m, k, vnni}, {k, n, vnni}, {m, n}}));
335
- }
336
-
337
316
struct ContractionLowering : public OpRewritePattern <vector::ContractionOp> {
338
317
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
339
318
@@ -349,48 +328,30 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
349
328
VectorType accType = dyn_cast<VectorType>(acc.getType ());
350
329
if (!accType || accType.getRank () != 2 )
351
330
return rewriter.notifyMatchFailure (contractOp, " Expects acc 2D vector" );
331
+
332
+ // Accept only plain 2D data layout.
333
+ // VNNI packing is left to later lowering.
352
334
TypedValue<VectorType> lhs = contractOp.getLhs ();
353
- VectorType lhsType = lhs.getType ();
354
- int64_t lhsRank = lhsType.getRank ();
355
- if (!(lhsRank == 2 || lhsRank == 3 ))
356
- return rewriter.notifyMatchFailure (contractOp,
357
- " Expects lhs 2D or 3D vector" );
358
335
TypedValue<VectorType> rhs = contractOp.getRhs ();
359
- VectorType rhsType = rhs.getType ();
360
- int64_t rhsRank = rhsType.getRank ();
361
- if (!(rhsRank == 2 || rhsRank == 3 ))
336
+ if (lhs.getType ().getRank () != 2 || rhs.getType ().getRank () != 2 )
362
337
return rewriter.notifyMatchFailure (contractOp,
363
- " Expects rhs 2D or 3D vector" );
364
- if (lhsRank != rhsRank)
365
- return rewriter.notifyMatchFailure (
366
- contractOp, " Expects lhs and rhs to be the same rank" );
338
+ " Expects lhs and rhs 2D vectors" );
367
339
368
- if (failed ( validateDpasIndexing (rewriter, contractOp)))
340
+ if (! isRowMajorMatmul ( contractOp. getIndexingMapsAttr ( )))
369
341
return rewriter.notifyMatchFailure (contractOp, " Invalid indexing maps" );
370
342
371
- // 3D shape implies VNNI layout verified by the earlier indexing check.
372
- bool isVnni = rhsRank == 3 ;
373
- auto rhsShape = rhsType.getShape ();
374
- int64_t dimK = isVnni ? rhsShape[0 ] * rhsShape[2 ] : rhsShape[0 ];
375
- unsigned elemBitWidth = rhsType.getElementType ().getIntOrFloatBitWidth ();
376
- if (dimK != (8 * 32 / elemBitWidth))
343
+ // TODO: Update shape validation to be target aware.
344
+ auto rhsShape = rhs.getType ().getShape ();
345
+ auto accShape = accType.getShape ();
346
+ int64_t dimM = accShape[0 ];
347
+ int64_t dimN = accShape[1 ];
348
+ int64_t dimK = rhsShape[0 ];
349
+ if (dimM != 8 || dimN != 16 || dimK % 8 != 0 )
377
350
return rewriter.notifyMatchFailure (contractOp,
378
- " Invalid K-dimension size" );
379
- if (isVnni && rhsShape[2 ] != (32 / elemBitWidth))
380
- return rewriter.notifyMatchFailure (contractOp, " Invalid VNNI factor" );
381
-
382
- if (isVnni) {
383
- // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
384
- // flat 2D shape for its lhs operand.
385
- auto lhsShape = lhsType.getShape ();
386
- auto lhsFlatType = VectorType::get (
387
- {lhsShape[0 ], lhsShape[1 ] * lhsShape[2 ]}, lhsType.getElementType ());
388
- lhs = rewriter.create <vector::ShapeCastOp>(loc, lhsFlatType, lhs)
389
- .getResult ();
390
- }
351
+ " Invalid operand dimensions" );
391
352
392
353
auto dpasOp = rewriter.create <xegpu::DpasOp>(
393
- loc, contractOp.getResultType (), lhs, rhs, acc);
354
+ loc, TypeRange{ contractOp.getResultType ()}, ValueRange{ lhs, rhs, acc} );
394
355
rewriter.replaceOp (contractOp, dpasOp);
395
356
396
357
return success ();
0 commit comments