Skip to content

[mlir][vector] Add support for masks in castAwayContractionLeadingOneDim #81906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp);

/// Cast away the leading unit dim, if exists, for the given contract op.
/// Return success if the transformation applies; return failure otherwise.
LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter);
FailureOr<Value>
castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
RewriterBase &rewriter);

} // namespace vector
} // namespace mlir
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
/// responsible for providing an updated ("rewritten") version of:
/// a. the source Op when mask _is not_ present,
/// b. the source Op and the masking Op when mask _is_ present.
/// Note that the return value from `matchAndRewriteMaskableOp` depends on the
/// case above.
/// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
/// the return value will depend on the case above.
template <class SourceOp>
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
using OpRewritePattern<SourceOp>::OpRewritePattern;
Expand Down Expand Up @@ -162,9 +162,9 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
}

public:
// Matches SourceOp that can potentially be masked with `maskingOp`. If the
// latter is present, returns an updated masking op (with a replacement for
// `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
// Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
// latter is present, returns a replacement for `maskingOp`. Otherwise,
// returns a replacement for `sourceOp`.
virtual FailureOr<Value>
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const = 0;
Expand Down
50 changes: 31 additions & 19 deletions mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,10 @@ struct CastAwayTransferWriteLeadingOneDim

} // namespace

LogicalResult
FailureOr<Value>
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
RewriterBase &rewriter) {
// TODO(#78787): Not supported masked op yet.
if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
return failure();
VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
if (oldAccType == nullptr)
return failure();
Expand Down Expand Up @@ -368,6 +366,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
contractOp.getAcc()};
SmallVector<Value> newOperands;
auto loc = contractOp.getLoc();

for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
// Check if the dim to be dropped exists as a leading dim in the operand
Expand Down Expand Up @@ -405,7 +404,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
operands[it.index()] = rewriter.create<vector::TransposeOp>(
contractOp.getLoc(), operands[it.index()], perm);
loc, operands[it.index()], perm);
}
}
// We have taken care to have the dim to be dropped be
Expand All @@ -429,18 +428,29 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
// Extract if its a valid extraction, otherwise use the operand
// without extraction.
newOperands.push_back(
validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
operands[it.index()],
splatZero(dropDim))
validExtract ? rewriter.create<vector::ExtractOp>(
loc, operands[it.index()], splatZero(dropDim))
: operands[it.index()]);
}
auto newContractOp = rewriter.create<vector::ContractionOp>(
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],

// Depending on whether this vector.contract is masked, the replacing Op
// should either be a new vector.contract Op or vector.mask Op.
Operation *newOp = rewriter.create<vector::ContractionOp>(
loc, newOperands[0], newOperands[1], newOperands[2],
rewriter.getAffineMapArrayAttr(newIndexingMaps),
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
contractOp, contractOp->getResultTypes()[0], newContractOp);
return success();

if (maskingOp) {
auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(),
splatZero(dropDim));

newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
}

return rewriter
.create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
newOp->getResults()[0])
.getResult();
}

namespace {
Expand All @@ -450,12 +460,14 @@ namespace {
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
/// prior to extract.
struct CastAwayContractionLeadingOneDim
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
return castAwayContractionLeadingOneDim(contractOp, rewriter);
: public MaskableOpRewritePattern<vector::ContractionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
}
};

Expand Down
104 changes: 74 additions & 30 deletions mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
}

// -----
// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>

// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask
// CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
// CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
// CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
// CHECK: return %[[RES]] : vector<1x16x16xf32>

#contraction_accesses0 = [
affine_map<(l, i, j, k) -> (l, i, k)>,
affine_map<(l, i, j, k) -> (l, k, j)>,
affine_map<(l, i, j, k) -> (l, i, j)>
]
#contraction_trait0 = {
indexing_maps = #contraction_accesses0,
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
}

func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
%mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
%0 = vector.mask %mask {
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
return %0 : vector<1x16x16xf32>
}

// -----
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>

// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask
// CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
// CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
// CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
// CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] {
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
// CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32>

#contraction_accesses0 = [
affine_map<(l, i, j, k) -> (l, i, k)>,
affine_map<(l, i, j, k) -> (l, k, j)>,
affine_map<(l, i, j, k) -> (l, i, j)>
]
#contraction_trait0 = {
indexing_maps = #contraction_accesses0,
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
}

func.func @cast_away_contraction_leading_one_dim_under_mask(
%arg0: vector<1x16x8xf32>,
%arg1: vector<1x8x16xf32>,
%arg2: vector<1x16x16xf32>,
%mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
%0 = vector.mask %mask {
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
return %0: vector<1x16x16xf32>
}

// -----

// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
Expand Down Expand Up @@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
return %0: vector<1x1x2x16xf32>
}

// -----

// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>

// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
// CHECK: %[[MASK:.+]] = vector.constant_mask
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
// CHECK: return %[[RET]] : vector<1x16x16xf32>

#contraction_accesses0 = [
affine_map<(l, i, j, k) -> (l, i, k)>,
affine_map<(l, i, j, k) -> (l, k, j)>,
affine_map<(l, i, j, k) -> (l, i, j)>
]
#contraction_trait0 = {
indexing_maps = #contraction_accesses0,
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
}

func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
%mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
%0 = vector.mask %mask {
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
return %0 : vector<1x16x16xf32>
}

// -----
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
Expand Down