Skip to content

Commit cb9e65d

Browse files
banach-spacechencha3
authored andcommitted
[mlir][vector] Add support for masks in castAwayContractionLeadingOneDim (llvm#81906)
Updates `castAwayContractionLeadingOneDim` to inherit from `MaskableOpRewritePattern` so that this pattern can support masking. Builds on top of llvm#83827
1 parent 8551546 commit cb9e65d

File tree

4 files changed

+114
-56
lines changed

4 files changed

+114
-56
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp);
110110

111111
/// Cast away the leading unit dim, if exists, for the given contract op.
112112
/// Return success if the transformation applies; return failure otherwise.
113-
LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
114-
RewriterBase &rewriter);
113+
FailureOr<Value>
114+
castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
115+
MaskingOpInterface maskingOp,
116+
RewriterBase &rewriter);
115117

116118
} // namespace vector
117119
} // namespace mlir

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
127127
/// responsible for providing an updated ("rewritten") version of:
128128
/// a. the source Op when mask _is not_ present,
129129
/// b. the source Op and the masking Op when mask _is_ present.
130-
/// Note that the return value from `matchAndRewriteMaskableOp` depends on the
131-
/// case above.
130+
/// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
131+
/// the return value will depend on the case above.
132132
template <class SourceOp>
133133
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
134134
using OpRewritePattern<SourceOp>::OpRewritePattern;
@@ -162,9 +162,9 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
162162
}
163163

164164
public:
165-
// Matches SourceOp that can potentially be masked with `maskingOp`. If the
166-
// latter is present, returns an updated masking op (with a replacement for
167-
// `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
165+
// Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
166+
// latter is present, returns a replacement for `maskingOp`. Otherwise,
167+
// returns a replacement for `sourceOp`.
168168
virtual FailureOr<Value>
169169
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
170170
PatternRewriter &rewriter) const = 0;

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,12 +329,10 @@ struct CastAwayTransferWriteLeadingOneDim
329329

330330
} // namespace
331331

332-
LogicalResult
332+
FailureOr<Value>
333333
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
334+
MaskingOpInterface maskingOp,
334335
RewriterBase &rewriter) {
335-
// TODO(#78787): Not supported masked op yet.
336-
if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
337-
return failure();
338336
VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
339337
if (oldAccType == nullptr)
340338
return failure();
@@ -368,6 +366,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
368366
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
369367
contractOp.getAcc()};
370368
SmallVector<Value> newOperands;
369+
auto loc = contractOp.getLoc();
371370

372371
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
373372
// Check if the dim to be dropped exists as a leading dim in the operand
@@ -405,7 +404,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
405404
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
406405
contractOp.getContext());
407406
operands[it.index()] = rewriter.create<vector::TransposeOp>(
408-
contractOp.getLoc(), operands[it.index()], perm);
407+
loc, operands[it.index()], perm);
409408
}
410409
}
411410
// We have taken care to have the dim to be dropped be
@@ -429,18 +428,29 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
429428
// Extract if its a valid extraction, otherwise use the operand
430429
// without extraction.
431430
newOperands.push_back(
432-
validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
433-
operands[it.index()],
434-
splatZero(dropDim))
431+
validExtract ? rewriter.create<vector::ExtractOp>(
432+
loc, operands[it.index()], splatZero(dropDim))
435433
: operands[it.index()]);
436434
}
437-
auto newContractOp = rewriter.create<vector::ContractionOp>(
438-
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
435+
436+
// Depending on whether this vector.contract is masked, the replacing Op
437+
// should either be a new vector.contract Op or vector.mask Op.
438+
Operation *newOp = rewriter.create<vector::ContractionOp>(
439+
loc, newOperands[0], newOperands[1], newOperands[2],
439440
rewriter.getAffineMapArrayAttr(newIndexingMaps),
440441
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
441-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
442-
contractOp, contractOp->getResultTypes()[0], newContractOp);
443-
return success();
442+
443+
if (maskingOp) {
444+
auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(),
445+
splatZero(dropDim));
446+
447+
newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
448+
}
449+
450+
return rewriter
451+
.create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
452+
newOp->getResults()[0])
453+
.getResult();
444454
}
445455

446456
namespace {
@@ -450,12 +460,14 @@ namespace {
450460
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
451461
/// prior to extract.
452462
struct CastAwayContractionLeadingOneDim
453-
: public OpRewritePattern<vector::ContractionOp> {
454-
using OpRewritePattern::OpRewritePattern;
455-
456-
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
457-
PatternRewriter &rewriter) const override {
458-
return castAwayContractionLeadingOneDim(contractOp, rewriter);
463+
: public MaskableOpRewritePattern<vector::ContractionOp> {
464+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
465+
466+
FailureOr<Value>
467+
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
468+
MaskingOpInterface maskingOp,
469+
PatternRewriter &rewriter) const override {
470+
return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
459471
}
460472
};
461473

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
3030
}
3131

3232
// -----
33+
// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
34+
// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
35+
// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
36+
37+
// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask
38+
// CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
39+
// CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
40+
// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
41+
// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
42+
// CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
43+
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
44+
// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
45+
// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
46+
// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
47+
// CHECK: return %[[RES]] : vector<1x16x16xf32>
48+
49+
#contraction_accesses0 = [
50+
affine_map<(l, i, j, k) -> (l, i, k)>,
51+
affine_map<(l, i, j, k) -> (l, k, j)>,
52+
affine_map<(l, i, j, k) -> (l, i, j)>
53+
]
54+
#contraction_trait0 = {
55+
indexing_maps = #contraction_accesses0,
56+
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
57+
}
58+
59+
func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
60+
%mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
61+
%0 = vector.mask %mask {
62+
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
63+
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
64+
return %0 : vector<1x16x16xf32>
65+
}
66+
67+
// -----
68+
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
69+
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
70+
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
71+
72+
// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask
73+
// CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
74+
// CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
75+
// CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
76+
// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
77+
// CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] {
78+
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
79+
// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
80+
// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
81+
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
82+
// CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32>
83+
84+
#contraction_accesses0 = [
85+
affine_map<(l, i, j, k) -> (l, i, k)>,
86+
affine_map<(l, i, j, k) -> (l, k, j)>,
87+
affine_map<(l, i, j, k) -> (l, i, j)>
88+
]
89+
#contraction_trait0 = {
90+
indexing_maps = #contraction_accesses0,
91+
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
92+
}
93+
94+
func.func @cast_away_contraction_leading_one_dim_under_mask(
95+
%arg0: vector<1x16x8xf32>,
96+
%arg1: vector<1x8x16xf32>,
97+
%arg2: vector<1x16x16xf32>,
98+
%mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
99+
%0 = vector.mask %mask {
100+
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
101+
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
102+
return %0: vector<1x16x16xf32>
103+
}
104+
105+
// -----
106+
33107
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
34108
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
35109
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
164238
return %0: vector<1x1x2x16xf32>
165239
}
166240

167-
// -----
168-
169-
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
170-
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
171-
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
172-
173-
// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
174-
// CHECK: %[[MASK:.+]] = vector.constant_mask
175-
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
176-
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
177-
// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
178-
// CHECK: return %[[RET]] : vector<1x16x16xf32>
179-
180-
#contraction_accesses0 = [
181-
affine_map<(l, i, j, k) -> (l, i, k)>,
182-
affine_map<(l, i, j, k) -> (l, k, j)>,
183-
affine_map<(l, i, j, k) -> (l, i, j)>
184-
]
185-
#contraction_trait0 = {
186-
indexing_maps = #contraction_accesses0,
187-
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
188-
}
189-
190-
func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
191-
%mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
192-
%0 = vector.mask %mask {
193-
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
194-
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
195-
return %0 : vector<1x16x16xf32>
196-
}
197241

198242
// -----
199243
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims

0 commit comments

Comments
 (0)