Skip to content

[mlir] [linalg] Add pattern to swap transpose with broadcast #97063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Utils/IndexingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ SmallVector<int64_t>
computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
ArrayRef<int64_t> desiredPositions);

/// Returns a permutation vector that drop the input dims in
/// dropPositions from inputPerm.
///
/// For example, inputPerm = {2, 4, 0, 1, 3} and dropPositions= {1, 2} would
/// result in a {2, 0, 1} permutation vector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the result correct? Shouldn't this be {2, 1, 3} or am I getting this wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry not make this function clear.
This function returns a new permutation after removing input position in removePositions.

The removed position is "2", "1" in input pos, after remove, we have {4, 0, 3}.
To be a valid permutation, returned perm should start from "0", result should be {2, 0, 1}.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think we should already have this functionality implemented on AffineMap. Would you mind taking a look at the utilities in AffineMap.h? There are some drop... methods might get you what you need. You can get an AffineMap from a permutation with: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/AffineMap.h#L103

Copy link
Contributor Author

@cxy-1993 cxy-1993 Jul 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion, this approach is very interesting (at least my function name is not precise enough, I think dropDims would be better). We can indeed replace the calculation of perm with the calculation of affine map using applyPermutationMap and getPermutationMap. Before making the changes, I would like to discuss some points: This will simplify the calculation of perm, but it will introduce more compilation time -- we have to construct affine maps to reuse the affine map operation functions. Is this a more reasonable approach? If so, all the util functions in the Permutation utils series have corresponding util functions in the affine map. Should we systematically replace them all?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. My comment was motivated by the high proliferation of AffineMap and permutation utilities over the years (to the point that sometimes it's a challenge, even for people familiar with the code, to figure out if something exists already). However, I think this adding this one is justified as it's combined with other utilities that work on permutations.

SmallVector<int64_t> dropDims(ArrayRef<int64_t> inputPerm,
ArrayRef<int64_t> dropPositions);

/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
Expand Down
61 changes: 60 additions & 1 deletion mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1890,9 +1890,68 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};

/// This pattern canonicalize transpose by swapping the order of
/// broadcast and transpose:
/// transpose(broadcast(input)) -> broadcast(transpose(input))
struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
Value input = transposeOp.getInput();
BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
if (!input.hasOneUse() || !broadcastOp)
return failure();

ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
ArrayRef<int64_t> perms = transposeOp.getPermutation();

// Get new perms and new dimensions.
SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
SmallVector<int64_t> resultDimensions;
unsigned dimensionSize = dimensions.size();
for (unsigned i = 0; i < dimensionSize; ++i)
resultDimensions.push_back(invertPerm[dimensions[i]]);

// Create transpose result.
Value broadcastInput = broadcastOp.getInput();
Location loc = transposeOp.getLoc();
MLIRContext *ctx = transposeOp.getContext();
SmallVector<OpFoldResult> dims;
auto broadcastInputTy =
mlir::cast<RankedTensorType>(broadcastInput.getType());
unsigned inputRank = broadcastInputTy.getRank();
for (unsigned i = 0; i < inputRank; ++i) {
if (broadcastInputTy.isDynamicDim(i)) {
dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
->getResult(0));
} else {
dims.push_back(IntegerAttr::get(IndexType::get(ctx),
broadcastInputTy.getDimSize(i)));
}
}
SmallVector<OpFoldResult> transposeResultShapes =
applyPermutation(dims, resultPerms);
Value transposeInit = rewriter.create<tensor::EmptyOp>(
transposeOp.getLoc(), transposeResultShapes,
broadcastInputTy.getElementType());

// Create broadcast(transpose(input)).
Value transposeResult =
rewriter
.create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
resultPerms)
->getResult(0);
rewriter.replaceOpWithNewOp<BroadcastOp>(
transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
return success();
}
};

void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldTransposeWithTranspose>(context);
results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Utils/IndexingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,32 @@ mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
return res;
}

SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
ArrayRef<int64_t> dropPositions) {
assert(inputPerm.size() >= dropPositions.size() &&
"expect inputPerm size large than position to drop");
SmallVector<int64_t> res;
unsigned permSize = inputPerm.size();
for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {
int64_t targetIndex = inputPerm[inputIndex];
bool shouldDrop = false;
unsigned dropSize = dropPositions.size();
for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {
if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
shouldDrop = true;
break;
}
if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
targetIndex--;
}
}
if (!shouldDrop) {
res.push_back(targetIndex);
}
}
return res;
}

SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront,
unsigned dropBack) {
Expand Down
75 changes: 74 additions & 1 deletion mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
return %0 : tensor<2x3xf32>
}

// ----
// -----

func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
Expand Down Expand Up @@ -1096,3 +1096,76 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
func.return %transpose2 : tensor<3x4x5xf32>
}

// -----

func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
%init1: tensor<1x2x3x4x5x6xf32>,
%init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
// CHECK-LABEL: @broadcast_transpose_fold
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
// CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
// CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
%broadcast = linalg.broadcast
ins(%input : tensor<2x4x5xf32>)
outs(%init1 : tensor<1x2x3x4x5x6xf32>)
dimensions = [0, 2, 5]
%transpose = linalg.transpose
ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
outs(%init2 : tensor<1x6x2x3x5x4xf32>)
permutation = [0, 5, 1, 2, 4, 3]
func.return %transpose : tensor<1x6x2x3x5x4xf32>
}

// -----

func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
%init1: tensor<1x?x3x?x5x6xf32>,
%init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
// CHECK-LABEL: @broadcast_transpose_fold_dynamic
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
// CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
// CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
// CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
%broadcast = linalg.broadcast
ins(%input : tensor<?x?x5xf32>)
outs(%init1 : tensor<1x?x3x?x5x6xf32>)
dimensions = [0, 2, 5]
%transpose = linalg.transpose
ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
outs(%init2 : tensor<1x3x?x6x5x?xf32>)
permutation = [0, 2, 3, 5, 4, 1]
func.return %transpose : tensor<1x3x?x6x5x?xf32>
}

// -----

func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>,
%init1: tensor<2x4xf32>,
%init2: tensor<4x2xf32>) -> tensor<4x2xf32> {
// CHECK-LABEL: @broadcast_transpose_fold_2dim
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0]
// CHECK: return %[[BROADCAST]] : tensor<4x2xf32>
%broadcast = linalg.broadcast
ins(%input : tensor<2xf32>)
outs(%init1 : tensor<2x4xf32>)
dimensions = [1]
%transpose = linalg.transpose
ins(%broadcast : tensor<2x4xf32>)
outs(%init2 : tensor<4x2xf32>)
permutation = [1, 0]
func.return %transpose : tensor<4x2xf32>
}
Loading