Skip to content

Commit 9f49509

Browse files
authored
[mlir] Add ContractionOpInterface utility functions for vector matrix multiplication (#68945)
1 parent df3478e commit 9f49509

File tree

4 files changed

+263
-9
lines changed

4 files changed

+263
-9
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,39 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
8686
/*methodBody=*/[{
8787
return mlir::isRowMajorBatchMatmul($_op.getIndexingMaps());
8888
}]>,
89+
InterfaceMethod<
90+
/*desc=*/[{
91+
Returns whether the given op has indexing maps that correspond to a
92+
vector-matrix multiplication.
93+
}],
94+
/*retTy=*/"bool",
95+
/*methodName=*/"isVecmat",
96+
/*args=*/(ins),
97+
/*methodBody=*/[{
98+
return mlir::isVecmat($_op.getIndexingMaps());
99+
}]>,
100+
InterfaceMethod<
101+
/*desc=*/[{
102+
Returns whether the given op has indexing maps that correspond to a
103+
matrix-vector multiplication.
104+
}],
105+
/*retTy=*/"bool",
106+
/*methodName=*/"isMatvec",
107+
/*args=*/(ins),
108+
/*methodBody=*/[{
109+
return mlir::isMatvec($_op.getIndexingMaps());
110+
}]>,
111+
InterfaceMethod<
112+
/*desc=*/[{
113+
Returns whether the given op has indexing maps that correspond to a
114+
batched matrix-vector multiplication.
115+
}],
116+
/*retTy=*/"bool",
117+
/*methodName=*/"isBatchMatvec",
118+
/*args=*/(ins),
119+
/*methodBody=*/[{
120+
return mlir::isBatchMatvec($_op.getIndexingMaps());
121+
}]>,
89122
];
90123
}
91124

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,24 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps);
4949
/// the reduction.
5050
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
5151

52+
/// Tests whether the given maps describe a vector matrix multiplication. The
53+
/// test is permutation-invariant. Note that this only checks the affine maps
54+
/// from an operation, so does not perform any checks on the math being
55+
/// performed within the reduction.
56+
bool isVecmat(ArrayAttr indexingMaps);
57+
58+
/// Tests whether the given maps describe a matrix vector multiplication. The
59+
/// test is permutation-invariant. Note that this only checks the affine maps
60+
/// from an operation, so does not perform any checks on the math being
61+
/// performed within the reduction.
62+
bool isMatvec(ArrayAttr indexingMaps);
63+
64+
/// Tests whether the given maps describe a batch matrix vector multiplication.
65+
/// The test is permutation-invariant. Note that this only checks the affine
66+
/// maps from an operation, so does not perform any checks on the math being
67+
/// performed within the reduction.
68+
bool isBatchMatvec(ArrayAttr indexingMaps);
69+
5270
/// Return positions in `iteratorTypes` that match `iteratorTypeName`.
5371
inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
5472
utils::IteratorType iteratorTypeName,

mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
2121
if (indexingMaps.size() != 3)
2222
return false;
2323

24-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
24+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
2727

2828
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
2929
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
4747
if (indexingMaps.size() != 3)
4848
return false;
4949

50-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
50+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
5353

5454
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
5555
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
7373
if (indexingMaps.size() != 3)
7474
return false;
7575

76-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
76+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
7979

8080
if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
8181
map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
@@ -96,6 +96,79 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
9696
return indexingMaps == maps;
9797
}
9898

99+
bool mlir::isVecmat(ArrayAttr indexingMaps) {
100+
if (indexingMaps.size() != 3)
101+
return false;
102+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
103+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
104+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
105+
106+
if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
107+
map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
108+
map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
109+
return false;
110+
}
111+
112+
// Extract dimensions for K * KxN -> N
113+
AffineExpr k = map0.getResult(0);
114+
AffineExpr n = map2.getResult(0);
115+
auto *context = indexingMaps.getContext();
116+
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
117+
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
118+
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
119+
auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
120+
return indexingMaps == maps;
121+
}
122+
123+
bool mlir::isMatvec(ArrayAttr indexingMaps) {
124+
if (indexingMaps.size() != 3)
125+
return false;
126+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
129+
130+
if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
131+
map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
132+
map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
133+
return false;
134+
}
135+
136+
// Extract dimensions for N*K * K -> N
137+
AffineExpr k = map1.getResult(0);
138+
AffineExpr n = map2.getResult(0);
139+
auto *context = indexingMaps.getContext();
140+
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
141+
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
142+
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
143+
auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
144+
return indexingMaps == maps;
145+
}
146+
147+
bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
148+
if (indexingMaps.size() != 3)
149+
return false;
150+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
151+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
152+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
153+
154+
if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
155+
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
156+
map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
157+
return false;
158+
}
159+
160+
// Extract dimensions for B*N*K * B*K -> B*N
161+
AffineExpr b = map0.getResult(0);
162+
AffineExpr k = map1.getResult(1);
163+
AffineExpr n = map2.getResult(1);
164+
auto *context = indexingMaps.getContext();
165+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
166+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
167+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
168+
auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
169+
return indexingMaps == maps;
170+
}
171+
99172
Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
100173
ValueRange newOperands) {
101174
IRMapping bvm;

mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,4 +240,134 @@ TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
240240
EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
241241
}
242242

243+
TEST(isVecmat, Simple) {
244+
MLIRContext context;
245+
246+
AffineExpr k, n;
247+
bindDims(&context, k, n);
248+
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
249+
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
250+
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
251+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
252+
253+
EXPECT_THAT(maps, Truly(isVecmat));
254+
}
255+
256+
TEST(isVecmat, BindingSwapped) {
257+
MLIRContext context;
258+
259+
AffineExpr k, n;
260+
bindDims(&context, n, k); // bind in different order
261+
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
262+
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
263+
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
264+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
265+
266+
EXPECT_THAT(maps, Truly(isVecmat));
267+
}
268+
269+
TEST(isVecmat, WrongDimOrderMatrix) {
270+
MLIRContext context;
271+
272+
AffineExpr k, n;
273+
bindDims(&context, k, n);
274+
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
275+
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
276+
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
277+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
278+
279+
EXPECT_THAT(maps, Not(Truly(isVecmat)));
280+
}
281+
282+
TEST(isMatvec, Simple) {
283+
MLIRContext context;
284+
285+
AffineExpr k, n;
286+
bindDims(&context, k, n);
287+
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
288+
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
289+
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
290+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
291+
292+
EXPECT_THAT(maps, Truly(isMatvec));
293+
}
294+
295+
TEST(isMatvec, BindingSwapped) {
296+
MLIRContext context;
297+
298+
AffineExpr k, n;
299+
bindDims(&context, n, k); // bind in different order
300+
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
301+
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
302+
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
303+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
304+
305+
EXPECT_THAT(maps, Truly(isMatvec));
306+
}
307+
308+
TEST(isMatvec, WrongDimOrderMatrix) {
309+
MLIRContext context;
310+
311+
AffineExpr k, n;
312+
bindDims(&context, k, n);
313+
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
314+
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
315+
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
316+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
317+
318+
EXPECT_THAT(maps, Not(Truly(isMatvec)));
319+
}
320+
321+
TEST(isBatchMatvec, Simple) {
322+
MLIRContext context;
323+
324+
AffineExpr batch, k, n;
325+
bindDims(&context, batch, k, n);
326+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
327+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
328+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
329+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
330+
331+
EXPECT_THAT(maps, Truly(isBatchMatvec));
332+
}
333+
334+
TEST(isBatchMatvec, BindingSwapped) {
335+
MLIRContext context;
336+
337+
AffineExpr batch, k, n;
338+
bindDims(&context, batch, n, k); // bind in different order
339+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
340+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
341+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
342+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
343+
344+
EXPECT_THAT(maps, Truly(isBatchMatvec));
345+
}
346+
347+
TEST(isBatchMatvec, Matmul) {
348+
MLIRContext context;
349+
350+
AffineExpr m, n, k;
351+
bindDims(&context, m, n, k);
352+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
353+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
354+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
355+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
356+
357+
EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
358+
}
359+
360+
TEST(isBatchMatvec, WrongDimOrderMatrix) {
361+
MLIRContext context;
362+
363+
AffineExpr batch, k, n;
364+
bindDims(&context, batch, k, n);
365+
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
366+
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
367+
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
368+
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
369+
370+
EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
371+
}
372+
243373
} // namespace

0 commit comments

Comments
 (0)