Skip to content

Commit 214a323

Browse files
committed
conflicts fixed
1 parent 23d38a7 commit 214a323

File tree

9 files changed

+82
-118
lines changed

9 files changed

+82
-118
lines changed

mlir/include/mlir/Query/Matcher/ExtraMatchers.h

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
1+
//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
2+
//-*-===//
23
//
34
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45
// See https://llvm.org/LICENSE.txt for license information.
@@ -36,7 +37,6 @@ class BackwardSliceMatcher {
3637
bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
3738
QueryOptions &options, unsigned tempHops) {
3839

39-
bool validSlice = true;
4040
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4141
return false;
4242
}
@@ -56,16 +56,12 @@ class BackwardSliceMatcher {
5656
Operation *parentOp = block->getParentOp();
5757

5858
if (parentOp && backwardSlice.count(parentOp) == 0) {
59-
if (parentOp->getNumRegions() == 1 &&
60-
parentOp->getRegion(0).getBlocks().size() == 1) {
61-
validSlice = false;
62-
return;
63-
};
64-
matches(parentOp, backwardSlice, options, tempHops - 1);
59+
assert(parentOp->getNumRegions() == 1 &&
60+
parentOp->getRegion(0).getBlocks().size() == 1);
61+
matches(parentOp, backwardSlice, options, tempHops-1);
6562
}
6663
} else {
67-
validSlice = false;
68-
return;
64+
llvm_unreachable("No definingOp and not a block argument.");
6965
}
7066
};
7167

@@ -78,22 +74,13 @@ class BackwardSliceMatcher {
7874
for (OpOperand &operand : op->getOpOperands()) {
7975
if (!descendents.contains(operand.get().getParentRegion()))
8076
processValue(operand.get());
81-
if (!validSlice)
82-
return;
8377
}
8478
});
8579
});
8680
}
8781

88-
llvm::for_each(op->getOperands(), [&](Value operand) {
89-
processValue(operand);
90-
if (!validSlice)
91-
return;
92-
});
82+
llvm::for_each(op->getOperands(), processValue);
9383
backwardSlice.insert(op);
94-
if (!validSlice) {
95-
return false;
96-
}
9784
return true;
9885
}
9986

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,61 @@
1-
//===- MatchFinder.h - ------------------------------------------*- C++ -*-===//
2-
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4-
// See https://llvm.org/LICENSE.txt for license information.
5-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
//
71
//===----------------------------------------------------------------------===//
8-
//
9-
// This file contains the MatchFinder class, which is used to find operations
10-
// that match a given matcher.
11-
//
2+
// MatchFinder.h
123
//===----------------------------------------------------------------------===//
134

145
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
156
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
167

178
#include "MatchersInternal.h"
18-
#include "mlir/IR/Operation.h"
9+
#include "mlir/Query/QuerySession.h"
10+
#include "llvm/ADT/SetVector.h"
11+
#include "llvm/Support/SourceMgr.h"
12+
#include "llvm/Support/raw_ostream.h"
1913

2014
namespace mlir::query::matcher {
2115

22-
// MatchFinder is used to find all operations that match a given matcher.
2316
class MatchFinder {
17+
private:
18+
// Base print function with binding text
19+
static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
20+
mlir::Operation *op, const std::string &binding) {
21+
auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
22+
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
23+
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
24+
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
25+
"\"" + binding + "\" binds here");
26+
};
27+
2428
public:
25-
// Returns all operations that match the given matcher.
2629
static SetVector<Operation *>
27-
getMatches(Operation *root, QueryOptions &options, DynMatcher matcher) {
28-
SetVector<Operation *> backwardSlice;
30+
getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
31+
llvm::raw_ostream &os, QuerySession &qs) {
32+
unsigned matchCount = 0;
33+
SetVector<Operation *> matchedOps;
34+
SetVector<Operation *> tempStorage;
35+
2936
root->walk([&](Operation *subOp) {
3037
if (matcher.match(subOp)) {
31-
backwardSlice.insert(subOp);
38+
matchedOps.insert(subOp);
39+
os << "Match #" << ++matchCount << ":\n\n";
40+
printMatch(os, qs, subOp, "root");
3241
} else {
33-
matcher.match(subOp, backwardSlice, options);
34-
////
42+
SmallVector<Operation *> printingOps;
43+
size_t sizeBefore = matchedOps.size();
44+
if (matcher.match(subOp, tempStorage, options)) {
45+
os << "Match #" << ++matchCount << ":\n\n";
46+
SmallVector<Operation *> printingOps(tempStorage.takeVector());
47+
for (auto op : printingOps) {
48+
printMatch(os, qs, op, ""); // Using version without binding text
49+
matchedOps.insert(op);
50+
}
51+
printingOps.clear();
52+
}
3553
}
3654
});
37-
return backwardSlice;
55+
return matchedOps;
3856
}
3957
};
4058

4159
} // namespace mlir::query::matcher
4260

43-
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
61+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
22
//
3-
// Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
66
//
@@ -11,11 +11,6 @@
1111

1212
#include "mlir/IR/Matchers.h"
1313
#include "llvm/ADT/IntrusiveRefCntPtr.h"
14-
#include "llvm/ADT/MapVector.h"
15-
#include <memory>
16-
#include <stack>
17-
#include <unordered_set>
18-
#include <vector>
1914

2015
namespace mlir {
2116
namespace query {
@@ -112,4 +107,4 @@ class DynMatcher {
112107

113108
} // namespace mlir::query::matcher
114109

115-
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
110+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

mlir/include/mlir/Query/Query.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#define MLIR_TOOLS_MLIRQUERY_QUERY_H
1111

1212
#include "Matcher/VariantValue.h"
13-
#include "mlir/Analysis/SliceAnalysis.h"
1413
#include "llvm/ADT/IntrusiveRefCntPtr.h"
1514
#include "llvm/ADT/StringRef.h"
1615
#include "llvm/LineEditor/LineEditor.h"

mlir/include/mlir/Query/QuerySession.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@
99
#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
1010
#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
1111

12+
#include "Matcher/VariantValue.h"
1213
#include "mlir/IR/Operation.h"
1314
#include "mlir/Query/Matcher/Registry.h"
1415
#include "llvm/ADT/StringMap.h"
1516
#include "llvm/Support/SourceMgr.h"
1617

18+
namespace mlir::query::matcher {
19+
class Registry;
20+
}
21+
1722
namespace mlir::query {
1823

19-
class Registry;
2024
// Represents the state for a particular mlir-query session.
2125
class QuerySession {
2226
public:

mlir/lib/Query/Matcher/Parser.cpp

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ class Parser::CodeTokenizer {
157157
}
158158

159159
void consumeNumberLiteral(TokenInfo *result) {
160-
bool isFloatingLiteral = false;
161160
unsigned length = 1;
162161
if (code.size() > 1) {
163162
// Consume the 'x' or 'b' radix modifier, if present.
@@ -170,39 +169,17 @@ class Parser::CodeTokenizer {
170169
while (length < code.size() && isdigit(code[length]))
171170
++length;
172171

173-
// Try to recognize a floating point literal.
174-
while (length < code.size()) {
175-
char c = code[length];
176-
if (c == '-' || c == '+' || c == '.' || isdigit(c)) {
177-
isFloatingLiteral = true;
178-
length++;
179-
} else {
180-
break;
181-
}
182-
}
183-
184172
result->text = code.take_front(length);
185173
code = code.drop_front(length);
186174

187-
if (isFloatingLiteral) {
188-
char *end;
189-
errno = 0;
190-
std::string text = result->text.str();
191-
double doubleValue = strtod(text.c_str(), &end);
192-
if (*end == 0 && errno == 0) {
193-
result->kind = TokenKind::Literal;
194-
result->value = static_cast<double>(doubleValue);
195-
return;
196-
}
197-
} else {
198-
unsigned value;
199-
if (!result->text.getAsInteger(0, value)) {
200-
result->kind = TokenKind::Literal;
201-
result->value = value;
202-
return;
203-
}
175+
unsigned value;
176+
if (!result->text.getAsInteger(0, value)) {
177+
result->kind = TokenKind::Literal;
178+
result->value = static_cast<unsigned>(value);
179+
return;
204180
}
205181
}
182+
206183
// Consume a string literal, handle escape sequences and missing closing
207184
// quote.
208185
void consumeStringLiteral(TokenInfo *result) {

mlir/lib/Query/Query.cpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
#include "llvm/ADT/SetVector.h"
1616
#include "llvm/Support/SourceMgr.h"
1717
#include "llvm/Support/raw_ostream.h"
18-
#include <unordered_map>
19-
#include <unordered_set>
2018

2119
namespace mlir::query {
2220

@@ -29,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
2927
return QueryParser::complete(line, pos, qs);
3028
}
3129

32-
static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
33-
mlir::Operation *op, const std::string &binding) {
34-
auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
35-
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
36-
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
37-
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
38-
"\"" + binding + "\" binds here");
39-
}
40-
4130
// TODO: Extract into a helper function that can be reused outside query
4231
// context.
4332
static Operation *extractFunction(std::vector<Operation *> &ops,
@@ -150,8 +139,8 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
150139

151140
QueryOptions options;
152141
parseQueryOptions(qs, options);
153-
auto matches =
154-
matcher::MatchFinder().getMatches(rootOp, options, std::move(matcher));
142+
auto matches = matcher::MatchFinder().getMatches(rootOp, options,
143+
std::move(matcher), os, qs);
155144

156145
// An extract call is recognized by considering if the matcher has a name.
157146
// TODO: Consider making the extract more explicit.
@@ -164,14 +153,6 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
164153
// return mlir::success();
165154
// }
166155

167-
os << "\n";
168-
for (Operation *op : matches) {
169-
os << "Match #" << ++matchCount << ":\n\n";
170-
// Placeholder "root" binding for the initial draft.
171-
printMatch(os, qs, op, "root");
172-
}
173-
os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
174-
175156
return mlir::success();
176157
}
177158

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
// RUN: mlir-query %s -c "m definedBy(hasOpName(())" | FileCheck %s
22

3+
func.func @region_control_flow(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr = "func"} {
4+
%0 = memref.alloca() {test.ptr = "alloca_1"} : memref<8x64xf32>
5+
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
6+
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
37

4-
func.func @matrix_multiply(%A: memref<4x4xf32>, %B: memref<4x4xf32>, %C: memref<4x4xf32>) {
5-
%c0 = arith.constant 0 : index
6-
%c4 = arith.constant 4 : index
7-
%c1 = arith.constant 1 : index
8+
%3 = scf.if %cond -> (memref<8x64xf32>) {
9+
scf.yield %0 : memref<8x64xf32>
10+
} else {
11+
scf.yield %0 : memref<8x64xf32>
12+
} {test.ptr = "if_alloca"}
813

9-
scf.for %i = %c0 to %c4 step %c1 {
10-
scf.for %j = %c0 to %c4 step %c1 {
11-
%sum_init = arith.constant 0.0 : f32
12-
%sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %sum_init) -> (f32) {
13-
%a_ik = memref.load %A[%i, %k] : memref<4x4xf32>
14-
%b_kj = memref.load %B[%k, %j] : memref<4x4xf32>
15-
%prod = arith.mulf %a_ik, %b_kj : f32
16-
%new_acc = arith.addf %acc, %prod : f32
17-
scf.yield %new_acc : f32
18-
}
19-
memref.store %sum, %C[%i, %j] : memref<4x4xf32>
20-
}
21-
}
14+
%4 = scf.if %cond -> (memref<8x64xf32>) {
15+
scf.yield %0 : memref<8x64xf32>
16+
} else {
17+
scf.yield %1 : memref<8x64xf32>
18+
} {test.ptr = "if_alloca_merge"}
19+
20+
%5 = scf.if %cond -> (memref<8x64xf32>) {
21+
scf.yield %2 : memref<8x64xf32>
22+
} else {
23+
scf.yield %2 : memref<8x64xf32>
24+
} {test.ptr = "if_alloc"}
2225
return
23-
}
26+
}

mlir/test/mlir-query/function-extraction.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
55
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
66
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
7-
// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
7+
// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32S
88

99
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
1010
%sum0 = arith.addf %a, %b : f32

0 commit comments

Comments
 (0)