Skip to content

[mlir][linalg] Add patterns to convert matmul to transposed variants #89075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@ def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to matc
]>{
let cppNamespace = "mlir::transform";
}

def TransposeMatmulInput : I32EnumAttr<"TransposeMatmulInput",
"Input to transpose when converting matmul ops to transposed variants",
[
I32EnumAttrCase<"lhs", 0>,
I32EnumAttrCase<"rhs", 1>,
]>{
let cppNamespace = "mlir::transform";
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.transpose_matmul",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collects patterns to convert Linalg matmul ops to transposed variants.

By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
instead transpose RHS matrix.
}];

let arguments = (ins
DefaultValuedAttr<TransposeMatmulInput,
"TransposeMatmulInput::lhs">:$inputToTranspose);

let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
}

//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,10 @@ void populateSplitReductionPattern(
const ControlSplitReductionFn &controlSplitReductionFn,
bool useAlloc = false);

/// Patterns to convert Linalg matmul ops to transposed variants.
void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
bool transposeLHS = true);

} // namespace linalg
} // namespace mlir

Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
}

void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
}

//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
Padding.cpp
Expand Down
140 changes: 140 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This is intended to be a simple high-level (target-agnostic) matmul
// transposition transformation.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "linalg-transpose-matmul"

using namespace mlir;
using namespace mlir::linalg;

namespace {
/// Pattern to replace
///
/// linalg.matmul(a, b)
///
/// with
///
/// linalg.matmul_transpose_a(linalg.transpose(a), b)
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}

LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
PatternRewriter &rewriter) const override {
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");

Location loc = matmulOp.getLoc();
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());

SmallVector<Value> dynamicDims;
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));

ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
if (transposeLHS) {
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
matmulOp, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
matmulOp, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
}

return success();
}

private:
bool transposeLHS;
};

/// Pattern to replace
///
/// linalg.batch_matmul(a, b)
///
/// with
///
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
struct TransposeBatchMatmul final
: public OpRewritePattern<linalg::BatchMatmulOp> {
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}

LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
PatternRewriter &rewriter) const override {
if (!bufferization::hasTensorSemantics(batchMatmulOp))
return rewriter.notifyMatchFailure(
batchMatmulOp, "only matmul ops with tensors are supported");

Location loc = batchMatmulOp.getLoc();
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());

SmallVector<Value> dynamicDims;
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (type.isDynamicDim(2))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));

ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
if (transposeLHS) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
batchMatmulOp, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
batchMatmulOp, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
}

return success();
}

private:
bool transposeLHS;
};
} // namespace

void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
bool transposeLHS) {
patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
transposeLHS);
}
15 changes: 15 additions & 0 deletions mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: mlir-opt %s

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.linalg.transpose_matmul
} : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.yield
}
}
15 changes: 15 additions & 0 deletions mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: mlir-opt %s

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.linalg.transpose_matmul <rhs>
} : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.yield
}
}
Loading