Skip to content

Commit c0a354d

Browse files
[mlir][vector] Fix invalid IR in ContractionOpLowering (llvm#78130)
If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes `Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`. ``` * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' { Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering" ** Insert : 'vector.transpose'(0x5625b3a8cb30) ** Insert : 'vector.transpose'(0x5625b3a8cbc0) "(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0 } -> failure : pattern failed to match } -> failure : pattern failed to match LLVM ERROR: pattern returned failure but IR did change ``` Note: `vector-contract-to-outerproduct-transforms-unsupported.mlir` is merged into `vector-contract-to-outerproduct-matvec-transforms.mlir`. The `greedy pattern application failed` error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.
1 parent 480cc41 commit c0a354d

File tree

3 files changed

+118
-88
lines changed

3 files changed

+118
-88
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 104 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
426426
}
427427

428428
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
429-
VectorType lhsType, int reductionDim,
429+
VectorType lhsType, int reductionSize,
430430
std::optional<Value> maybeMask = std::nullopt) {
431-
// Unrolling a scalable dimension would be incorrect - bail out.
432-
if (lhsType.getScalableDims()[reductionDim])
433-
return failure();
434-
435-
int reductionSize = lhsType.getDimSize(reductionDim);
436-
assert(reductionSize > 0 &&
437-
"Reduction dim must be a known static size to allow unrolling");
438-
439431
// Incremental support for masking.
440432
if (mask && !maybeMask.has_value())
441433
return failure();
@@ -458,49 +450,93 @@ struct UnrolledOuterProductGenerator
458450
return res;
459451
}
460452

453+
/// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
454+
/// dimension `reductionDim`. If the dimension is a scalable dimension,
455+
/// returns "nullopt".
456+
std::optional<int64_t> getReductionSize(VectorType vecType,
457+
int64_t reductionDim) {
458+
// Cannot unroll scalable dimension.
459+
if (vecType.getScalableDims()[reductionDim])
460+
return std::nullopt;
461+
int64_t reductionSize = vecType.getDimSize(reductionDim);
462+
assert(reductionSize > 0 &&
463+
"Reduction dim must be a known static size to allow unrolling");
464+
return reductionSize;
465+
}
466+
461467
/// Two outer parallel, one inner reduction (matmat flavor).
462468
FailureOr<Value> matmat() {
463469
if (!iters({Par(), Par(), Red()}))
464470
return failure();
465471
// Set up the parallel/reduction structure in the right form.
466472
AffineExpr m, n, k;
467473
bindDims(rewriter.getContext(), m, n, k);
468-
Value transposedMask = t(mask, {2, 0, 1});
474+
469475
// Classical row-major matmul: Just permute the lhs.
470-
if (layout({{m, k}, {k, n}, {m, n}}))
471-
return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
472-
transposedMask);
476+
if (layout({{m, k}, {k, n}, {m, n}})) {
477+
if (auto reductionSize = getReductionSize(lhsType, 1)) {
478+
// Note: `t` creates new IR. It must be nested within this `if` check
479+
// so that no IR is created when then pattern returns "failure".
480+
Value tLhs = t(lhs);
481+
Value tMask = t(mask, {2, 0, 1});
482+
return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
483+
}
484+
}
473485
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
474486
if (layout({{m, k}, {n, k}, {m, n}})) {
475-
Value tlhs = t(lhs);
476-
return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
477-
transposedMask);
487+
if (auto reductionSize = getReductionSize(lhsType, 1)) {
488+
Value tLhs = t(lhs);
489+
Value tRhs = t(rhs);
490+
Value tMask = t(mask, {2, 0, 1});
491+
return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
492+
}
478493
}
479494
// No need to permute anything.
480-
if (layout({{k, m}, {k, n}, {m, n}}))
481-
return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
482-
transposedMask);
495+
if (layout({{k, m}, {k, n}, {m, n}})) {
496+
if (auto reductionSize = getReductionSize(lhsType, 0)) {
497+
Value tMask = t(mask, {2, 0, 1});
498+
return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
499+
}
500+
}
483501
// Just permute the rhs.
484-
if (layout({{k, m}, {n, k}, {m, n}}))
485-
return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
486-
transposedMask);
502+
if (layout({{k, m}, {n, k}, {m, n}})) {
503+
if (auto reductionSize = getReductionSize(lhsType, 0)) {
504+
Value tRhs = t(rhs);
505+
Value tMask = t(mask, {2, 0, 1});
506+
return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
507+
}
508+
}
487509
// Transposed output: swap RHS and LHS.
488510
// Classical row-major matmul: permute the lhs.
489-
if (layout({{m, k}, {k, n}, {n, m}}))
490-
return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
491-
transposedMask);
511+
if (layout({{m, k}, {k, n}, {n, m}})) {
512+
if (auto reductionSize = getReductionSize(lhsType, 1)) {
513+
Value tLhs = t(lhs);
514+
Value tMask = t(mask, {2, 0, 1});
515+
return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
516+
}
517+
}
492518
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
493519
if (layout({{m, k}, {n, k}, {n, m}})) {
494-
Value trhs = t(rhs);
495-
return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
496-
transposedMask);
520+
if (auto reductionSize = getReductionSize(lhsType, 1)) {
521+
Value tRhs = t(rhs);
522+
Value tLhs = t(lhs);
523+
Value tMask = t(mask, {2, 0, 1});
524+
return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
525+
}
526+
}
527+
if (layout({{k, m}, {k, n}, {n, m}})) {
528+
if (auto reductionSize = getReductionSize(lhsType, 0)) {
529+
Value tMask = t(mask, {2, 0, 1});
530+
return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
531+
}
532+
}
533+
if (layout({{k, m}, {n, k}, {n, m}})) {
534+
if (auto reductionSize = getReductionSize(lhsType, 0)) {
535+
Value tRhs = t(rhs);
536+
Value tMask = t(mask, {2, 0, 1});
537+
return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
538+
}
497539
}
498-
if (layout({{k, m}, {k, n}, {n, m}}))
499-
return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
500-
transposedMask);
501-
if (layout({{k, m}, {n, k}, {n, m}}))
502-
return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
503-
transposedMask);
504540
return failure();
505541
}
506542

@@ -514,24 +550,37 @@ struct UnrolledOuterProductGenerator
514550
return failure();
515551
AffineExpr m, k;
516552
bindDims(rewriter.getContext(), m, k);
517-
Value transposedMask = t(mask);
518553

519554
// Case mat-vec: transpose.
520-
if (layout({{m, k}, {k}, {m}}))
521-
return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
522-
transposedMask);
555+
if (layout({{m, k}, {k}, {m}})) {
556+
if (auto reductionSize = getReductionSize(lhsType, 1)) {
557+
Value tLhs = t(lhs);
558+
Value tMask = t(mask);
559+
return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
560+
}
561+
}
523562
// Case mat-trans-vec: ready to go.
524-
if (layout({{k, m}, {k}, {m}}))
525-
return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
526-
transposedMask);
563+
if (layout({{k, m}, {k}, {m}})) {
564+
if (auto reductionSize = getReductionSize(lhsType, 0)) {
565+
Value tMask = t(mask);
566+
return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
567+
}
568+
}
527569
// Case vec-mat: swap and transpose.
528-
if (layout({{k}, {m, k}, {m}}))
529-
return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
530-
transposedMask);
570+
if (layout({{k}, {m, k}, {m}})) {
571+
if (auto reductionSize = getReductionSize(lhsType, 0)) {
572+
Value tRhs = t(rhs);
573+
Value tMask = t(mask);
574+
return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
575+
}
576+
}
531577
// Case vec-mat-trans: swap and ready to go.
532-
if (layout({{k}, {k, m}, {m}}))
533-
return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
534-
transposedMask);
578+
if (layout({{k}, {k, m}, {m}})) {
579+
if (auto reductionSize = getReductionSize(lhsType, 0)) {
580+
Value tMask = t(mask);
581+
return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
582+
}
583+
}
535584
return failure();
536585
}
537586

@@ -547,16 +596,20 @@ struct UnrolledOuterProductGenerator
547596

548597
// Case mat-vec: transpose.
549598
if (layout({{m, k}, {k}, {m}}))
550-
return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
599+
if (auto reductionSize = getReductionSize(lhsType, 1))
600+
return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
551601
// Case mat-trans-vec: ready to go.
552602
if (layout({{k, m}, {k}, {m}}))
553-
return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
603+
if (auto reductionSize = getReductionSize(lhsType, 0))
604+
return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
554605
// Case vec-mat: swap and transpose.
555606
if (layout({{k}, {m, k}, {m}}))
556-
return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
607+
if (auto reductionSize = getReductionSize(lhsType, 0))
608+
return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
557609
// Case vec-mat-trans: swap and ready to go.
558610
if (layout({{k}, {k, m}, {m}}))
559-
return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
611+
if (auto reductionSize = getReductionSize(lhsType, 0))
612+
return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
560613
return failure();
561614
}
562615

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
320320
%x: vector<2xf32>,
321321
%b: vector<4xf32>,
322322
%mask: vector<4x2xi1>) -> vector<4xf32> {
323-
// CHECK: vector.transpose %[[MASK]]
324323
// CHECK: vector.transpose %[[A]]
324+
// CHECK: vector.transpose %[[MASK]]
325325
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
326326
%res = vector.mask %mask {
327327
vector.contract #matvec_trait_3 %x, %A, %b
@@ -339,8 +339,8 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
339339
%x: vector<2xf32>,
340340
%b: vector<[4]xf32>,
341341
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
342-
// CHECK: vector.transpose %[[MASK]]
343342
// CHECK: vector.transpose %[[A]]
343+
// CHECK: vector.transpose %[[MASK]]
344344
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
345345
%res = vector.mask %mask {
346346
vector.contract #matvec_trait_3 %x, %A, %b
@@ -641,6 +641,18 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
641641
return %res : vector<[4]xf32>
642642
}
643643

644+
// Unrolling scalable reduction dim is not supported - bail out
645+
// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
646+
// CHECK: vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
647+
func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
648+
%arg1: vector<[3]xf32>,
649+
%arg2: vector<[2]xf32>,
650+
%m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
651+
%0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
652+
: vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
653+
return %0 : vector<[2]xf32>
654+
}
655+
644656
// ============================================================================
645657
// TD sequence
646658
// ============================================================================

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir

Lines changed: 0 additions & 35 deletions
This file was deleted.

0 commit comments

Comments
 (0)