Skip to content

Commit 64c4dcb

Browse files
author
Nicolas Vasilache
committed
[mlir][Linalg] Extend linalg vectorization to MatmulOp
Summary: This is a simple extension to allow vectorization to work not only on GenericLinalgOp but more generally across named ops too. For now, this still only vectorizes matmul-like ops but is a step towards more generic vectorization of Linalg ops. Reviewers: ftynse Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72942
1 parent a8a9c8e commit 64c4dcb

File tree

5 files changed

+55
-29
lines changed

5 files changed

+55
-29
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@ class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
8686
//===----------------------------------------------------------------------===//
8787
// Linalg to vector patterns precondition and DRR.
8888
//===----------------------------------------------------------------------===//
89-
def PreconditionVectorizeGenericLinalgOp : CPred<
90-
"succeeded(vectorizeGenericLinalgOpPrecondition(op))">;
91-
def VectorizeGenericLinalgOp : NativeCodeCall<
92-
"vectorizeGenericLinalgOp($_builder, op)">;
89+
def PreconditionVectorizeLinalgOp : CPred<
90+
"succeeded(vectorizeLinalgOpPrecondition(op))">;
91+
def VectorizeLinalgOp : NativeCodeCall<
92+
"vectorizeLinalgOp($_builder, op)">;
93+
9394

9495
//===----------------------------------------------------------------------===//
9596
// Linalg generic permutation patterns precondition and DRR.

mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ template <typename ConcreteOp>
7979
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
8080

8181
/// Rewrite a linalg.generic into a suitable vector.contraction op.
82-
LogicalResult vectorizeGenericLinalgOpPrecondition(Operation *op);
83-
SmallVector<Value, 0> vectorizeGenericLinalgOp(PatternRewriter &rewriter,
84-
Operation *op);
82+
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
83+
SmallVector<Value, 0> vectorizeLinalgOp(PatternRewriter &rewriter,
84+
Operation *op);
8585

8686
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
8787
/// and `iterator_types` permutated according to `permutation`.

mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,20 @@ static bool isMatmul(linalg::GenericOp genericOp) {
158158
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
159159
}
160160

161-
LogicalResult
162-
mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) {
163-
// TODO(ntv): This is in fact much more general than just vectorization for
164-
// matmul ops.
161+
// TODO(ntv): This is in fact much more general than just vectorization for
162+
// matmul ops.
163+
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
164+
auto linalgOp = cast<linalg::LinalgOp>(op);
165+
// All types must be static shape to go to vector.
166+
for (Value operand : linalgOp.getInputsAndOutputBuffers())
167+
if (!operand.getType().cast<ShapedType>().hasStaticShape())
168+
return failure();
169+
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
170+
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
171+
return failure();
172+
if (isa<linalg::MatmulOp>(op))
173+
return success();
174+
165175
auto genericOp = dyn_cast<linalg::GenericOp>(op);
166176
if (!genericOp || !isMatmul(genericOp))
167177
return failure();
@@ -179,30 +189,29 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) {
179189
return success();
180190
}
181191

182-
SmallVector<Value, 0>
183-
mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
184-
Operation *op) {
192+
SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
193+
Operation *op) {
185194
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
186195
"]: Rewrite linalg op as vector.contract: "
187196
<< *op << ":\n");
188197

189-
assert(succeeded(vectorizeGenericLinalgOpPrecondition(op)) &&
198+
assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
190199
"DRR failure case must be a precondition");
191200

192-
auto genericOp = cast<linalg::GenericOp>(op);
193-
assert(genericOp.hasBufferSemantics() &&
201+
auto linalgOp = cast<linalg::LinalgOp>(op);
202+
assert(linalgOp.hasBufferSemantics() &&
194203
"expected linalg op with buffer semantics");
195204
edsc::ScopedContext scope(rewriter, op->getLoc());
196205
using edsc::intrinsics::std_load;
197206
using edsc::intrinsics::std_store;
198207
using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
199208
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
200-
auto vA = std_load(vector_type_cast(genericOp.getInput(0)));
201-
auto vB = std_load(vector_type_cast(genericOp.getInput(1)));
202-
auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0));
209+
auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
210+
auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
211+
auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
203212
auto vC = std_load(vectorMemRefC);
204-
auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(),
205-
genericOp.iterator_types());
213+
auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
214+
linalgOp.iterator_types());
206215
std_store(vRes, vectorMemRefC);
207216
return {};
208217
}

mlir/test/Dialect/Linalg/transform-patterns.mlir

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
173173
affine_map<(m, n, k) -> (m, n)>
174174
],
175175
iterator_types = ["parallel", "parallel", "reduction"],
176-
__internal_linalg_transform__ = "_marked_matmul_"
176+
__internal_linalg_transform__ = "VECTORIZE"
177177
}
178178
func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
179179
%C: memref<8x32xf32>) {
@@ -185,7 +185,6 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
185185
} : memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>
186186
return
187187
}
188-
189188
// CHECK-LABEL: func @vectorization_test
190189
// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
191190
// CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
@@ -195,6 +194,17 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
195194
// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
196195
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
197196
// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
197+
198+
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
199+
%C: memref<8x32xf32>) {
200+
linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "VECTORIZE"} :
201+
memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>
202+
return
203+
}
204+
// CHECK-LABEL: func @vectorization_test_2
205+
// CHECK: vector.contract {{.*}} :
206+
// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
207+
198208
func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
199209
%d = mulf %a, %b: f32
200210
%e = addf %c, %d: f32
@@ -213,7 +223,6 @@ func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
213223
library_call = "linalg_matmul",
214224
iterator_types = ["parallel", "parallel", "reduction"]
215225
}
216-
217226
func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
218227
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
219228
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {

mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,17 @@ def : Pattern<(DotOp:$op $_, $_, $_),
9999
//===----------------------------------------------------------------------===//
100100
// Linalg to vector contraction patterns.
101101
//===----------------------------------------------------------------------===//
102+
def : Pattern<(MatmulOp:$op $_, $_, $_),
103+
[(VectorizeLinalgOp)],
104+
[(Constraint<And<[
105+
HasLinalgTransformMarker<"VECTORIZE">,
106+
PreconditionVectorizeLinalgOp
107+
]>>)]>;
102108
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
103-
[(VectorizeGenericLinalgOp)],
104-
[(Constraint<And<[
105-
HasLinalgTransformMarker<"_marked_matmul_">,
106-
PreconditionVectorizeGenericLinalgOp
109+
[(VectorizeLinalgOp)],
110+
[(Constraint<And<[
111+
HasLinalgTransformMarker<"VECTORIZE">,
112+
PreconditionVectorizeLinalgOp
107113
]>>)]>;
108114

109115
//===----------------------------------------------------------------------===//
@@ -135,4 +141,5 @@ def : Pat<(MatmulOp:$op $_, $_, $_),
135141
HasOperandsOfType<"SubViewOp">,
136142
HasLinalgTransformMarker<"_promote_views_">]>>
137143
)]>;
144+
138145
#endif // TEST_LINALG_TRANSFORMS_PATTERNS

0 commit comments

Comments
 (0)