Skip to content

Commit ad558f0

Browse files
committed
[mlir][vector] Group tests for re-order patterns
Moves all tests for patterns that re-order vector.transpose and vector.broadcast Ops (*) under one test-flag: * `test-vector-reorder-patterns`. To facilitate this, * `-test-sink-vector-broadcast` is renamed as `test-vector-reorder-patterns`, * "sink-vector-broadcast.mlir" is renamed as "vector-reorder.mlir", * tests for `ReorderCastOpsOnBroadcast` and `ReorderElementwiseOpsOnTranspose` patterns are moved from "vector-reduce-to-contract.mlir" to "vector-reorder.mlir", * `ReorderElementwiseOpsOnTranspose` patterns are removed from `populateVectorReductionToContractPatterns` and added to (newly created) `populateReoderVectorTransposePatterns`. * `ReorderCastOpsOnBroadcast` patterns are removed from `populateVectorReductionToContractPatterns` - these are already present in `populateSinkVectorBroadcastPatter`. This should allow us better layering and more straightforward testing. For the latter, the goal is to be able to easily identify which pattern a particular test is exercising (especially when it's a specific pattern). Note for downstream users: in order to preserve the current functionality, please make sure to add * `populateReoderVectorTransposePatterns` and `populateSinkVectorBroadcastPatter`, wherever you are using `populateVectorReductionToContractPatterns` (*) I didn't notice any other re-order patterns.
1 parent efe3db2 commit ad558f0

File tree

6 files changed

+128
-127
lines changed

6 files changed

+128
-127
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
148148
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
149149
PatternBenefit benefit = 1);
150150

151+
/// Patterns that re-order transpose Ops.
152+
void populateReoderVectorTransposePatterns(RewritePatternSet &patterns,
153+
PatternBenefit benefit = 1);
154+
151155
/// Patterns that fold chained vector reductions. These patterns assume that
152156
/// elementwise operations (e.g., `arith.addf` with vector operands) are
153157
/// cheaper than vector reduction.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3452,6 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
34523452
if (!getDisableMultiReductionToContractPatterns())
34533453
vector::populateVectorReductionToContractPatterns(patterns);
34543454

3455+
vector::populateReoderVectorTransposePatterns(patterns);
34553456
vector::populateSinkVectorBroadcastPatterns(patterns);
34563457

34573458
patterns.add<linalg::LinalgCopyVTRForwardingPattern,

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
20302030
void mlir::vector::populateVectorReductionToContractPatterns(
20312031
RewritePatternSet &patterns, PatternBenefit benefit) {
20322032
patterns.add<MultiReduceToContract, CombineContractBroadcast,
2033-
CombineContractABTranspose, CombineContractResultTranspose,
2034-
ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
2033+
CombineContractABTranspose, CombineContractResultTranspose>(
20352034
patterns.getContext(), benefit);
20362035
}
20372036

@@ -2049,6 +2048,12 @@ void mlir::vector::populateSinkVectorBroadcastPatterns(
20492048
patterns.getContext(), benefit);
20502049
}
20512050

2051+
void mlir::vector::populateReoderVectorTransposePatterns(
2052+
RewritePatternSet &patterns, PatternBenefit benefit) {
2053+
patterns.add<ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
2054+
benefit);
2055+
}
2056+
20522057
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
20532058
RewritePatternSet &patterns, PatternBenefit benefit) {
20542059
patterns.add<ChainedReduction>(patterns.getContext(), benefit);

mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -245,128 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
245245
}
246246

247247

248-
//===----------------------------------------------------------------------===//
249-
// [Pattern: ReorderCastOpsOnBroadcast]
250-
//
251-
// Reorder casting ops and vector ops. The casting ops have almost identical
252-
// pattern, so only arith.extsi op is tested.
253-
//
254-
// TODO: Potential duplication with sink-vector-broadcast.mlir
255-
//===----------------------------------------------------------------------===//
256-
257-
// -----
258-
259-
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
260-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
261-
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
262-
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
263-
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
264-
return %r : vector<2x4xi32>
265-
}
266-
267-
// -----
268-
269-
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
270-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
271-
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
272-
%b = vector.broadcast %a : i8 to vector<2x4xi8>
273-
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
274-
return %r : vector<2x4xi32>
275-
}
276-
277-
// -----
278-
279-
//===----------------------------------------------------------------------===//
280-
// [Pattern: ReorderElementwiseOpsOnTranspose]
281-
//
282-
// TODO: Potential duplication with sink-vector-broadcast.mlir
283-
//===----------------------------------------------------------------------===//
284-
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
285-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
286-
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
287-
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
288-
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
289-
return %r : vector<2x4xi32>
290-
}
291-
292-
//===----------------------------------------------------------------------===//
293-
// Reorder elementwise ops and vector ops.
294-
// TODO: Potential duplication with sink-vector-broadcast.mlir
295-
//===----------------------------------------------------------------------===//
296-
297-
// -----
298-
299-
// CHECK-LABEL: func @transpose_elementwise_same_type
300-
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
301-
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
302-
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
303-
// CHECK: return %[[T]]
304-
305-
func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
306-
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
307-
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
308-
%r = arith.addf %at, %bt : vector<2x4xf32>
309-
return %r : vector<2x4xf32>
310-
}
311-
312-
// -----
313-
314-
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
315-
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
316-
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
317-
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
318-
// CHECK: return %[[T]]
319-
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
320-
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
321-
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
322-
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
323-
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
324-
return %r : vector<2x4xf32>
325-
}
326-
327-
// -----
328-
329-
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
330-
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
331-
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
332-
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
333-
// CHECK: return %[[T]]
334-
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
335-
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
336-
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
337-
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
338-
return %r : vector<2x4xi1>
339-
}
340-
341-
// -----
342-
343-
// CHECK-LABEL: func @transpose_elementwise_splat_constant
344-
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
345-
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
346-
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
347-
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
348-
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
349-
350-
func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
351-
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
352-
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
353-
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
354-
return %r : vector<6x4x2x3xf32>
355-
}
356-
357-
// -----
358-
359-
// CHECK-LABEL: func @transpose_elementwise_diff_map
360-
// CHECK: vector.transpose
361-
// CHECK: vector.transpose
362-
// CHECK: arith.addf
363-
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
364-
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
365-
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
366-
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
367-
return %r : vector<6x4x2x3xf32>
368-
}
369-
370248
// -----
371249

372250
// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>

mlir/test/Dialect/Vector/sink-vector-broadcast.mlir renamed to mlir/test/Dialect/Vector/vector-reorder.mlir

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-vector-reorder-patterns -split-input-file | FileCheck %s
22

33
//-----------------------------------------------------------------------------
44
// [Pattern: ReorderElementwiseOpsOnBroadcast]
@@ -208,3 +208,115 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
208208
%1 = vector.fma %0, %0, %0 : vector<1xf32>
209209
return %1 : vector<1xf32>
210210
}
211+
212+
//===----------------------------------------------------------------------===//
213+
// [Pattern: ReorderCastOpsOnBroadcast]
214+
//
215+
// Reorder casting ops and vector ops. The casting ops have almost identical
216+
// pattern, so only arith.extsi op is tested.
217+
//===----------------------------------------------------------------------===//
218+
219+
// -----
220+
221+
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
222+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
223+
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
224+
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
225+
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
226+
return %r : vector<2x4xi32>
227+
}
228+
229+
// -----
230+
231+
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
232+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
233+
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
234+
%b = vector.broadcast %a : i8 to vector<2x4xi8>
235+
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
236+
return %r : vector<2x4xi32>
237+
}
238+
239+
//===----------------------------------------------------------------------===//
240+
// [Pattern: ReorderElementwiseOpsOnTranspose]
241+
//===----------------------------------------------------------------------===//
242+
243+
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
244+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
245+
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
246+
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
247+
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
248+
return %r : vector<2x4xi32>
249+
}
250+
251+
// -----
252+
253+
// CHECK-LABEL: func @transpose_elementwise_same_type
254+
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
255+
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
256+
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
257+
// CHECK: return %[[T]]
258+
259+
func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
260+
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
261+
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
262+
%r = arith.addf %at, %bt : vector<2x4xf32>
263+
return %r : vector<2x4xf32>
264+
}
265+
266+
// -----
267+
268+
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
269+
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
270+
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
271+
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
272+
// CHECK: return %[[T]]
273+
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
274+
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
275+
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
276+
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
277+
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
278+
return %r : vector<2x4xf32>
279+
}
280+
281+
// -----
282+
283+
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
284+
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
285+
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
286+
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
287+
// CHECK: return %[[T]]
288+
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
289+
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
290+
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
291+
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
292+
return %r : vector<2x4xi1>
293+
}
294+
295+
// -----
296+
297+
// CHECK-LABEL: func @transpose_elementwise_splat_constant
298+
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
299+
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
300+
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
301+
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
302+
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
303+
304+
func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
305+
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
306+
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
307+
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
308+
return %r : vector<6x4x2x3xf32>
309+
}
310+
311+
// -----
312+
313+
// CHECK-LABEL: func @transpose_elementwise_diff_map
314+
// CHECK: vector.transpose
315+
// CHECK: vector.transpose
316+
// CHECK: arith.addf
317+
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
318+
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
319+
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
320+
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
321+
return %r : vector<6x4x2x3xf32>
322+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,16 +385,17 @@ struct TestSinkVectorBroadcast
385385
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
386386
}
387387

388-
StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
388+
StringRef getArgument() const final { return "test-vector-reorder-patterns"; }
389389

390390
StringRef getDescription() const final {
391391
return "Test lowering patterns that eliminate redundant brodacast "
392-
"operations.";
392+
"and transpose operations.";
393393
}
394394

395395
void runOnOperation() override {
396396
RewritePatternSet patterns(&getContext());
397397
populateSinkVectorBroadcastPatterns(patterns);
398+
populateReoderVectorTransposePatterns(patterns);
398399
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
399400
}
400401
};

0 commit comments

Comments
 (0)