Skip to content

[mlir][vector] Group re-order patterns together #102856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,22 @@ void populateVectorTransferFullPartialPatterns(
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

/// Patterns that remove redundant vector broadcasts.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Patterns that remove redundant Vector Ops by re-ordering them with
/// e.g. elementwise Ops:
/// ```
/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
/// %r = arith.addf %at, %bt : vector<2x4xf32>
/// ```
/// gets converted to:
/// ```
/// %0 = arith.addf %a, %b : vector<4x2xf32>
/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
/// ```
/// At the moment, these patterns are limited to vector.broadcast and
/// vector.transpose.
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Patterns that fold chained vector reductions. These patterns assume that
/// elementwise operations (e.g., `arith.addf` with vector operands) are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3452,7 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
if (!getDisableMultiReductionToContractPatterns())
vector::populateVectorReductionToContractPatterns(patterns);

vector::populateSinkVectorBroadcastPatterns(patterns);
vector::populateSinkVectorOpsPatterns(patterns);

patterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(ctx,
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractABTranspose, CombineContractResultTranspose,
ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
Comment on lines -2033 to -2034
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: populateVectorReductionToContractPatterns() is used a few times in IREE, so this change will likely cause some breakages there. You should probably share a fix on the Discord before landing this 🙂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up! So is the expectation that you run the sink before you run contract conversion patterns.

I agree the current grouping of patterns is adhoc, but what is the expected path for users (iree and others) to get back the same code that this one method would give?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is the expectation that you run the sink before you run contract conversion patterns.

See the summary :)

NOTES FOR DOWNSTREAM USERS

In order to preserve the current functionality, please make sure to add

  • populateSinkVectorOpsPatterns,

wherever you are using populateVectorReductionToContractPatterns.
Also, rename populateSinkVectorBroadcastPatterns as
populateSinkVectorOpsPatterns.

I checked MLIR and IREE and in both cases these were required after. In IREE, I run ctest -R Codegen/SPIRV/ and ctest -R Codegen/LLVMCPU and both pass 100%. This is my diff:

diff --git a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
index 1824fb08bb..e2c8805d27 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
@@ -146,7 +146,7 @@ struct EmulateNarrowTypePass final
     }

     RewritePatternSet sinkBroadcast(ctx);
-    vector::populateSinkVectorBroadcastPatterns(sinkBroadcast);
+    vector::populateSinkVectorOpsPatterns(sinkBroadcast);
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(sinkBroadcast)))) {
       getOperation()->emitOpError("failed in sinking of broadcasts");
diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
index 98814d1342..8aee5ba2c0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
@@ -409,6 +409,7 @@ void GenericVectorizationPass::runOnOperation() {
     vector::populateVectorTransferPermutationMapLoweringPatterns(
         vectorizationPatterns);
     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
+    vector::populateSinkVectorOpsPatterns(vectorizationPatterns);
   }
   if (foldCastIntoContract) {
     vector::populateFoldArithExtensionPatterns(vectorizationPatterns);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
index aa72280e5f..f00af42845 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
@@ -96,6 +96,7 @@ public:
       vector::populateVectorTransferPermutationMapLoweringPatterns(
           contractionPatterns);
       vector::populateVectorReductionToContractPatterns(contractionPatterns);
+      vector::populateSinkVectorOpsPatterns(contractionPatterns);
       if (failed(applyPatternsAndFoldGreedily(
               funcOp, std::move(contractionPatterns)))) {
         return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
index 493894675c..c24da49689 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
@@ -314,6 +314,7 @@ public:
       // cancel them or embed into contract ops. Embedding in the flexible
       // contract ops will help to sustain the structure through various
       // transformations.
+      vector::populateSinkVectorOpsPatterns(patterns);
       vector::populateVectorReductionToContractPatterns(patterns);
       // Pull in patterns to canonicalize transfer ops.
       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 6b7afaa9db..dc678f1dfe 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 6b7afaa9db8f904ebf0262774e38e54b36598782
+Subproject commit dc678f1dfe49cd4d0b9136ff2490a482ae91e786

@MaheshRavishankar anything else that I should check?

Btw, this re-grouping has been a bit tricky to verify 100% - there are no tests that would require these patterns to be run together (otherwise I wouldn't be able to move the tests around as I did).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for doing this! Part of my concern here as well was that this combination of things wasnt tested properly in tree. Not your fault, it is just how this was done (there are implicit assumptions of what patterns go together that isnt always clear in the vector dialect).

CombineContractABTranspose, CombineContractResultTranspose>(
patterns.getContext(), benefit);
}

Expand All @@ -2043,10 +2042,11 @@ void mlir::vector::
benefit);
}

void mlir::vector::populateSinkVectorBroadcastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
patterns.getContext(), benefit);
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
benefit);
}

void mlir::vector::populateChainedVectorReductionFoldingPatterns(
Expand Down
122 changes: 0 additions & 122 deletions mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -245,128 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
}


//===----------------------------------------------------------------------===//
// [Pattern: ReorderCastOpsOnBroadcast]
//
// Reorder casting ops and vector ops. The casting ops have almost identical
// pattern, so only arith.extsi op is tested.
//
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//

// -----

func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}

// -----

func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
%b = vector.broadcast %a : i8 to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}

// -----

//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}

//===----------------------------------------------------------------------===//
// Reorder elementwise ops and vector ops.
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: func @transpose_elementwise_same_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
// CHECK: return %[[T]]

func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.addf %at, %bt : vector<2x4xf32>
return %r : vector<2x4xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
return %r : vector<2x4xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
return %r : vector<2x4xi1>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_splat_constant
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>

func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_map
// CHECK: vector.transpose
// CHECK: vector.transpose
// CHECK: arith.addf
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}

// -----

// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s

//-----------------------------------------------------------------------------
// [Pattern: ReorderElementwiseOpsOnBroadcast]
Expand Down Expand Up @@ -208,3 +208,115 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
%1 = vector.fma %0, %0, %0 : vector<1xf32>
return %1 : vector<1xf32>
}

//===----------------------------------------------------------------------===//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diff-ed moved tests. Looks good. 👍

// [Pattern: ReorderCastOpsOnBroadcast]
//
// Reorder casting ops and vector ops. The casting ops have almost identical
// pattern, so only arith.extsi op is tested.
//===----------------------------------------------------------------------===//

// -----

func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}

// -----

func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
%b = vector.broadcast %a : i8 to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}

//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//

func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_same_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
// CHECK: return %[[T]]

func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.addf %at, %bt : vector<2x4xf32>
return %r : vector<2x4xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
return %r : vector<2x4xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
return %r : vector<2x4xi1>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_splat_constant
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>

func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_map
// CHECK: vector.transpose
// CHECK: vector.transpose
// CHECK: arith.addf
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}
18 changes: 9 additions & 9 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,27 +374,27 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
}
};

struct TestSinkVectorBroadcast
: public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
struct TestVectorSinkPatterns
: public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)

TestSinkVectorBroadcast() = default;
TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
TestVectorSinkPatterns() = default;
TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
}

StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
StringRef getArgument() const final { return "test-vector-sink-patterns"; }

StringRef getDescription() const final {
return "Test lowering patterns that eliminate redundant brodacast "
"operations.";
"and transpose operations.";
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSinkVectorBroadcastPatterns(patterns);
populateSinkVectorOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down Expand Up @@ -919,7 +919,7 @@ void registerTestVectorLowerings() {

PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();

PassRegistration<TestSinkVectorBroadcast>();
PassRegistration<TestVectorSinkPatterns>();

PassRegistration<TestVectorReduceToContractPatternsPatterns>();

Expand Down
Loading