Skip to content

Commit 42944da

Browse files
authored
[mlir][vector] Group re-order patterns together (#102856)
Group all patterns that re-order vector.transpose and vector.broadcast Ops (*) under `populateSinkVectorOpsPatterns`. These patterns are normally used to "sink" redundant Vector Ops, hence grouping together. Example: ```mlir %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> %r = arith.addf %at, %bt : vector<2x4xf32> ``` would get converted to: ```mlir %0 = arith.addf %a, %b : vector<4x2xf32> %r = vector.transpose %0, [1, 0] : vector<2x4xf32> ``` This patch also moves all tests for these patterns so that all of them are: * run under one test-flag: `test-vector-sink-patterns`, * located in one file: "vector-sink.mlir". To facilitate this change: * `-test-sink-vector-broadcast` is renamed as `test-vector-sink-patterns`, * "sink-vector-broadcast.mlir" is renamed as "vector-sink.mlir", * tests for `ReorderCastOpsOnBroadcast` and `ReorderElementwiseOpsOnTranspose` patterns are moved from "vector-reduce-to-contract.mlir" to "vector-sink.mlir", * `ReorderElementwiseOpsOnTranspose` patterns are removed from `populateVectorReductionToContractPatterns` and added to (newly created) `populateSinkVectorOpsPatterns`, * `ReorderCastOpsOnBroadcast` patterns are removed from `populateVectorReductionToContractPatterns` - these are already present in `populateSinkVectorOpsPatterns`. This should allow us better layering and more straightforward testing. For the latter, the goal is to be able to easily identify which pattern a particular test is exercising (especially when it's a specific pattern). NOTES FOR DOWNSTREAM USERS In order to preserve the current functionality, please make sure to add * `populateSinkVectorOpsPatterns`, wherever you are using `populateVectorReductionToContractPatterns`. Also, rename `populateSinkVectorBroadcastPatterns` as `populateSinkVectorOpsPatterns`. (*) I didn't notice any other re-order patterns.
1 parent a434cac commit 42944da

File tree

6 files changed

+145
-142
lines changed

6 files changed

+145
-142
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,22 @@ void populateVectorTransferFullPartialPatterns(
144144
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
145145
RewritePatternSet &patterns, PatternBenefit benefit = 1);
146146

147-
/// Patterns that remove redundant vector broadcasts.
148-
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
149-
PatternBenefit benefit = 1);
147+
/// Patterns that remove redundant Vector Ops by re-ordering them with
148+
/// e.g. elementwise Ops:
149+
/// ```
150+
/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
151+
/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
152+
/// %r = arith.addf %at, %bt : vector<2x4xf32>
153+
/// ```
154+
/// gets converted to:
155+
/// ```
156+
/// %0 = arith.addf %a, %b : vector<4x2xf32>
157+
/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
158+
/// ```
159+
/// At the moment, these patterns are limited to vector.broadcast and
160+
/// vector.transpose.
161+
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
162+
PatternBenefit benefit = 1);
150163

151164
/// Patterns that fold chained vector reductions. These patterns assume that
152165
/// elementwise operations (e.g., `arith.addf` with vector operands) are

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3452,7 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
34523452
if (!getDisableMultiReductionToContractPatterns())
34533453
vector::populateVectorReductionToContractPatterns(patterns);
34543454

3455-
vector::populateSinkVectorBroadcastPatterns(patterns);
3455+
vector::populateSinkVectorOpsPatterns(patterns);
34563456

34573457
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
34583458
linalg::LinalgCopyVTWForwardingPattern>(ctx,

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
20302030
void mlir::vector::populateVectorReductionToContractPatterns(
20312031
RewritePatternSet &patterns, PatternBenefit benefit) {
20322032
patterns.add<MultiReduceToContract, CombineContractBroadcast,
2033-
CombineContractABTranspose, CombineContractResultTranspose,
2034-
ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
2033+
CombineContractABTranspose, CombineContractResultTranspose>(
20352034
patterns.getContext(), benefit);
20362035
}
20372036

@@ -2043,10 +2042,11 @@ void mlir::vector::
20432042
benefit);
20442043
}
20452044

2046-
void mlir::vector::populateSinkVectorBroadcastPatterns(
2047-
RewritePatternSet &patterns, PatternBenefit benefit) {
2048-
patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
2049-
patterns.getContext(), benefit);
2045+
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
2046+
PatternBenefit benefit) {
2047+
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2048+
ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
2049+
benefit);
20502050
}
20512051

20522052
void mlir::vector::populateChainedVectorReductionFoldingPatterns(

mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -245,128 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
245245
}
246246

247247

248-
//===----------------------------------------------------------------------===//
249-
// [Pattern: ReorderCastOpsOnBroadcast]
250-
//
251-
// Reorder casting ops and vector ops. The casting ops have almost identical
252-
// pattern, so only arith.extsi op is tested.
253-
//
254-
// TODO: Potential duplication with sink-vector-broadcast.mlir
255-
//===----------------------------------------------------------------------===//
256-
257-
// -----
258-
259-
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
260-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
261-
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
262-
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
263-
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
264-
return %r : vector<2x4xi32>
265-
}
266-
267-
// -----
268-
269-
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
270-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
271-
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
272-
%b = vector.broadcast %a : i8 to vector<2x4xi8>
273-
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
274-
return %r : vector<2x4xi32>
275-
}
276-
277-
// -----
278-
279-
//===----------------------------------------------------------------------===//
280-
// [Pattern: ReorderElementwiseOpsOnTranspose]
281-
//
282-
// TODO: Potential duplication with sink-vector-broadcast.mlir
283-
//===----------------------------------------------------------------------===//
284-
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
285-
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
286-
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
287-
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
288-
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
289-
return %r : vector<2x4xi32>
290-
}
291-
292-
//===----------------------------------------------------------------------===//
293-
// Reorder elementwise ops and vector ops.
294-
// TODO: Potential duplication with sink-vector-broadcast.mlir
295-
//===----------------------------------------------------------------------===//
296-
297-
// -----
298-
299-
// CHECK-LABEL: func @transpose_elementwise_same_type
300-
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
301-
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
302-
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
303-
// CHECK: return %[[T]]
304-
305-
func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
306-
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
307-
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
308-
%r = arith.addf %at, %bt : vector<2x4xf32>
309-
return %r : vector<2x4xf32>
310-
}
311-
312-
// -----
313-
314-
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
315-
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
316-
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
317-
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
318-
// CHECK: return %[[T]]
319-
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
320-
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
321-
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
322-
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
323-
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
324-
return %r : vector<2x4xf32>
325-
}
326-
327-
// -----
328-
329-
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
330-
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
331-
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
332-
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
333-
// CHECK: return %[[T]]
334-
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
335-
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
336-
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
337-
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
338-
return %r : vector<2x4xi1>
339-
}
340-
341-
// -----
342-
343-
// CHECK-LABEL: func @transpose_elementwise_splat_constant
344-
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
345-
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
346-
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
347-
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
348-
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
349-
350-
func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
351-
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
352-
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
353-
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
354-
return %r : vector<6x4x2x3xf32>
355-
}
356-
357-
// -----
358-
359-
// CHECK-LABEL: func @transpose_elementwise_diff_map
360-
// CHECK: vector.transpose
361-
// CHECK: vector.transpose
362-
// CHECK: arith.addf
363-
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
364-
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
365-
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
366-
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
367-
return %r : vector<6x4x2x3xf32>
368-
}
369-
370248
// -----
371249

372250
// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>

mlir/test/Dialect/Vector/sink-vector-broadcast.mlir renamed to mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
22

33
//-----------------------------------------------------------------------------
44
// [Pattern: ReorderElementwiseOpsOnBroadcast]
@@ -208,3 +208,115 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
208208
%1 = vector.fma %0, %0, %0 : vector<1xf32>
209209
return %1 : vector<1xf32>
210210
}
211+
212+
//===----------------------------------------------------------------------===//
213+
// [Pattern: ReorderCastOpsOnBroadcast]
214+
//
215+
// Reorder casting ops and vector ops. The casting ops have almost identical
216+
// pattern, so only arith.extsi op is tested.
217+
//===----------------------------------------------------------------------===//
218+
219+
// -----
220+
221+
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
222+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
223+
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
224+
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
225+
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
226+
return %r : vector<2x4xi32>
227+
}
228+
229+
// -----
230+
231+
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
232+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
233+
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
234+
%b = vector.broadcast %a : i8 to vector<2x4xi8>
235+
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
236+
return %r : vector<2x4xi32>
237+
}
238+
239+
//===----------------------------------------------------------------------===//
240+
// [Pattern: ReorderElementwiseOpsOnTranspose]
241+
//===----------------------------------------------------------------------===//
242+
243+
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
244+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
245+
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
246+
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
247+
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
248+
return %r : vector<2x4xi32>
249+
}
250+
251+
// -----
252+
253+
// CHECK-LABEL: func @transpose_elementwise_same_type
254+
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
255+
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
256+
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
257+
// CHECK: return %[[T]]
258+
259+
func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
260+
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
261+
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
262+
%r = arith.addf %at, %bt : vector<2x4xf32>
263+
return %r : vector<2x4xf32>
264+
}
265+
266+
// -----
267+
268+
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
269+
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
270+
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
271+
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
272+
// CHECK: return %[[T]]
273+
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
274+
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
275+
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
276+
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
277+
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
278+
return %r : vector<2x4xf32>
279+
}
280+
281+
// -----
282+
283+
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
284+
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
285+
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
286+
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
287+
// CHECK: return %[[T]]
288+
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
289+
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
290+
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
291+
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
292+
return %r : vector<2x4xi1>
293+
}
294+
295+
// -----
296+
297+
// CHECK-LABEL: func @transpose_elementwise_splat_constant
298+
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
299+
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
300+
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
301+
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
302+
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
303+
304+
func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
305+
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
306+
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
307+
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
308+
return %r : vector<6x4x2x3xf32>
309+
}
310+
311+
// -----
312+
313+
// CHECK-LABEL: func @transpose_elementwise_diff_map
314+
// CHECK: vector.transpose
315+
// CHECK: vector.transpose
316+
// CHECK: arith.addf
317+
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
318+
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
319+
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
320+
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
321+
return %r : vector<6x4x2x3xf32>
322+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -374,27 +374,27 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
374374
}
375375
};
376376

377-
struct TestSinkVectorBroadcast
378-
: public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
379-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
377+
struct TestVectorSinkPatterns
378+
: public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
379+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
380380

381-
TestSinkVectorBroadcast() = default;
382-
TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
381+
TestVectorSinkPatterns() = default;
382+
TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
383383

384384
void getDependentDialects(DialectRegistry &registry) const override {
385385
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
386386
}
387387

388-
StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
388+
StringRef getArgument() const final { return "test-vector-sink-patterns"; }
389389

390390
StringRef getDescription() const final {
391391
return "Test lowering patterns that eliminate redundant brodacast "
392-
"operations.";
392+
"and transpose operations.";
393393
}
394394

395395
void runOnOperation() override {
396396
RewritePatternSet patterns(&getContext());
397-
populateSinkVectorBroadcastPatterns(patterns);
397+
populateSinkVectorOpsPatterns(patterns);
398398
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
399399
}
400400
};
@@ -919,7 +919,7 @@ void registerTestVectorLowerings() {
919919

920920
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
921921

922-
PassRegistration<TestSinkVectorBroadcast>();
922+
PassRegistration<TestVectorSinkPatterns>();
923923

924924
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
925925

0 commit comments

Comments
 (0)