Skip to content

Commit 58cda1d

Browse files
committed
Fix tests and add type clarity in util function
1 parent f2c6334 commit 58cda1d

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
2121
if (indexingMaps.size() != 3)
2222
return false;
2323

24-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
24+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
2727

2828
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
2929
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
4747
if (indexingMaps.size() != 3)
4848
return false;
4949

50-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
50+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
5353

5454
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
5555
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
7373
if (indexingMaps.size() != 3)
7474
return false;
7575

76-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
76+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
7979

8080
if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
8181
map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
@@ -99,9 +99,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
9999
bool mlir::isVecmat(ArrayAttr indexingMaps) {
100100
if (indexingMaps.size() != 3)
101101
return false;
102-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
103-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
104-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
102+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
103+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
104+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
105105

106106
if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
107107
map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
@@ -123,9 +123,9 @@ bool mlir::isVecmat(ArrayAttr indexingMaps) {
123123
bool mlir::isMatvec(ArrayAttr indexingMaps) {
124124
if (indexingMaps.size() != 3)
125125
return false;
126-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
126+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
129129

130130
if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
131131
map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
@@ -147,9 +147,9 @@ bool mlir::isMatvec(ArrayAttr indexingMaps) {
147147
bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
148148
if (indexingMaps.size() != 3)
149149
return false;
150-
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
151-
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
152-
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
150+
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
151+
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
152+
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
153153

154154
if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
155155
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||

mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ TEST(isVecmat, BindingSwapped) {
257257
MLIRContext context;
258258

259259
AffineExpr k, n;
260-
bindDims(&context, k, n); // bind in different order
260+
bindDims(&context, n, k); // bind in different order
261261
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
262262
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
263263
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
@@ -296,7 +296,7 @@ TEST(isMatvec, BindingSwapped) {
296296
MLIRContext context;
297297

298298
AffineExpr k, n;
299-
bindDims(&context, k, n); // bind in different order
299+
bindDims(&context, n, k); // bind in different order
300300
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
301301
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
302302
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
@@ -335,7 +335,7 @@ TEST(isBatchMatvec, BindingSwapped) {
335335
MLIRContext context;
336336

337337
AffineExpr batch, k, n;
338-
bindDims(&context, batch, k, n); // bind in different order
338+
bindDims(&context, batch, n, k); // bind in different order
339339
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
340340
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
341341
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));

0 commit comments

Comments
 (0)