Skip to content

Commit de55c2f

Browse files
committed
Revert "[mlir-query] Add function extraction feature to mlir-query"
This reverts commit c66f2d0. The bot is broken.
1 parent 82c1bfc commit de55c2f

File tree

9 files changed

+17
-207
lines changed

9 files changed

+17
-207
lines changed

mlir/include/mlir/Query/Matcher/ErrorBuilder.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,8 @@ enum class ErrorType {
3737
None,
3838

3939
// Parser Errors
40-
ParserChainedExprInvalidArg,
41-
ParserChainedExprNoCloseParen,
42-
ParserChainedExprNoOpenParen,
4340
ParserFailedToBuildMatcher,
4441
ParserInvalidToken,
45-
ParserMalformedChainedExpr,
4642
ParserNoCloseParen,
4743
ParserNoCode,
4844
ParserNoComma,
@@ -54,10 +50,9 @@ enum class ErrorType {
5450

5551
// Registry Errors
5652
RegistryMatcherNotFound,
57-
RegistryNotBindable,
5853
RegistryValueNotFound,
5954
RegistryWrongArgCount,
60-
RegistryWrongArgType,
55+
RegistryWrongArgType
6156
};
6257

6358
void addError(Diagnostics *error, SourceRange range, ErrorType errorType,

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,8 @@ class DynMatcher {
6363

6464
bool match(Operation *op) const { return implementation->match(op); }
6565

66-
void setFunctionName(StringRef name) { functionName = name.str(); };
67-
68-
bool hasFunctionName() const { return !functionName.empty(); };
69-
70-
StringRef getFunctionName() const { return functionName; };
71-
7266
private:
7367
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
74-
std::string functionName;
7568
};
7669

7770
} // namespace mlir::query::matcher

mlir/lib/Query/Matcher/Diagnostics.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
3838
return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
3939
case ErrorType::RegistryValueNotFound:
4040
return "Value not found: $0";
41-
case ErrorType::RegistryNotBindable:
42-
return "Matcher does not support binding.";
4341

4442
case ErrorType::ParserStringError:
4543
return "Error parsing string token: <$0>";
@@ -59,14 +57,6 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
5957
return "Unexpected end of code.";
6058
case ErrorType::ParserOverloadedType:
6159
return "Input value has unresolved overloaded type: $0";
62-
case ErrorType::ParserMalformedChainedExpr:
63-
return "Period not followed by valid chained call.";
64-
case ErrorType::ParserChainedExprInvalidArg:
65-
return "Missing/Invalid argument for the chained call.";
66-
case ErrorType::ParserChainedExprNoCloseParen:
67-
return "Missing ')' for the chained call.";
68-
case ErrorType::ParserChainedExprNoOpenParen:
69-
return "Missing '(' for the chained call.";
7060
case ErrorType::ParserFailedToBuildMatcher:
7161
return "Failed to build matcher: $0.";
7262

mlir/lib/Query/Matcher/Parser.cpp

Lines changed: 5 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,12 @@ struct Parser::TokenInfo {
2626
text = newText;
2727
}
2828

29-
// Known identifiers.
30-
static const char *const ID_Extract;
31-
3229
llvm::StringRef text;
3330
TokenKind kind = TokenKind::Eof;
3431
SourceRange range;
3532
VariantValue value;
3633
};
3734

38-
const char *const Parser::TokenInfo::ID_Extract = "extract";
39-
4035
class Parser::CodeTokenizer {
4136
public:
4237
// Constructor with matcherCode and error
@@ -303,36 +298,6 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
303298
return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
304299
}
305300

306-
bool Parser::parseChainedExpression(std::string &argument) {
307-
// Parse the parenthesized argument to .extract("foo")
308-
// Note: EOF is handled inside the consume functions and would fail below when
309-
// checking token kind.
310-
const TokenInfo openToken = tokenizer->consumeNextToken();
311-
const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
312-
const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
313-
314-
if (openToken.kind != TokenKind::OpenParen) {
315-
error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
316-
return false;
317-
}
318-
319-
if (argumentToken.kind != TokenKind::Literal ||
320-
!argumentToken.value.isString()) {
321-
error->addError(argumentToken.range,
322-
ErrorType::ParserChainedExprInvalidArg);
323-
return false;
324-
}
325-
326-
if (closeToken.kind != TokenKind::CloseParen) {
327-
error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
328-
return false;
329-
}
330-
331-
// If all checks passed, extract the argument and return true.
332-
argument = argumentToken.value.getString();
333-
return true;
334-
}
335-
336301
// Parse the arguments of a matcher
337302
bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
338303
const TokenInfo &nameToken, TokenInfo &endToken) {
@@ -399,34 +364,13 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
399364
return false;
400365
}
401366

402-
std::string functionName;
403-
if (tokenizer->peekNextToken().kind == TokenKind::Period) {
404-
tokenizer->consumeNextToken();
405-
TokenInfo chainCallToken = tokenizer->consumeNextToken();
406-
if (chainCallToken.kind == TokenKind::CodeCompletion) {
407-
addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
408-
return false;
409-
}
410-
411-
if (chainCallToken.kind != TokenKind::Ident ||
412-
chainCallToken.text != TokenInfo::ID_Extract) {
413-
error->addError(chainCallToken.range,
414-
ErrorType::ParserMalformedChainedExpr);
415-
return false;
416-
}
417-
418-
if (chainCallToken.text == TokenInfo::ID_Extract &&
419-
!parseChainedExpression(functionName))
420-
return false;
421-
}
422-
423367
if (!ctor)
424368
return false;
425369
// Merge the start and end infos.
426370
SourceRange matcherRange = nameToken.range;
427371
matcherRange.end = endToken.range.end;
428-
VariantMatcher result = sema->actOnMatcherExpression(
429-
*ctor, matcherRange, functionName, args, error);
372+
VariantMatcher result =
373+
sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
430374
if (result.isNull())
431375
return false;
432376
*value = result;
@@ -526,10 +470,9 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
526470
}
527471

528472
VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
529-
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
530-
llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
531-
return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
532-
error);
473+
MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
474+
Diagnostics *error) {
475+
return RegistryManager::constructMatcher(ctor, nameRange, args, error);
533476
}
534477

535478
std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(

mlir/lib/Query/Matcher/Parser.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ class Parser {
6464

6565
// Process a matcher expression. The caller takes ownership of the Matcher
6666
// object returned.
67-
virtual VariantMatcher actOnMatcherExpression(
68-
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
69-
llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
67+
virtual VariantMatcher
68+
actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
69+
llvm::ArrayRef<ParserValue> args,
70+
Diagnostics *error) = 0;
7071

7172
// Look up a matcher by name in the matcher name found by the parser.
7273
virtual std::optional<MatcherCtor>
@@ -92,11 +93,10 @@ class Parser {
9293
std::optional<MatcherCtor>
9394
lookupMatcherCtor(llvm::StringRef matcherName) override;
9495

95-
VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
96-
SourceRange NameRange,
97-
StringRef functionName,
98-
ArrayRef<ParserValue> Args,
99-
Diagnostics *Error) override;
96+
VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
97+
SourceRange nameRange,
98+
llvm::ArrayRef<ParserValue> args,
99+
Diagnostics *error) override;
100100

101101
std::vector<ArgKind> getAcceptedCompletionTypes(
102102
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@@ -153,8 +153,6 @@ class Parser {
153153
Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
154154
const NamedValueMap *namedValues, Diagnostics *error);
155155

156-
bool parseChainedExpression(std::string &argument);
157-
158156
bool parseExpressionImpl(VariantValue *value);
159157

160158
bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,

mlir/lib/Query/Matcher/RegistryManager.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,8 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
132132

133133
VariantMatcher RegistryManager::constructMatcher(
134134
MatcherCtor ctor, internal::SourceRange nameRange,
135-
llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
136-
internal::Diagnostics *error) {
137-
VariantMatcher out = ctor->create(nameRange, args, error);
138-
if (functionName.empty() || out.isNull())
139-
return out;
140-
141-
if (std::optional<DynMatcher> result = out.getDynMatcher()) {
142-
result->setFunctionName(functionName);
143-
return VariantMatcher::SingleMatcher(*result);
144-
}
145-
146-
error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
147-
return {};
135+
llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
136+
return ctor->create(nameRange, args, error);
148137
}
149138

150139
} // namespace mlir::query::matcher

mlir/lib/Query/Matcher/RegistryManager.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class RegistryManager {
6161

6262
static VariantMatcher constructMatcher(MatcherCtor ctor,
6363
internal::SourceRange nameRange,
64-
llvm::StringRef functionName,
6564
ArrayRef<ParserValue> args,
6665
internal::Diagnostics *error);
6766
};

mlir/lib/Query/Query.cpp

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
#include "mlir/Query/Query.h"
1010
#include "QueryParser.h"
11-
#include "mlir/Dialect/Func/IR/FuncOps.h"
12-
#include "mlir/IR/IRMapping.h"
1311
#include "mlir/Query/Matcher/MatchFinder.h"
1412
#include "mlir/Query/QuerySession.h"
1513
#include "mlir/Support/LogicalResult.h"
@@ -36,70 +34,6 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
3634
"\"" + binding + "\" binds here");
3735
}
3836

39-
// TODO: Extract into a helper function that can be reused outside query
40-
// context.
41-
static Operation *extractFunction(std::vector<Operation *> &ops,
42-
MLIRContext *context,
43-
llvm::StringRef functionName) {
44-
context->loadDialect<func::FuncDialect>();
45-
OpBuilder builder(context);
46-
47-
// Collect data for function creation
48-
std::vector<Operation *> slice;
49-
std::vector<Value> values;
50-
std::vector<Type> outputTypes;
51-
52-
for (auto *op : ops) {
53-
// Return op's operands are propagated, but the op itself isn't needed.
54-
if (!isa<func::ReturnOp>(op))
55-
slice.push_back(op);
56-
57-
// All results are returned by the extracted function.
58-
outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
59-
op->getResults().getTypes().end());
60-
61-
// Track all values that need to be taken as input to function.
62-
values.insert(values.end(), op->getOperands().begin(),
63-
op->getOperands().end());
64-
}
65-
66-
// Create the function
67-
FunctionType funcType =
68-
builder.getFunctionType(ValueRange(values), outputTypes);
69-
auto loc = builder.getUnknownLoc();
70-
func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
71-
72-
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
73-
74-
// Map original values to function arguments
75-
IRMapping mapper;
76-
for (const auto &arg : llvm::enumerate(values))
77-
mapper.map(arg.value(), funcOp.getArgument(arg.index()));
78-
79-
// Clone operations and build function body
80-
std::vector<Operation *> clonedOps;
81-
std::vector<Value> clonedVals;
82-
for (Operation *slicedOp : slice) {
83-
Operation *clonedOp =
84-
clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
85-
clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
86-
clonedOp->result_end());
87-
}
88-
// Add return operation
89-
builder.create<func::ReturnOp>(loc, clonedVals);
90-
91-
// Remove unused function arguments
92-
size_t currentIndex = 0;
93-
while (currentIndex < funcOp.getNumArguments()) {
94-
if (funcOp.getArgument(currentIndex).use_empty())
95-
funcOp.eraseArgument(currentIndex);
96-
else
97-
++currentIndex;
98-
}
99-
100-
return funcOp;
101-
}
102-
10337
Query::~Query() = default;
10438

10539
mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@@ -131,21 +65,9 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
13165

13266
mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
13367
QuerySession &qs) const {
134-
Operation *rootOp = qs.getRootOp();
13568
int matchCount = 0;
13669
std::vector<Operation *> matches =
137-
matcher::MatchFinder().getMatches(rootOp, matcher);
138-
139-
// An extract call is recognized by considering if the matcher has a name.
140-
// TODO: Consider making the extract more explicit.
141-
if (matcher.hasFunctionName()) {
142-
auto functionName = matcher.getFunctionName();
143-
Operation *function =
144-
extractFunction(matches, rootOp->getContext(), functionName);
145-
os << "\n" << *function << "\n\n";
146-
return mlir::success();
147-
}
148-
70+
matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
14971
os << "\n";
15072
for (Operation *op : matches) {
15173
os << "Match #" << ++matchCount << ":\n\n";

mlir/test/mlir-query/function-extraction.mlir

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)