Skip to content

Commit 1a1ec4a

Browse files
committed
Simplify validation + update tests
1 parent 1e00eb8 commit 1a1ec4a

File tree

2 files changed

+86
-190
lines changed

2 files changed

+86
-190
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 16 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1718
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1819
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1920
#include "mlir/Pass/Pass.h"
@@ -312,28 +313,6 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
312313
}
313314
};
314315

315-
static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
316-
vector::ContractionOp contractOp) {
317-
MLIRContext *ctx = contractOp.getContext();
318-
SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
319-
320-
// Operand rank defines expected data layout:
321-
// - 2D for standard GEMM
322-
// - 3D for VNNI layout
323-
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
324-
auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
325-
AffineExpr m, n, k, vnni;
326-
bindDims(ctx, m, n, k, vnni);
327-
328-
if (contractOp.getRhsType().getRank() == 2) {
329-
// Require plain GEMM without any transposition.
330-
return success(maps == infer({{m, k}, {k, n}, {m, n}}));
331-
}
332-
333-
// Require VNNI layout.
334-
return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
335-
}
336-
337316
struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
338317
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
339318

@@ -349,48 +328,30 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
349328
VectorType accType = dyn_cast<VectorType>(acc.getType());
350329
if (!accType || accType.getRank() != 2)
351330
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
331+
332+
// Accept only plain 2D data layout.
333+
// VNNI packing is left to later lowering.
352334
TypedValue<VectorType> lhs = contractOp.getLhs();
353-
VectorType lhsType = lhs.getType();
354-
int64_t lhsRank = lhsType.getRank();
355-
if (!(lhsRank == 2 || lhsRank == 3))
356-
return rewriter.notifyMatchFailure(contractOp,
357-
"Expects lhs 2D or 3D vector");
358335
TypedValue<VectorType> rhs = contractOp.getRhs();
359-
VectorType rhsType = rhs.getType();
360-
int64_t rhsRank = rhsType.getRank();
361-
if (!(rhsRank == 2 || rhsRank == 3))
336+
if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
362337
return rewriter.notifyMatchFailure(contractOp,
363-
"Expects rhs 2D or 3D vector");
364-
if (lhsRank != rhsRank)
365-
return rewriter.notifyMatchFailure(
366-
contractOp, "Expects lhs and rhs to be the same rank");
338+
"Expects lhs and rhs 2D vectors");
367339

368-
if (failed(validateDpasIndexing(rewriter, contractOp)))
340+
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
369341
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
370342

371-
// 3D shape implies VNNI layout verified by the earlier indexing check.
372-
bool isVnni = rhsRank == 3;
373-
auto rhsShape = rhsType.getShape();
374-
int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
375-
unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
376-
if (dimK != (8 * 32 / elemBitWidth))
343+
// TODO: Update shape validation to be target aware.
344+
auto rhsShape = rhs.getType().getShape();
345+
auto accShape = accType.getShape();
346+
int64_t dimM = accShape[0];
347+
int64_t dimN = accShape[1];
348+
int64_t dimK = rhsShape[0];
349+
if (dimM != 8 || dimN != 16 || dimK % 8 != 0)
377350
return rewriter.notifyMatchFailure(contractOp,
378-
"Invalid K-dimension size");
379-
if (isVnni && rhsShape[2] != (32 / elemBitWidth))
380-
return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
381-
382-
if (isVnni) {
383-
// Collapse contract lhs VNNI factor back into K-dim as dpas op expects
384-
// flat 2D shape for its lhs operand.
385-
auto lhsShape = lhsType.getShape();
386-
auto lhsFlatType = VectorType::get(
387-
{lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
388-
lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
389-
.getResult();
390-
}
351+
"Invalid operand dimensions");
391352

392353
auto dpasOp = rewriter.create<xegpu::DpasOp>(
393-
loc, contractOp.getResultType(), lhs, rhs, acc);
354+
loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
394355
rewriter.replaceOp(contractOp, dpasOp);
395356

396357
return success();

0 commit comments

Comments
 (0)