Skip to content

Commit b6f4dd9

Browse files
authored
[mlir][transform] Implement FlattenElementwiseLinalgOp transform op (#81431)
A `transform.structured.flatten_elementwise` op is implemented for flattening the iteration space and (applicable) operands/results to a single dimension.
1 parent 95e0369 commit b6f4dd9

File tree

5 files changed

+241
-50
lines changed

5 files changed

+241
-50
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,6 +2295,49 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
22952295
}];
22962296
}
22972297

2298+
//===----------------------------------------------------------------------===//
2299+
// FlattenElementwiseLinalgOp
2300+
//===----------------------------------------------------------------------===//
2301+
2302+
def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
2303+
"structured.flatten_elementwise",
2304+
[FunctionalStyleTransformOpTrait,
2305+
MemoryEffectsOpInterface,
2306+
TransformOpInterface,
2307+
TransformEachOpTrait,
2308+
ReportTrackingListenerFailuresOpTrait]> {
2309+
let description = [{
2310+
Flattens the iteration space and (applicable) operands of elementwise
2311+
linalg ops to a single dimension.
2312+
2313+
Returns one handle:
2314+
- Flattened linalg operation.
2315+
2316+
#### Return modes:
2317+
2318+
Returns a definite failure if target is not isolated from above.
2319+
Returns a silenceable failure if the pattern application failed.
2320+
}];
2321+
2322+
let arguments = (ins TransformHandleTypeInterface:$target);
2323+
let results = (outs TransformHandleTypeInterface:$transformed);
2324+
2325+
let assemblyFormat =
2326+
"$target attr-dict `:` functional-type($target, results)";
2327+
2328+
let builders = [
2329+
OpBuilder<(ins "Value":$target)>
2330+
];
2331+
2332+
let extraClassDeclaration = [{
2333+
::mlir::DiagnosedSilenceableFailure applyToOne(
2334+
::mlir::transform::TransformRewriter &rewriter,
2335+
::mlir::linalg::LinalgOp target,
2336+
::mlir::transform::ApplyToEachResultList &results,
2337+
::mlir::transform::TransformState &state);
2338+
}];
2339+
}
2340+
22982341
//===----------------------------------------------------------------------===//
22992342
// Transpose Conv2D
23002343
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,16 +1074,20 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
10741074
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
10751075
ArrayRef<ReassociationIndices> dimSequences);
10761076

1077+
struct CollapseResult {
1078+
SmallVector<Value> results;
1079+
LinalgOp collapsedOp;
1080+
};
1081+
10771082
/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
10781083
/// to calling this method is that for each list in `foldedIterationDim`, the
10791084
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
10801085
/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
10811086
/// When valid, the method also collapses the operands of the op. Returns
10821087
/// replacement values of the results of the original `linalgOp` by inserting
10831088
/// reshapes to get back values of compatible types.
1084-
template <typename LinalgType>
1085-
FailureOr<SmallVector<Value>>
1086-
collapseOpIterationDims(LinalgType op,
1089+
FailureOr<CollapseResult>
1090+
collapseOpIterationDims(LinalgOp op,
10871091
ArrayRef<ReassociationIndices> foldedIterationDims,
10881092
RewriterBase &rewriter);
10891093

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3244,6 +3244,31 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
32443244
return DiagnosedSilenceableFailure::success();
32453245
}
32463246

3247+
//===----------------------------------------------------------------------===//
3248+
// FlattenElementwiseLinalgOp.
3249+
//===----------------------------------------------------------------------===//
3250+
3251+
DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3252+
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3253+
transform::ApplyToEachResultList &results,
3254+
transform::TransformState &state) {
3255+
rewriter.setInsertionPoint(target);
3256+
if (target.getNumLoops() <= 1)
3257+
return DiagnosedSilenceableFailure::success();
3258+
ReassociationIndices reassociation(target.getNumLoops());
3259+
std::iota(reassociation.begin(), reassociation.end(), 0);
3260+
auto maybeFlattened =
3261+
(isElementwise(target))
3262+
? collapseOpIterationDims(target, reassociation, rewriter)
3263+
: FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
3264+
target, "only elementwise flattening is supported"));
3265+
if (failed(maybeFlattened))
3266+
return emitDefaultSilenceableFailure(target);
3267+
results.push_back(maybeFlattened->collapsedOp);
3268+
rewriter.replaceOp(target, maybeFlattened->results);
3269+
return DiagnosedSilenceableFailure::success();
3270+
}
3271+
32473272
//===----------------------------------------------------------------------===//
32483273
// TransposeConv2DOp
32493274
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,24 +1446,20 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
14461446
}
14471447
}
14481448

1449-
template <typename LinalgType>
1450-
Operation *createCollapsedOp(LinalgType op,
1451-
const CollapsingInfo &collapsingInfo,
1452-
RewriterBase &rewriter) {
1453-
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1454-
"unsupported linalg op type to create");
1449+
void collapseOperandsAndResults(LinalgOp op,
1450+
const CollapsingInfo &collapsingInfo,
1451+
RewriterBase &rewriter,
1452+
SmallVectorImpl<Value> &inputOperands,
1453+
SmallVectorImpl<Value> &outputOperands,
1454+
SmallVectorImpl<Type> &resultTypes) {
14551455
Location loc = op->getLoc();
1456-
1457-
// Get the input operands.
1458-
SmallVector<Value> inputOperands =
1456+
inputOperands =
14591457
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
14601458
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
14611459
rewriter);
14621460
});
14631461

14641462
// Get the output operands and result types.
1465-
SmallVector<Type> resultTypes;
1466-
SmallVector<Value> outputOperands;
14671463
resultTypes.reserve(op.getNumDpsInits());
14681464
outputOperands.reserve(op.getNumDpsInits());
14691465
for (OpOperand &output : op.getDpsInitsMutable()) {
@@ -1475,41 +1471,69 @@ Operation *createCollapsedOp(LinalgType op,
14751471
if (!op.hasPureBufferSemantics())
14761472
resultTypes.push_back(newOutput.getType());
14771473
}
1474+
}
14781475

1479-
if (isa<linalg::CopyOp>(op)) {
1480-
return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
1481-
outputOperands[0]);
1482-
}
1476+
/// Clone a `LinalgOp` to a collapsed version of same name
1477+
template <typename OpTy>
1478+
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1479+
const CollapsingInfo &collapsingInfo) {
1480+
return nullptr;
1481+
}
14831482

1484-
// Get the iterator types for the operand.
1485-
SmallVector<utils::IteratorType> iteratorTypes =
1486-
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
1483+
/// Collapse any `LinalgOp` that does not require any specialization such as
1484+
/// indexing_maps, iterator_types, etc.
1485+
template <>
1486+
LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1487+
const CollapsingInfo &collapsingInfo) {
1488+
SmallVector<Value> inputOperands, outputOperands;
1489+
SmallVector<Type> resultTypes;
1490+
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1491+
outputOperands, resultTypes);
1492+
return cast<LinalgOp>(clone(
1493+
rewriter, origOp, resultTypes,
1494+
llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands))));
1495+
}
14871496

1488-
// Get the indexing maps.
1489-
auto indexingMaps =
1490-
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
1497+
/// Collapse a `GenericOp`
1498+
template <>
1499+
GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1500+
GenericOp origOp,
1501+
const CollapsingInfo &collapsingInfo) {
1502+
SmallVector<Value> inputOperands, outputOperands;
1503+
SmallVector<Type> resultTypes;
1504+
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1505+
outputOperands, resultTypes);
1506+
SmallVector<AffineMap> indexingMaps(
1507+
llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
14911508
return getCollapsedOpIndexingMap(map, collapsingInfo);
1492-
});
1509+
}));
1510+
1511+
SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1512+
origOp.getIteratorTypesArray(), collapsingInfo));
14931513

1494-
Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
1495-
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1514+
GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1515+
origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
14961516
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1497-
Block *origOpBlock = &op->getRegion(0).front();
1517+
Block *origOpBlock = &origOp->getRegion(0).front();
14981518
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
14991519
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
15001520
collapsedOpBlock->getArguments());
1501-
15021521
return collapsedOp;
15031522
}
15041523

1524+
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1525+
RewriterBase &rewriter) {
1526+
if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1527+
return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1528+
} else {
1529+
return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1530+
}
1531+
}
1532+
15051533
/// Implementation of fusion with reshape operation by collapsing dimensions.
1506-
template <typename LinalgType>
1507-
FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
1508-
LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
1534+
FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1535+
LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
15091536
RewriterBase &rewriter) {
1510-
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1511-
"unsupported linalg op type to collapse");
1512-
15131537
// Bail on trivial no-op cases.
15141538
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
15151539
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
@@ -1538,8 +1562,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
15381562
}
15391563

15401564
// Bail on non-canonical ranges.
1541-
SmallVector<Range> loopRanges =
1542-
cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
1565+
SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
15431566
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
15441567
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
15451568
return cast<IntegerAttr>(attr).getInt() == value;
@@ -1555,8 +1578,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
15551578
op, "expected all loop ranges to have zero start and unit stride");
15561579
}
15571580

1558-
LinalgType collapsedOp = cast<LinalgType>(
1559-
createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
1581+
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
15601582

15611583
Location loc = op->getLoc();
15621584
if (collapsedOp.hasIndexSemantics()) {
@@ -1597,7 +1619,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
15971619
results.push_back(collapsedOpResult);
15981620
}
15991621
}
1600-
return results;
1622+
return CollapseResult{results, collapsedOp};
16011623
}
16021624

16031625
namespace {
@@ -1629,15 +1651,14 @@ class FoldWithProducerReshapeOpByCollapsing
16291651
continue;
16301652
}
16311653

1632-
std::optional<SmallVector<Value>> replacements =
1633-
collapseOpIterationDims<linalg::GenericOp>(
1634-
genericOp, collapsableIterationDims, rewriter);
1635-
if (!replacements) {
1654+
std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1655+
genericOp, collapsableIterationDims, rewriter);
1656+
if (!collapseResult) {
16361657
return rewriter.notifyMatchFailure(
16371658
genericOp, "failed to do the fusion by collapsing transformation");
16381659
}
16391660

1640-
rewriter.replaceOp(genericOp, *replacements);
1661+
rewriter.replaceOp(genericOp, collapseResult->results);
16411662
return success();
16421663
}
16431664
return failure();
@@ -1671,13 +1692,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
16711692
op, "specified dimensions cannot be collapsed");
16721693
}
16731694

1674-
std::optional<SmallVector<Value>> replacements =
1675-
collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
1676-
rewriter);
1677-
if (!replacements) {
1695+
std::optional<CollapseResult> collapseResult =
1696+
collapseOpIterationDims(op, collapsableIterationDims, rewriter);
1697+
if (!collapseResult) {
16781698
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
16791699
}
1680-
rewriter.replaceOp(op, *replacements);
1700+
rewriter.replaceOp(op, collapseResult->results);
16811701
return success();
16821702
}
16831703

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @fill(
4+
// CHECK-SAME: %[[ARG0:.*]]: f32,
5+
// CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32>
6+
// CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
7+
// CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>)
8+
func.func @fill(%cst: f32, %arg: memref<32x7xf32>) {
9+
linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>)
10+
return
11+
}
12+
13+
module attributes {transform.with_named_sequence} {
14+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
15+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
16+
%flattened = transform.structured.flatten_elementwise %0
17+
: (!transform.any_op) -> !transform.any_op
18+
transform.yield
19+
}
20+
}
21+
22+
// -----
23+
24+
// CHECK-LABEL: func.func @fill_tensor(
25+
// CHECK-SAME: %[[ARG0:.*]]: f32,
26+
// CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32>
27+
// CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
28+
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>)
29+
// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]]
30+
func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> {
31+
%0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32>
32+
return %0 : tensor<32x7xf32>
33+
}
34+
35+
module attributes {transform.with_named_sequence} {
36+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
37+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
38+
%flattened = transform.structured.flatten_elementwise %0
39+
: (!transform.any_op) -> !transform.any_op
40+
transform.yield
41+
}
42+
}
43+
44+
// -----
45+
46+
// CHECK-LABEL: func.func @map(
47+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
48+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
49+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
50+
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
51+
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
52+
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
53+
// CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
54+
func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
55+
linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>)
56+
return
57+
}
58+
59+
module attributes {transform.with_named_sequence} {
60+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
61+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
62+
%flattened = transform.structured.flatten_elementwise %0
63+
: (!transform.any_op) -> !transform.any_op
64+
transform.yield
65+
}
66+
}
67+
68+
// -----
69+
70+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
71+
// CHECK-LABEL: func.func @generic
72+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
73+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
74+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
75+
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
76+
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
77+
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
78+
// CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
79+
// CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32)
80+
// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]]
81+
// CHECK-NEXT: linalg.yield %[[SUM]]
82+
#map = affine_map<(d0, d1) -> (d0, d1)>
83+
func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
84+
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) {
85+
^bb0(%a: f32, %b: f32, %c: f32):
86+
%0 = arith.addf %a, %b : f32
87+
linalg.yield %0 : f32
88+
}
89+
return
90+
}
91+
92+
module attributes {transform.with_named_sequence} {
93+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
94+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
95+
%flattened = transform.structured.flatten_elementwise %0
96+
: (!transform.any_op) -> !transform.any_op
97+
transform.yield
98+
}
99+
}

0 commit comments

Comments
 (0)