Skip to content

Commit d0ec4a8

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Add pad and hoist test pass.
Adding a padding and hoisting pattern, a test pass, and tests. The patch prepares the split of tiling/fusion and padding. Depends On D112255 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D112412
1 parent 6c2f26a commit d0ec4a8

File tree

4 files changed

+330
-0
lines changed

4 files changed

+330
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,47 @@ using PaddingValueComputationFunction =
460460
/// OpOperand shall be marked as nofold to enable packing.
461461
using PaddingNoFoldComputationFunction = std::function<bool(OpOperand &)>;
462462

463+
/// Callback returning the number of loops to hoist the pad tensor operation
464+
/// defining the given OpOperand.
465+
using PaddingHoistComputationFunction = std::function<int64_t(OpOperand &)>;
466+
467+
struct LinalgPaddingOptions {
468+
/// Callback returning the padding value to use for a given OpOperand or
469+
/// failure for no padding. Padding operations are introduced if
470+
/// `paddingValueComputationFunction` is set and does not return failure.
471+
/// Padding all operands guarantees the operation is statically shaped and
472+
/// thus can be vectorized.
473+
PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
474+
475+
LinalgPaddingOptions &
476+
setPaddingValueComputationFunction(PaddingValueComputationFunction fun) {
477+
paddingValueComputationFunction = std::move(fun);
478+
return *this;
479+
}
480+
481+
/// Callback returning true if the pad tensor operation defining the given
482+
/// OpOperand shall be marked as nofold to enable packing. A padding operation
483+
/// is only marked nofold if `paddingNoFoldComputationFunction` is set and
484+
/// returns true. Otherwise, the nofold attribute is set to false.
485+
PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr;
486+
487+
LinalgPaddingOptions &
488+
setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) {
489+
paddingNoFoldComputationFunction = std::move(fun);
490+
return *this;
491+
}
492+
493+
/// Callback returning the number of loops to hoist the pad tensor operation
494+
/// defining the given OpOperand.
495+
PaddingHoistComputationFunction paddingHoistComputationFunction = nullptr;
496+
497+
LinalgPaddingOptions &
498+
setPaddingHoistComputationFunction(PaddingHoistComputationFunction fun) {
499+
paddingHoistComputationFunction = std::move(fun);
500+
return *this;
501+
}
502+
};
503+
463504
struct LinalgTilingOptions {
464505
/// Computation function that returns the tile sizes for each operation.
465506
/// Delayed construction of constant tile sizes should occur to interoperate
@@ -650,6 +691,35 @@ struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
650691
}
651692
};
652693

694+
///
695+
/// Linalg padding pattern.
696+
///
697+
/// Apply the `padding` transformation as a pattern.
698+
/// `filter` controls LinalgTransformMarker matching and update when specified.
699+
/// See `padding` for more details.
700+
struct LinalgPaddingPattern : public RewritePattern {
701+
// Entry point to match any LinalgOp OpInterface.
702+
LinalgPaddingPattern(
703+
MLIRContext *context,
704+
LinalgPaddingOptions options = LinalgPaddingOptions(),
705+
LinalgTransformationFilter filter = LinalgTransformationFilter(),
706+
PatternBenefit benefit = 1);
707+
// Entry point to match a specific LinalgOp.
708+
LinalgPaddingPattern(
709+
StringRef opName, MLIRContext *context,
710+
LinalgPaddingOptions options = LinalgPaddingOptions(),
711+
LinalgTransformationFilter filter = LinalgTransformationFilter(),
712+
PatternBenefit benefit = 1);
713+
LogicalResult matchAndRewrite(Operation *op,
714+
PatternRewriter &rewriter) const override;
715+
716+
private:
717+
/// LinalgTransformMarker handles special attribute manipulations.
718+
LinalgTransformationFilter filter;
719+
/// Options to control padding and hoisting.
720+
LinalgPaddingOptions options;
721+
};
722+
653723
struct LinalgFusionOptions {
654724
/// List of operands indices to use for fusion.
655725
llvm::SmallSet<unsigned, 1> indicesToFuse = {};

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1717
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
1818
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
19+
#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
1920
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2021
#include "mlir/Dialect/SCF/Transforms.h"
2122
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -470,6 +471,64 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
470471
return success();
471472
}
472473

474+
/// Linalg padding pattern.
475+
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
476+
MLIRContext *context, LinalgPaddingOptions options,
477+
LinalgTransformationFilter filter, PatternBenefit benefit)
478+
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
479+
options(options) {}
480+
481+
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
482+
StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
483+
LinalgTransformationFilter filter, PatternBenefit benefit)
484+
: RewritePattern(opName, benefit, context, {}), filter(filter),
485+
options(options) {}
486+
487+
LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
488+
Operation *op, PatternRewriter &rewriter) const {
489+
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
490+
if (!linalgOp)
491+
return failure();
492+
if (!linalgOp.hasTensorSemantics())
493+
return failure();
494+
if (failed(filter.checkAndNotify(rewriter, op)))
495+
return failure();
496+
497+
// Pad the operation.
498+
LinalgOp paddedOp;
499+
FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
500+
rewriter, linalgOp, options.paddingValueComputationFunction,
501+
options.paddingNoFoldComputationFunction, paddedOp);
502+
if (failed(newResults))
503+
return failure();
504+
505+
// Compute the desired hoisting depths.
506+
SmallVector<int64_t> depths;
507+
if (options.paddingHoistComputationFunction) {
508+
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
509+
depths.push_back(options.paddingHoistComputationFunction(*opOperand));
510+
}
511+
512+
// Hoist the padding.
513+
for (auto en : enumerate(depths)) {
514+
OpOperand &opOperand = paddedOp->getOpOperand(en.index());
515+
auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>();
516+
if (!padTensorOp || en.value() == 0)
517+
continue;
518+
PadTensorOp hoistedOp;
519+
FailureOr<Value> newResult =
520+
hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp);
521+
if (failed(newResult))
522+
continue;
523+
rewriter.replaceOp(padTensorOp, newResult.getValue());
524+
}
525+
526+
// Replace the original operation to pad.
527+
rewriter.replaceOp(op, newResults.getValue());
528+
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
529+
return success();
530+
}
531+
473532
/// Linalg generic interchange pattern.
474533
mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
475534
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=2,1,0" -cse -canonicalize -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=4,3,0" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-DOUBLE
3+
4+
// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (5, -d0 + 24)>
5+
// CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0) -> (8, -d0 + 12)>
6+
// CHECK-DAG: #[[DIV6:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 6)>
7+
#map0 = affine_map<(d0) -> (5, -d0 + 24)>
8+
#map1 = affine_map<(d0) -> (8, -d0 + 12)>
9+
#map2 = affine_map<(d0) -> (7, -d0 + 25)>
10+
11+
// CHECK: single_tiling
12+
// CHECK-DOUBLE: single_tiling
13+
14+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
15+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
16+
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
17+
func @single_tiling(%arg0: tensor<24x12xf32>,
18+
%arg1: tensor<12x25xf32>,
19+
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
20+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
21+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5
22+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8
23+
%c0 = arith.constant 0 : index
24+
%c12 = arith.constant 12 : index
25+
%c25 = arith.constant 25 : index
26+
%c24 = arith.constant 24 : index
27+
%c6 = arith.constant 6 : index
28+
%c7 = arith.constant 7 : index
29+
%c5 = arith.constant 5 : index
30+
31+
// CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
32+
%0 = scf.for %arg3 = %c0 to %c24 step %c5 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) {
33+
34+
// Packing the first input operand for all values of IV2 (IV2x5x6).
35+
// CHECK: = linalg.init_tensor [2, 5, 6]
36+
// CHECK: %[[PT0:.*]] = scf.for %[[P0IV2:[0-9a-z]+]] =
37+
// CHECK: %[[PIDX0:.*]] = affine.apply #[[DIV6]](%[[P0IV2]])
38+
// CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
39+
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
40+
// CHECK-SAME: %[[IV0]], %[[P0IV2]]
41+
// CHECK-SAME: %[[TS0]], 6
42+
// CHECK: %[[V0:.*]] = arith.subi %[[C5]], %[[TS0]]
43+
// CHECK: %[[T1:.*]] = linalg.pad_tensor %[[T0]] nofold {{.*}} high[%[[V0]]
44+
// CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1:.*]] into %{{.*}}[%[[PIDX0]], 0, 0]
45+
// CHECK: scf.yield %[[T2:.*]]
46+
47+
// CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
48+
%1 = scf.for %arg5 = %c0 to %c25 step %c7 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) {
49+
50+
// Packing the second input operand for all values of IV2 (IV2x6x8).
51+
// CHECK: = linalg.init_tensor [2, 6, 8]
52+
// CHECK: %[[PT1:.*]] = scf.for %[[P1IV2:[0-9a-z]+]] =
53+
// CHECK: %[[PIDX1:.*]] = affine.apply #[[DIV6]](%[[P1IV2]])
54+
// CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]])
55+
// CHECK: %[[T3:.*]] = tensor.extract_slice %[[ARG1]]
56+
// CHECK-SAME: %[[P1IV2]], %[[IV1]]
57+
// CHECK-SAME: 6, %[[TS1]]
58+
// CHECK: %[[V1:.*]] = arith.subi %[[C8]], %[[TS1]]
59+
// CHECK: %[[T4:.*]] = linalg.pad_tensor %[[T3]] nofold {{.*}} high[%[[C0]], %[[V1]]
60+
// CHECK: %[[T5:.*]] = tensor.insert_slice %[[T4:.*]] into %{{.*}}[%[[PIDX1]], 0, 0]
61+
// CHECK: scf.yield %[[T5:.*]]
62+
63+
// CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG4:.*]] =
64+
%2 = scf.for %arg7 = %c0 to %c12 step %c6 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) {
65+
%3 = affine.min #map0(%arg3)
66+
// Index the packed operands.
67+
// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[DIV6]](%[[IV2]])
68+
// CHECK-DAG: %[[T6:.*]] = tensor.extract_slice %[[PT0]][%[[IDX]]
69+
// CHECK-DAG: %[[T7:.*]] = tensor.extract_slice %[[PT1]][%[[IDX]]
70+
%4 = tensor.extract_slice %arg0[%arg3, %arg7] [%3, 6] [1, 1] : tensor<24x12xf32> to tensor<?x6xf32>
71+
%5 = affine.min #map1(%arg5)
72+
%6 = tensor.extract_slice %arg1[%arg7, %arg5] [6, %5] [1, 1] : tensor<12x25xf32> to tensor<6x?xf32>
73+
74+
// Pad the output operand without setting the nofold attribute.
75+
// CHECK-DAG: %[[T8:.*]] = tensor.extract_slice %[[ARG4]][%[[IV0]], %[[IV1]]
76+
// CHECK: %[[T9:.*]] = linalg.pad_tensor %[[T8]] low
77+
%7 = tensor.extract_slice %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor<24x25xf32> to tensor<?x?xf32>
78+
79+
// Check matmul uses the packed input operands and the padded output operand.
80+
// CHECK: = linalg.matmul ins(%[[T6]], %[[T7]]{{.*}} outs(%[[T9]]
81+
%8 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%4, %6 : tensor<?x6xf32>, tensor<6x?xf32>) outs(%7 : tensor<?x?xf32>) -> tensor<?x?xf32>
82+
%9 = tensor.insert_slice %8 into %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor<?x?xf32> into tensor<24x25xf32>
83+
scf.yield %9 : tensor<24x25xf32>
84+
}
85+
scf.yield %2 : tensor<24x25xf32>
86+
}
87+
scf.yield %1 : tensor<24x25xf32>
88+
}
89+
return %0 : tensor<24x25xf32>
90+
}
91+
92+
// -----
93+
94+
#map0 = affine_map<(d0) -> (15, -d0 + 24)>
95+
#map1 = affine_map<(d0) -> (16, -d0 + 25)>
96+
#map2 = affine_map<(d0, d1) -> (5, -d0 + d1)>
97+
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
98+
#map4 = affine_map<(d0, d1) -> (6, -d0 + d1)>
99+
100+
// CHECK: double_tiling
101+
// CHECK-DOUBLE: double_tiling
102+
103+
// CHECK-DOUBLE-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
104+
// CHECK-DOUBLE-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
105+
// CHECK-DOUBLE-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
106+
func @double_tiling(%arg0: tensor<24x12xf32>,
107+
%arg1: tensor<12x25xf32>,
108+
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
109+
%c15 = arith.constant 15 : index
110+
%c16 = arith.constant 16 : index
111+
%c24 = arith.constant 24 : index
112+
%c25 = arith.constant 25 : index
113+
%c0 = arith.constant 0 : index
114+
%c5 = arith.constant 5 : index
115+
%c6 = arith.constant 6 : index
116+
117+
// Packing the first input operand.
118+
// CHECK-DOUBLE: = linalg.init_tensor
119+
// CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold
120+
121+
// CHECK-DOUBLE: scf.for %[[IV0:[0-9a-zA-Z]*]] =
122+
%0 = scf.for %arg3 = %c0 to %c24 step %c15 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) {
123+
124+
// Packing the second input operand.
125+
// CHECK-DOUBLE: = linalg.init_tensor
126+
// CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold
127+
128+
// CHECK-DOUBLE: scf.for %[[IV1:[0-9a-zA-Z]*]] =
129+
%1 = scf.for %arg5 = %c0 to %c25 step %c16 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) {
130+
%2 = affine.min #map0(%arg3)
131+
%3 = affine.min #map1(%arg5)
132+
%4 = tensor.extract_slice %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<24x25xf32> to tensor<?x?xf32>
133+
134+
// CHECK-DOUBLE: scf.for %[[IV2:[0-9a-zA-Z]*]] =
135+
%5 = scf.for %arg7 = %c0 to %2 step %c5 iter_args(%arg8 = %4) -> (tensor<?x?xf32>) {
136+
137+
// CHECK-DOUBLE: scf.for %[[IV3:[0-9a-zA-Z]*]] =
138+
%7 = scf.for %arg9 = %c0 to %3 step %c6 iter_args(%arg10 = %arg8) -> (tensor<?x?xf32>) {
139+
%8 = affine.min #map2(%arg7, %2)
140+
%9 = affine.apply #map3(%arg7, %arg3)
141+
%10 = tensor.extract_slice %arg0[%9, 0] [%8, 12] [1, 1] : tensor<24x12xf32> to tensor<?x12xf32>
142+
%11 = affine.min #map4(%arg9, %3)
143+
%12 = affine.apply #map3(%arg9, %arg5)
144+
%13 = tensor.extract_slice %arg1[0, %12] [12, %11] [1, 1] : tensor<12x25xf32> to tensor<12x?xf32>
145+
%14 = affine.min #map2(%arg7, %2)
146+
%15 = affine.min #map4(%arg9, %3)
147+
%16 = tensor.extract_slice %arg10[%arg7, %arg9] [%14, %15] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
148+
149+
// Pad the output operand and perform the multiplication.
150+
// CHECK-DOUBLE: = linalg.pad_tensor
151+
// CHECK-DOUBLE: = linalg.matmul
152+
%17 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%10, %13 : tensor<?x12xf32>, tensor<12x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
153+
%18 = tensor.insert_slice %17 into %arg10[%arg7, %arg9] [%14, %15] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
154+
scf.yield %18 : tensor<?x?xf32>
155+
}
156+
scf.yield %7 : tensor<?x?xf32>
157+
}
158+
%6 = tensor.insert_slice %5 into %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<24x25xf32>
159+
scf.yield %6 : tensor<24x25xf32>
160+
}
161+
scf.yield %1 : tensor<24x25xf32>
162+
}
163+
return %0 : tensor<24x25xf32>
164+
}

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ struct TestLinalgTransforms
9696
Option<int> testHoistPadding{*this, "test-hoist-padding",
9797
llvm::cl::desc("Test hoist padding"),
9898
llvm::cl::init(0)};
99+
Option<bool> testPadPattern{*this, "test-pad-pattern",
100+
llvm::cl::desc("Test pad pattern"),
101+
llvm::cl::init(false)};
99102
Option<bool> testTransformPadTensor{
100103
*this, "test-transform-pad-tensor",
101104
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
@@ -117,6 +120,14 @@ struct TestLinalgTransforms
117120
*this, "nofold-operands",
118121
llvm::cl::desc("Operands to set nofold when test-tile-pattern"),
119122
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
123+
ListOption<int64_t> packPaddings{
124+
*this, "pack-paddings",
125+
llvm::cl::desc("Operand packing flags when test-pad-pattern"),
126+
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
127+
ListOption<int64_t> hoistPaddings{
128+
*this, "hoist-paddings",
129+
llvm::cl::desc("Operand hoisting depths when test-pad-pattern"),
130+
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
120131
ListOption<int64_t> peeledLoops{
121132
*this, "peeled-loops",
122133
llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
@@ -637,6 +648,30 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType,
637648
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
638649
}
639650

651+
static void applyPadPattern(FuncOp funcOp, ArrayRef<int64_t> packPaddings,
652+
ArrayRef<int64_t> hoistPaddings) {
653+
MLIRContext *context = funcOp.getContext();
654+
RewritePatternSet padPattern(context);
655+
auto linalgPaddingOptions = linalg::LinalgPaddingOptions();
656+
auto packFunc = [&](OpOperand &opOperand) {
657+
return opOperand.getOperandNumber() < packPaddings.size()
658+
? packPaddings[opOperand.getOperandNumber()]
659+
: false;
660+
};
661+
auto hoistingFunc = [&](OpOperand &opOperand) {
662+
return opOperand.getOperandNumber() < hoistPaddings.size()
663+
? hoistPaddings[opOperand.getOperandNumber()]
664+
: 0;
665+
};
666+
linalgPaddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
667+
linalgPaddingOptions.setPaddingNoFoldComputationFunction(packFunc);
668+
linalgPaddingOptions.setPaddingHoistComputationFunction(hoistingFunc);
669+
padPattern.add<LinalgPaddingPattern>(
670+
context, linalgPaddingOptions,
671+
LinalgTransformationFilter(Identifier::get("pad", context)));
672+
(void)applyPatternsAndFoldGreedily(funcOp, std::move(padPattern));
673+
}
674+
640675
static void applyInterchangePattern(FuncOp funcOp,
641676
ArrayRef<unsigned> interchangeVector) {
642677
MLIRContext *context = funcOp.getContext();
@@ -780,6 +815,8 @@ void TestLinalgTransforms::runOnFunction() {
780815
}
781816
});
782817
}
818+
if (testPadPattern)
819+
return applyPadPattern(getFunction(), packPaddings, hoistPaddings);
783820
if (testInterchangePattern.hasValue())
784821
return applyInterchangePattern(getFunction(), testInterchangePattern);
785822
}

0 commit comments

Comments
 (0)