Skip to content

[mlir][linalg] Move transpose_matmul to targeted transform op #89717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,6 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.transpose_matmul",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collects patterns to convert Linalg matmul ops to transposed variants.

By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
instead transpose RHS matrix.
}];

let arguments = (ins
DefaultValuedAttr<TransposeMatmulInput,
"TransposeMatmulInput::lhs">:$inputToTranspose);

let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
}

//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2429,6 +2412,52 @@ def TransposeConv2DOp : Op<Transform_Dialect,
}];
}

//===----------------------------------------------------------------------===//
// TransposeMatmulOp
//===----------------------------------------------------------------------===//

def TransposeMatmulOp : Op<Transform_Dialect,
"structured.transpose_matmul",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Convert Linalg matmul ops to transposed variants.

By default the LHS matrix is transposed. Specify `<rhs>` to instead
transpose RHS matrix.

#### Return modes:

This operation fails if `target` is unsupported, i.e., not a
`linalg.matmul` or `linalg.batch_matmul`. Otherwise, the operation succeeds
and returns a handle to the transposed matmul op.
}];

let arguments = (ins
TransformHandleTypeInterface:$target,
DefaultValuedAttr<TransposeMatmulInput,
"TransposeMatmulInput::lhs">:$inputToTranspose);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat = [{
$target (`<` $inputToTranspose^ `>`)?
attr-dict `:` functional-type($target, results)
}];

let builders = [
OpBuilder<(ins "Value":$target)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,14 @@ FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcQOp op);

/// Convert Linalg matmul ops to transposed variants.
FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp op,
bool transposeLHS = true);
FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp op,
bool transposeLHS = true);

//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
Expand Down
32 changes: 26 additions & 6 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,6 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
}

void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
}

//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3422,6 +3416,32 @@ DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TransposeMatmulOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
auto maybeTransformed =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
.Case([&](linalg::MatmulOp op) {
return transposeMatmul(rewriter, op, transposeLHS);
})
.Case([&](linalg::BatchMatmulOp op) {
return transposeBatchMatmul(rewriter, op, transposeLHS);
})
.Default([&](Operation *op) { return failure(); });
if (failed(maybeTransformed))
return emitSilenceableFailure(target->getLoc()) << "not supported";
// Handle to the new Matmul operation with transposed filters
results.push_back(*maybeTransformed);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
Expand Down
175 changes: 98 additions & 77 deletions mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
using namespace mlir;
using namespace mlir::linalg;

namespace {
/// Pattern to replace
///
/// linalg.matmul(a, b)
Expand All @@ -29,102 +28,124 @@ namespace {
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");

Location loc = matmulOp.getLoc();
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());

SmallVector<Value> dynamicDims;
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));

ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
}
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}

/// Pattern to replace
///
/// linalg.batch_matmul(a, b)
///
/// with
///
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
bool transposeLHS) {
if (!bufferization::hasTensorSemantics(batchMatmulOp))
return rewriter.notifyMatchFailure(
batchMatmulOp, "only matmul ops with tensors are supported");

Location loc = batchMatmulOp.getLoc();
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());

SmallVector<Value> dynamicDims;
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (type.isDynamicDim(2))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));

ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
}
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}

namespace {
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}

LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
LogicalResult matchAndRewrite(linalg::MatmulOp op,
PatternRewriter &rewriter) const override {
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");

Location loc = matmulOp.getLoc();
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());

SmallVector<Value> dynamicDims;
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));

ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
if (transposeLHS) {
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
matmulOp, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
matmulOp, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
return failure();
}

return success();
}

private:
bool transposeLHS;
};

/// Pattern to replace
///
/// linalg.batch_matmul(a, b)
///
/// with
///
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
struct TransposeBatchMatmul final
: public OpRewritePattern<linalg::BatchMatmulOp> {
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}

LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
PatternRewriter &rewriter) const override {
if (!bufferization::hasTensorSemantics(batchMatmulOp))
return rewriter.notifyMatchFailure(
batchMatmulOp, "only matmul ops with tensors are supported");

Location loc = batchMatmulOp.getLoc();
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());

SmallVector<Value> dynamicDims;
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (type.isDynamicDim(2))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));

ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
if (transposeLHS) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
batchMatmulOp, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
batchMatmulOp, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
return failure();
}

return success();
}

Expand Down
5 changes: 2 additions & 3 deletions mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.transpose_matmul %matmul : (!transform.any_op) -> (!transform.any_op)
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.linalg.transpose_matmul
} : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.canonicalization
Expand Down
5 changes: 2 additions & 3 deletions mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.transpose_matmul %matmul <rhs> : (!transform.any_op) -> (!transform.any_op)
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.linalg.transpose_matmul <rhs>
} : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.canonicalization
Expand Down