Skip to content

Commit 9c8bbf9

Browse files
committed
[mlir] [linalg] Add pattern to swap transpose with broadcast
Add a pattern that implement: transpose(broadcast(input)) -> broadcast(transpose(input))
1 parent a139f84 commit 9c8bbf9

File tree

8 files changed

+214
-5
lines changed

8 files changed

+214
-5
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,15 +1709,21 @@ void populateSplitReductionPattern(
17091709
void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
17101710
bool transposeLHS = true);
17111711

1712+
/// Patterns to convert transpose(broadcast(input)) to
1713+
/// broadcast(transpose(input)).
1714+
void populateSwapTransposeWithBroadcastPatterns(RewritePatternSet &patterns);
1715+
17121716
/// Patterns to block pack Linalg matmul ops.
17131717
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
17141718
const ControlBlockPackMatmulFn &controlFn);
17151719

17161720
/// Adds patterns that reduce the rank of named contraction ops that have
1717-
/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
1718-
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
1719-
/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
1720-
/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
1721+
/// unit dimensions in the operand(s) by converting to a sequence of
1722+
/// `collapse_shape`,
1723+
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For
1724+
/// example a `linalg.batch_matmul` with unit batch size will convert to
1725+
/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
1726+
/// convert to a `linalg.dot`.
17211727
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
17221728

17231729
} // namespace linalg

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,14 @@ SmallVector<int64_t>
243243
computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
244244
ArrayRef<int64_t> desiredPositions);
245245

246+
/// Returns a permutation vector that drop the input dims in
247+
/// dropPositions from inputPerm.
248+
///
249+
/// For example, inputPerm = {2, 4, 0, 1, 3} and dropPositions= {1, 2} would
250+
/// result in a {2, 0, 1} permutation vector.
251+
SmallVector<int64_t> dropDims(ArrayRef<int64_t> inputPerm,
252+
ArrayRef<int64_t> dropPositions);
253+
246254
/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
247255
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
248256
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3333
SplitReduction.cpp
3434
SubsetInsertionOpInterfaceImpl.cpp
3535
SwapExtractSliceWithFillPatterns.cpp
36+
SwapTransposeWithBroadcast.cpp
3637
Tiling.cpp
3738
TilingInterfaceImpl.cpp
3839
Transforms.cpp
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//===- SwapTransposeWithBroadcast.cpp - Swap transpose with broadcast op --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This is a pattern swap broadcast with transpose.
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12+
#include "mlir/Dialect/Utils/IndexingUtils.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15+
16+
#define DEBUG_TYPE "linalg-swap-transpose-with-broadcast"
17+
18+
using namespace mlir;
19+
using namespace mlir::linalg;
20+
21+
namespace {
22+
/// This pattern canonicalize transpose by swapping the order of
23+
/// broadcast and transpose:
24+
/// transpose(broadcast(input)) -> broadcast(transpose(input))
25+
struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
26+
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
27+
28+
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
29+
PatternRewriter &rewriter) const override {
30+
Value input = transposeOp.getInput();
31+
BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
32+
if (!input.hasOneUse() || !broadcastOp)
33+
return failure();
34+
35+
ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
36+
ArrayRef<int64_t> perms = transposeOp.getPermutation();
37+
38+
// Get new perms and new dimensions.
39+
SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
40+
SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
41+
SmallVector<int64_t> resultDimensions;
42+
for (unsigned i = 0; i < dimensions.size(); i++) {
43+
resultDimensions.push_back(invertPerm[dimensions[i]]);
44+
}
45+
46+
// Create transpose result.
47+
Value broadcastInput = broadcastOp.getInput();
48+
Location loc = transposeOp.getLoc();
49+
MLIRContext *ctx = transposeOp.getContext();
50+
SmallVector<OpFoldResult> dims;
51+
auto broadcastInputTy =
52+
mlir::cast<RankedTensorType>(broadcastInput.getType());
53+
for (unsigned i = 0; i < broadcastInputTy.getRank(); i++) {
54+
if (broadcastInputTy.isDynamicDim(i)) {
55+
dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
56+
->getResult(0));
57+
} else {
58+
dims.push_back(IntegerAttr::get(IndexType::get(ctx),
59+
broadcastInputTy.getDimSize(i)));
60+
}
61+
}
62+
SmallVector<OpFoldResult> transposeResultShapes =
63+
applyPermutation(dims, resultPerms);
64+
Value transposeInit = rewriter.create<tensor::EmptyOp>(
65+
transposeOp.getLoc(), transposeResultShapes,
66+
broadcastInputTy.getElementType());
67+
68+
// Create broadcast(transpose(input)).
69+
Value transposeResult =
70+
rewriter
71+
.create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
72+
resultPerms)
73+
->getResult(0);
74+
rewriter.replaceOpWithNewOp<BroadcastOp>(
75+
transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
76+
return success();
77+
}
78+
};
79+
} // namespace
80+
81+
void mlir::linalg::populateSwapTransposeWithBroadcastPatterns(
82+
RewritePatternSet &patterns) {
83+
patterns.add<SwapTransposeWithBroadcast>(patterns.getContext());
84+
}

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,31 @@ mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
252252
return res;
253253
}
254254

255+
SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
256+
ArrayRef<int64_t> dropPositions) {
257+
assert(inputPerm.size() >= dropPositions.size() &&
258+
"expect inputPerm size large than position to drop");
259+
SmallVector<int64_t> res;
260+
for (unsigned inputIndex = 0; inputIndex < inputPerm.size(); ++inputIndex) {
261+
int64_t targetIndex = inputPerm[inputIndex];
262+
bool shouldDrop = false;
263+
for (unsigned dropIndex = 0; dropIndex < dropPositions.size();
264+
dropIndex++) {
265+
if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
266+
shouldDrop = true;
267+
break;
268+
}
269+
if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
270+
targetIndex--;
271+
}
272+
}
273+
if (!shouldDrop) {
274+
res.push_back(targetIndex);
275+
}
276+
}
277+
return res;
278+
}
279+
255280
SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
256281
unsigned dropFront,
257282
unsigned dropBack) {

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
10171017
return %0 : tensor<2x3xf32>
10181018
}
10191019

1020-
// ----
1020+
// -----
10211021

10221022
func.func @transpose_1d(%input: tensor<16xf32>,
10231023
%init: tensor<16xf32>) -> tensor<16xf32> {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-swap-transpose-with-broadcast %s | FileCheck %s
2+
3+
func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
4+
%init1: tensor<1x2x3x4x5x6xf32>,
5+
%init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
6+
// CHECK-LABEL: @broadcast_transpose_fold
7+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
8+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
9+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
10+
// CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
11+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
12+
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
13+
// CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
14+
%broadcast = linalg.broadcast
15+
ins(%input : tensor<2x4x5xf32>)
16+
outs(%init1 : tensor<1x2x3x4x5x6xf32>)
17+
dimensions = [0, 2, 5]
18+
%transpose = linalg.transpose
19+
ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
20+
outs(%init2 : tensor<1x6x2x3x5x4xf32>)
21+
permutation = [0, 5, 1, 2, 4, 3]
22+
func.return %transpose : tensor<1x6x2x3x5x4xf32>
23+
}
24+
25+
// -----
26+
27+
func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
28+
%init1: tensor<1x?x3x?x5x6xf32>,
29+
%init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
30+
// CHECK-LABEL: @broadcast_transpose_fold_dynamic
31+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
32+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
33+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
34+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
35+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
36+
// CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
37+
// CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
38+
// CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
39+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
40+
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
41+
// CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
42+
%broadcast = linalg.broadcast
43+
ins(%input : tensor<?x?x5xf32>)
44+
outs(%init1 : tensor<1x?x3x?x5x6xf32>)
45+
dimensions = [0, 2, 5]
46+
%transpose = linalg.transpose
47+
ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
48+
outs(%init2 : tensor<1x3x?x6x5x?xf32>)
49+
permutation = [0, 2, 3, 5, 4, 1]
50+
func.return %transpose : tensor<1x3x?x6x5x?xf32>
51+
}
52+
53+
// -----
54+
55+
func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>,
56+
%init1: tensor<2x4xf32>,
57+
%init2: tensor<4x2xf32>) -> tensor<4x2xf32> {
58+
// CHECK-LABEL: @broadcast_transpose_fold_2dim
59+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
60+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
61+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32>
62+
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0]
63+
// CHECK: return %[[BROADCAST]] : tensor<4x2xf32>
64+
%broadcast = linalg.broadcast
65+
ins(%input : tensor<2xf32>)
66+
outs(%init1 : tensor<2x4xf32>)
67+
dimensions = [1]
68+
%transpose = linalg.transpose
69+
ins(%broadcast : tensor<2x4xf32>)
70+
outs(%init2 : tensor<4x2xf32>)
71+
permutation = [1, 0]
72+
func.return %transpose : tensor<4x2xf32>
73+
}

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ struct TestLinalgTransforms
115115
llvm::cl::desc(
116116
"Test patterns to swap tensor.extract_slice(linalg.fill())"),
117117
llvm::cl::init(false)};
118+
Option<bool> testSwapTransposeWithBroadcast{
119+
*this, "test-swap-transpose-with-broadcast",
120+
llvm::cl::desc("Test patterns to swap transpose(broadcast(input))"),
121+
llvm::cl::init(false)};
118122
Option<bool> testEraseUnusedOperandsAndResults{
119123
*this, "test-erase-unused-operands-and-results",
120124
llvm::cl::desc("Test patterns to erase unused operands and results"),
@@ -195,6 +199,12 @@ static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
195199
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
196200
}
197201

202+
static void applySwapTransposeWithBroadcast(func::FuncOp funcOp) {
203+
RewritePatternSet patterns(funcOp.getContext());
204+
populateSwapTransposeWithBroadcastPatterns(patterns);
205+
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
206+
}
207+
198208
static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
199209
RewritePatternSet patterns(funcOp.getContext());
200210
populateEraseUnusedOperandsAndResultsPatterns(patterns);
@@ -227,6 +237,8 @@ void TestLinalgTransforms::runOnOperation() {
227237
return applyBubbleUpExtractSliceOpPattern(getOperation());
228238
if (testSwapExtractSliceWithFill)
229239
return applySwapExtractSliceWithFillPattern(getOperation());
240+
if (testSwapTransposeWithBroadcast)
241+
return applySwapTransposeWithBroadcast(getOperation());
230242
if (testEraseUnusedOperandsAndResults)
231243
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
232244
if (testEraseUnnecessaryInputs)

0 commit comments

Comments
 (0)