Skip to content

Commit 33b7833

Browse files
authored
[MLIR][Linalg] Add more specialize patterns (#91153)
Currently only linalg.copy is recognized when trying to specialize linalg.generics back to named op. This diff enables recognition of more generic to named op e.g. linalg.fill, elemwise unary/binary.
1 parent 7ecdf62 commit 33b7833

File tree

6 files changed

+307
-0
lines changed

6 files changed

+307
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace mlir {
2828
namespace linalg {
2929
class IteratorTypeAttr;
3030
class LinalgOp;
31+
class GenericOp;
3132

3233
namespace detail {
3334
/// Implementation of the method that check if given operands
@@ -115,6 +116,21 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
115116
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
116117
bool isaCopyOpInterface(LinalgOp linalgOp);
117118

119+
/// Checks whether a given `genericOp` is semantically equivalent to a single
120+
/// linalgelementwise unary op. e.g. linalg.exp.
121+
/// A linalg.generic body could be a series of unary elementwise ops e.g.
122+
/// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
123+
/// detecting cases where body is is a single computation op.
124+
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
125+
126+
/// Checks whether `genericOp` is semantically equivalent to a single linalg
127+
/// elementwise binary op e.g. linalg.sub.
128+
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
129+
130+
/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
131+
/// Returns the scalar fill value if true.
132+
std::optional<Value> isaFillOpInterface(GenericOp genericOp);
133+
118134
namespace detail {
119135

120136
/// Returns true if the block contains a contraction of the following form:

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,99 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
7070
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
7171
}
7272

73+
//===----------------------------------------------------------------------===//
74+
// FillOpInterface implementation
75+
//===----------------------------------------------------------------------===//
76+
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
77+
// Structural.
78+
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
79+
genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
80+
return std::nullopt;
81+
82+
// Input should be referenced and init should not.
83+
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
84+
genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
85+
return std::nullopt;
86+
87+
OpOperand *value = genericOp.getDpsInputOperand(0);
88+
if (!genericOp.isScalar(value))
89+
return std::nullopt;
90+
91+
Block *body = genericOp.getBody();
92+
if (body->getOperations().size() != 1)
93+
return std::nullopt;
94+
95+
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
96+
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
97+
yieldOp->getOperand(0) != body->getArgument(0))
98+
return std::nullopt;
99+
return value->get();
100+
}
101+
102+
//===----------------------------------------------------------------------===//
103+
// Elementwise Single Unary/Binary-OpInterface implementation
104+
//===----------------------------------------------------------------------===//
105+
static bool
106+
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
107+
unsigned arity) {
108+
// Check all loops are parallel, and have only tensor semantics.
109+
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
110+
genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
111+
return false;
112+
113+
// Check there are arity-inputs, 1-output and all are identity-maps.
114+
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
115+
!llvm::all_of(genericOp.getIndexingMapsArray(),
116+
[](AffineMap map) { return map.isIdentity(); }))
117+
return false;
118+
119+
// Init should not be referenced for elementwise operations.
120+
if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
121+
return false;
122+
123+
// A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
124+
// as resulting from producer-consumer fusion. Here, we restrict to two ops in
125+
// the body, where the first is the elementwise single op and the second a
126+
// yield.
127+
Block *body = genericOp.getBody();
128+
if (body->getOperations().size() != 2)
129+
return false;
130+
131+
Operation *op = &body->front();
132+
if (op->getNumOperands() != arity || op->getNumResults() != 1)
133+
return false;
134+
135+
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
136+
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
137+
yieldOp->getOperand(0).getDefiningOp() != op)
138+
return false;
139+
return true;
140+
}
141+
142+
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
143+
// All basic elemwise checks.
144+
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
145+
return false;
146+
147+
// Check input is actully used.
148+
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
149+
return false;
150+
return true;
151+
}
152+
153+
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
154+
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
155+
return false;
156+
157+
// Check both inputs are used (elementwise).
158+
OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
159+
OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
160+
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
161+
!genericOp.payloadUsesValueFromOperand(inputOpOperand1))
162+
return false;
163+
return true;
164+
}
165+
73166
//===----------------------------------------------------------------------===//
74167
// ContractionOpInterface implementation
75168
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,91 @@
1414
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1515
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1616
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17+
#include "mlir/Dialect/Math/IR/Math.h"
1718
#include "llvm/Support/Debug.h"
1819

1920
#define DEBUG_TYPE "linalg-specialization"
2021

22+
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
23+
(rewriter.replaceOpWithNewOp<NEWOP>( \
24+
genericOp, \
25+
ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
26+
genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
27+
ValueRange{genericOp.getDpsInits()[0]}))
28+
29+
#define REPLACE_UNARY_OP(NEWOP) \
30+
(rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
31+
ValueRange{genericOp.getDpsInputs()[0]}, \
32+
ValueRange{genericOp.getDpsInits()[0]}))
33+
2134
using namespace mlir;
2235
using namespace mlir::linalg;
2336

37+
// Given a elementwise single binary linalg generic op, checks whether the
38+
// binary op accesses operands as swapped. e.g.
39+
// this differentiates between a linalg-generic body that contains:
40+
// ^bb0(%a: f32, %b: f32, %c : f32):
41+
// %0 = arith.subf %a, %b : f32
42+
// linalg.yield %0: f32
43+
// against:
44+
// ^bb0(%a: f32, %b: f32, %c : f32):
45+
// %0 = arith.subf %b, %a : f32
46+
// linalg.yield %0: f32
47+
// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
48+
static bool areBinOpsSwapped(GenericOp genericOp) {
49+
Block *body = genericOp.getBody();
50+
Operation *op = &body->front();
51+
bool swapped = false;
52+
if (op->getOpOperand(0).get() != body->getArgument(0)) {
53+
swapped = true;
54+
assert(op->getOpOperand(0).get() == body->getArgument(1) &&
55+
op->getOpOperand(1).get() == body->getArgument(0) &&
56+
"binary op uses just one block arg");
57+
}
58+
return swapped;
59+
}
60+
2461
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
2562
GenericOp genericOp) {
2663
if (isaCopyOpInterface(genericOp)) {
2764
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
2865
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
2966
return namedOp;
3067
}
68+
69+
if (isaFillOpInterface(genericOp)) {
70+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
71+
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
72+
return namedOp;
73+
}
74+
75+
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
76+
Operation *op = &genericOp.getBody()->front();
77+
if (isa<math::ExpOp>(op)) {
78+
LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
79+
return namedOp;
80+
}
81+
}
82+
83+
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
84+
bool swap = areBinOpsSwapped(genericOp);
85+
Operation *op = &genericOp.getBody()->front();
86+
if (isa<arith::AddFOp>(op)) {
87+
LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
88+
return namedOp;
89+
}
90+
if (isa<arith::SubFOp>(op)) {
91+
LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
92+
return namedOp;
93+
}
94+
if (isa<arith::MulFOp>(op)) {
95+
LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
96+
return namedOp;
97+
}
98+
if (isa<arith::DivFOp>(op)) {
99+
LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
100+
return namedOp;
101+
}
102+
}
31103
return failure();
32104
}

mlir/test/Dialect/Linalg/transform-op-specialize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,28 @@ module attributes {transform.with_named_sequence} {
141141
transform.yield
142142
}
143143
}
144+
145+
// -----
146+
147+
#map = affine_map<(d0, d1) -> ()>
148+
#map1 = affine_map<(d0, d1) -> (d0, d1)>
149+
func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
150+
%cst = arith.constant 0.000000e+00 : f32
151+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
152+
^bb0(%in: f32, %out: f32):
153+
linalg.yield %in : f32
154+
} -> tensor<7x7xf32>
155+
return %0 : tensor<7x7xf32>
156+
}
157+
// CHECK-LABEL: linalg_generic_fill
158+
// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
159+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
160+
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
161+
162+
module attributes {transform.with_named_sequence} {
163+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
164+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
165+
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
166+
transform.yield
167+
}
168+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
2+
3+
#map = affine_map<(d0, d1) -> (d0, d1)>
4+
func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
5+
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
6+
^bb0(%in: f32, %in_0: f32, %out: f32):
7+
%1 = arith.addf %in, %in_0 : f32
8+
linalg.yield %1 : f32
9+
} -> tensor<?x?xf32>
10+
return %0 : tensor<?x?xf32>
11+
}
12+
// CHECK-LABEL: specialize_add
13+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
14+
// CHECK-NOT: linalg.generic
15+
// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
16+
17+
func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
18+
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
19+
^bb0(%in: f32, %in_0: f32, %out: f32):
20+
%1 = arith.subf %in, %in_0 : f32
21+
linalg.yield %1 : f32
22+
} -> tensor<?x?xf32>
23+
return %0 : tensor<?x?xf32>
24+
}
25+
// CHECK-LABEL: specialize_sub
26+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
27+
// CHECK-NOT: linalg.generic
28+
// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
29+
30+
func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
31+
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
32+
^bb0(%in: f32, %in_0: f32, %out: f32):
33+
%1 = arith.subf %in_0, %in : f32
34+
linalg.yield %1 : f32
35+
} -> tensor<?x?xf32>
36+
return %0 : tensor<?x?xf32>
37+
}
38+
// CHECK-LABEL: specialize_sub
39+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
40+
// CHECK-NOT: linalg.generic
41+
// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
42+
43+
func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
44+
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
45+
^bb0(%in: f32, %in_0: f32, %out: f32):
46+
%1 = arith.mulf %in, %in_0 : f32
47+
linalg.yield %1 : f32
48+
} -> tensor<?x?xf32>
49+
return %0 : tensor<?x?xf32>
50+
}
51+
// CHECK-LABEL: specialize_mul
52+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
53+
// CHECK-NOT: linalg.generic
54+
// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
55+
56+
func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
57+
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
58+
^bb0(%in: f32, %in_0: f32, %out: f32):
59+
%1 = arith.divf %in, %in_0 : f32
60+
linalg.yield %1 : f32
61+
} -> tensor<?x?xf32>
62+
return %0 : tensor<?x?xf32>
63+
}
64+
// CHECK-LABEL: specialize_div
65+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
66+
// CHECK-NOT: linalg.generic
67+
// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
68+
69+
70+
module attributes {transform.with_named_sequence} {
71+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
72+
%0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
73+
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
74+
transform.yield
75+
}
76+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
2+
3+
#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4+
func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
5+
%0 = linalg.generic
6+
{indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
7+
ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
8+
^bb0(%in: f32, %out: f32):
9+
%1 = math.exp %in : f32
10+
linalg.yield %1 : f32
11+
} -> tensor<?x?x?xf32>
12+
return %0 : tensor<?x?x?xf32>
13+
}
14+
// CHECK-LABEL: specialize_exp
15+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
16+
// CHECK-NOT: linalg.generic
17+
// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
18+
19+
module attributes {transform.with_named_sequence} {
20+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
21+
%0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
22+
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
23+
transform.yield
24+
}
25+
}

0 commit comments

Comments
 (0)