Skip to content

Commit 37682e7

Browse files
committed
Add test cases
Add test cases for variadic matchers Relocate variadic matchers
1 parent 11792a6 commit 37682e7

File tree

7 files changed

+60
-19
lines changed

7 files changed

+60
-19
lines changed

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ struct VariadicOperatorMatcherFunc {
214214
}
215215
};
216216

217+
namespace internal {
218+
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
219+
anyOf = {DynMatcher::AnyOf};
220+
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
221+
allOf = {DynMatcher::AllOf};
222+
} // namespace internal
217223
} // namespace mlir::query::matcher
218224

219225
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

mlir/include/mlir/Query/Matcher/SliceMatchers.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,6 @@ class PredicateForwardSliceMatcher {
186186
bool inclusive;
187187
};
188188

189-
namespace internal {
190-
const matcher::VariadicOperatorMatcherFunc<1,
191-
std::numeric_limits<unsigned>::max()>
192-
anyOf = {matcher::DynMatcher::AnyOf};
193-
const matcher::VariadicOperatorMatcherFunc<1,
194-
std::numeric_limits<unsigned>::max()>
195-
allOf = {matcher::DynMatcher::AllOf};
196-
} // namespace internal
197189
/// Matches transitive defs of a top-level operation up to N levels.
198190
template <typename Matcher>
199191
inline BackwardSliceMatcher<Matcher>

mlir/lib/Query/Matcher/MatchersInternal.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
//
9-
// Implements the base layer of the matcher framework.
10-
//
11-
//===----------------------------------------------------------------------===//
8+
129
#include "mlir/Query/Matcher/MatchersInternal.h"
1310
#include "llvm/ADT/SetVector.h"
1411

mlir/lib/Query/Matcher/VariantValue.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,9 @@ class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload {
4646

4747
std::string getTypeAsString() const override {
4848
std::string inner;
49-
for (size_t i = 0, e = args.size(); i != e; ++i) {
50-
if (i != 0)
51-
inner += "&";
52-
inner += args[i].getTypeAsString();
53-
}
49+
llvm::interleave(
50+
args, [&](auto const &arg) { inner += arg.getTypeAsString(); },
51+
[&] { inner += " & "; });
5452
return inner;
5553
}
5654

mlir/test/mlir-query/complex-test.mlir renamed to mlir/test/mlir-query/backward-slice-union.mlir

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
1+
// RUN: mlir-query %s -c "m anyOf(getAllDefinitions(hasOpName(\"arith.addf\"),2),getAllDefinitions(hasOpName(\"tensor.extract\"),1))" | FileCheck %s
22

33
#map = affine_map<(d0, d1) -> (d0, d1)>
44
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
@@ -19,14 +19,23 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
1919
}
2020

2121
// CHECK: Match #1:
22-
2322
// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
2423
// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
24+
25+
// CHECK: {{.*}}.mlir:7:10: note: "root" binds here
2526
// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
2627

2728
// CHECK: Match #2:
29+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
30+
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
2831

32+
// CHECK: {{.*}}.mlir:14:18: note: "root" binds here
33+
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
34+
35+
// CHECK: Match #3:
2936
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
3037
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
3138
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
39+
40+
// CHECK: {{.*}}.mlir:15:10: note: "root" binds here
3241
// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-query %s -c "m getUsersByPredicate(anyOf(hasOpName(\"memref.alloc\"),isConstantOp()),hasOpName(\"affine.load\"),true)" | FileCheck %s
2+
3+
func.func @slice_depth1_loop_nest_with_offsets() {
4+
%0 = memref.alloc() : memref<100xf32>
5+
%cst = arith.constant 7.000000e+00 : f32
6+
affine.for %i0 = 0 to 16 {
7+
%a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0)
8+
affine.store %cst, %0[%a0] : memref<100xf32>
9+
}
10+
affine.for %i1 = 4 to 8 {
11+
%a1 = affine.apply affine_map<(d0) -> (d0 - 1)>(%i1)
12+
%1 = affine.load %0[%a1] : memref<100xf32>
13+
}
14+
return
15+
}
16+
17+
// CHECK: Match #1:
18+
// CHECK: {{.*}}.mlir:4:8: note: "root" binds here
19+
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<100xf32>
20+
21+
// CHECK: affine.store %cst, %0[%a0] : memref<100xf32>
22+
23+
// CHECK: Match #2:
24+
// CHECK: {{.*}}.mlir:5:10: note: "root" binds here
25+
// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
26+
27+
// CHECK: affine.store %[[CST]], %0[%a0] : memref<100xf32>
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
// RUN: mlir-query %s -c "m allOf(hasOpName(\"memref.alloca\"), hasOpAttrName(\"alignment\"))" | FileCheck %s
3+
4+
func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
5+
%0 = memref.alloca(%arg0, %arg1) : memref<?x?xf32>
6+
memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32>
7+
return %0 : memref<?x?xf32>
8+
}
9+
10+
// CHECK: Match #1:
11+
// CHECK: {{.*}}.mlir:6:3: note: "root" binds here
12+
// CHECK: memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32>

0 commit comments

Comments
 (0)