Skip to content

Commit e4a503a

Browse files
[mlir][Linalg] Introduce a ContractionOpInterface
This revision takes advantage of recent extensions to vectorization to refactor contraction detection into a bona fide Linalg interface. The mlit-linalg-ods-gen parser is extended to support adding such interfaces. The detection that was originally enabling vectorization is refactored to serve as both a test on a generic LinalgOp as well as to verify ops that declare to conform to that interface. This is plugged through Linalg transforms and strategies but it quickly becomes evident that the complexity and rigidity of the C++ class based templating does not pay for itself. Therefore, this revision changes the API for vectorization patterns to get rid of templates as much as possible. Variadic templates are relegated to the internals of LinalgTransformationFilter as much as possible and away from the user-facing APIs. It is expected other patterns / transformations will follow the same path and drop as much C++ templating as possible from the class definition. Differential revision: https://reviews.llvm.org/D95973
1 parent 727bd89 commit e4a503a

File tree

11 files changed

+486
-245
lines changed

11 files changed

+486
-245
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,21 @@
2222

2323
namespace mlir {
2424
namespace linalg {
25+
class LinalgOp;
2526

2627
/// Returns the values obtained by applying `map` to the list of values.
2728
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
2829
AffineMap map, ValueRange values);
2930

31+
/// Checks whether `linalgOp` conforms to ContractionOpInterface.
32+
// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
33+
bool isaContractionOpInterface(LinalgOp linalgOp);
34+
3035
namespace detail {
3136

37+
/// Verify that `op` conforms to ContractionOpInterface.
38+
LogicalResult verifyContractionInterface(Operation *op);
39+
3240
/// Verify that `op` conforms to the invariants of StructuredOpInterface
3341
LogicalResult verifyStructuredOpInterface(Operation *op);
3442

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,28 @@
1515

1616
include "mlir/IR/OpBase.td"
1717

18-
// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp'
19-
// interface.
18+
// The 'LinalgContractionOpInterface' provides access to the
19+
// 'ContractionOpInterface'.
20+
def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
21+
let description = [{
22+
A Linalg contraction is defined in general terms:
23+
1. Has 2 input and 1 output shapes.
24+
2. Has at least one reduction dimension.
25+
3. Has only projected permutation indexing maps.
26+
4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
27+
(AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
28+
operations that may change the type (e.g. for mixed-precision).
29+
As a consequence, when vectorization of such an op occurs, the only special
30+
behavior is that the (unique) MulOpType is vectorized into a
31+
`vector.contract`. All other ops are handled in a generic fashion.
32+
In the future, we may wish to allow more input arguments and elementwise and
33+
constant operations that do not involve the reduction dimension(s).
34+
}];
35+
let cppNamespace = "::mlir::linalg";
36+
let verify = [{ return detail::verifyContractionInterface($_op); }];
37+
}
38+
39+
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
2040
def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
2141
let cppNamespace = "::mlir::linalg";
2242
let methods = [

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,43 @@
1-
ods_def<MatmulOp>:
1+
ods_def<MatmulOp>
2+
implements_interface<LinalgContractionOpInterface> :
23
def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
34
C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
45
}
56

6-
ods_def<MatmulColumnMajorOp>:
7+
ods_def<MatmulColumnMajorOp>
8+
implements_interface<LinalgContractionOpInterface> :
79
def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
810
C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
911
}
1012

11-
ods_def<MatmulI8I8I32Op>:
13+
ods_def<MatmulI8I8I32Op>
14+
implements_interface<LinalgContractionOpInterface> :
1215
def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
1316
// TODO: ideally something closer to
1417
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
1518
C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
1619
}
1720

18-
ods_def<MatvecOp>:
21+
ods_def<MatvecOp>
22+
implements_interface<LinalgContractionOpInterface> :
1923
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
2024
x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
2125
}
2226

23-
ods_def<VecmatOp>:
27+
ods_def<VecmatOp>
28+
implements_interface<LinalgContractionOpInterface> :
2429
def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
2530
x(n) = std_addf<m>(std_mulf(y(m), A(m, n)));
2631
}
2732

28-
ods_def<DotOp>:
33+
ods_def<DotOp>
34+
implements_interface<LinalgContractionOpInterface> :
2935
def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
3036
C() = std_addf<m>(std_mulf(A(m), B(m)));
3137
}
3238

33-
ods_def<BatchMatmulOp>:
39+
ods_def<BatchMatmulOp>
40+
implements_interface<LinalgContractionOpInterface> :
3441
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
3542
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
3643
}

mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,33 @@ template <template <typename> class PatternType, typename ConcreteOpType,
3535
typename OptionsType,
3636
typename = std::enable_if_t<std::is_member_function_pointer<
3737
decltype(&ConcreteOpType::getOperationName)>::value>>
38-
void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
38+
void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
3939
MLIRContext *context, StringRef opName,
4040
linalg::LinalgTransformationFilter m) {
4141
assert(opName == ConcreteOpType::getOperationName() &&
4242
"explicit name must match ConcreteOpType::getOperationName");
43-
patterList.insert<PatternType<ConcreteOpType>>(context, options, m);
43+
patternList.insert<PatternType<ConcreteOpType>>(context, options, m);
4444
}
4545

4646
/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
4747
/// (e.g. LinalgOp, other interfaces, Operation*).
4848
template <template <typename> class PatternType, typename OpType,
4949
typename OptionsType>
50-
void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
50+
void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
5151
MLIRContext *context, StringRef opName,
5252
linalg::LinalgTransformationFilter m) {
5353
assert(!opName.empty() && "opName must not be empty");
54-
patterList.insert<PatternType<OpType>>(opName, context, options, m);
54+
patternList.insert<PatternType<OpType>>(opName, context, options, m);
55+
}
56+
57+
template <typename PatternType, typename OpType, typename OptionsType>
58+
void enqueue(OwningRewritePatternList &patternList, OptionsType options,
59+
MLIRContext *context, StringRef opName,
60+
linalg::LinalgTransformationFilter m) {
61+
if (!opName.empty())
62+
patternList.insert<PatternType>(opName, context, options, m);
63+
else
64+
patternList.insert<PatternType>(m.addOpFilter<OpType>(), options);
5565
}
5666

5767
/// Promotion transformation enqueues a particular stage-1 pattern for
@@ -112,13 +122,12 @@ struct Promote : public Transformation {
112122
/// Vectorization transformation enqueues a particular stage-1 pattern for
113123
/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
114124
/// transfer rewrite forwarding patterns.
115-
template <typename LinalgOpType>
125+
template <typename LinalgOpType = LinalgOp>
116126
struct Vectorize : public Transformation {
117127
explicit Vectorize(
118128
linalg::LinalgVectorizationOptions options,
119129
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
120-
: Transformation(f), opName(LinalgOpType::getOperationName()),
121-
options(options) {}
130+
: Transformation(f), opName(), options(options) {}
122131

123132
Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
124133
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -128,7 +137,7 @@ struct Vectorize : public Transformation {
128137
buildRewritePatterns(MLIRContext *context,
129138
linalg::LinalgTransformationFilter m) override {
130139
OwningRewritePatternList vectorizationPatterns;
131-
sfinae_enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
140+
enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
132141
vectorizationPatterns, options, context, opName, m);
133142
vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
134143
linalg::LinalgCopyVTWForwardingPattern>(
@@ -235,16 +244,6 @@ struct CodegenStrategy {
235244
linalg::LinalgVectorizationOptions(), f));
236245
return *this;
237246
}
238-
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
239-
template <typename LinalgOpType>
240-
CodegenStrategy &
241-
vectorize(StringRef opName,
242-
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
243-
transformationSequence.emplace_back(
244-
std::make_unique<Vectorize<LinalgOpType>>(
245-
opName, linalg::LinalgVectorizationOptions(), f));
246-
return *this;
247-
}
248247
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
249248
/// operation.
250249
template <typename LinalgOpType>
@@ -254,13 +253,21 @@ struct CodegenStrategy {
254253
return b ? vectorize<LinalgOpType>(f) : *this;
255254
return *this;
256255
}
256+
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
257+
CodegenStrategy &
258+
vectorize(StringRef opName,
259+
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
260+
assert(!opName.empty() && "expected an op name");
261+
transformationSequence.emplace_back(std::make_unique<Vectorize<LinalgOp>>(
262+
opName, linalg::LinalgVectorizationOptions(), f));
263+
return *this;
264+
}
257265
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
258266
/// operation.
259-
template <typename LinalgOpType>
260267
CodegenStrategy &
261268
vectorizeIf(bool b, StringRef opName,
262269
linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
263-
return b ? vectorize<LinalgOpType>(opName, f) : *this;
270+
return b ? vectorize(opName, f) : *this;
264271
return *this;
265272
}
266273
/// Configure the post staged-patterns late vector transformations.

0 commit comments

Comments
 (0)