Skip to content

Commit c66f2d0

Browse files
devajithvsjpienaar
authored andcommitted
[mlir-query] Add function extraction feature to mlir-query
This enables specifying the extract modifier to extract all matches into a function. This currently does this very directly by converting all operands to function arguments (ones due to results of other matched ops are dropped) and all results as return values. Differential Revision: https://reviews.llvm.org/D158693
1 parent 5bd01ac commit c66f2d0

File tree

9 files changed

+207
-17
lines changed

9 files changed

+207
-17
lines changed

mlir/include/mlir/Query/Matcher/ErrorBuilder.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ enum class ErrorType {
3737
None,
3838

3939
// Parser Errors
40+
ParserChainedExprInvalidArg,
41+
ParserChainedExprNoCloseParen,
42+
ParserChainedExprNoOpenParen,
4043
ParserFailedToBuildMatcher,
4144
ParserInvalidToken,
45+
ParserMalformedChainedExpr,
4246
ParserNoCloseParen,
4347
ParserNoCode,
4448
ParserNoComma,
@@ -50,9 +54,10 @@ enum class ErrorType {
5054

5155
// Registry Errors
5256
RegistryMatcherNotFound,
57+
RegistryNotBindable,
5358
RegistryValueNotFound,
5459
RegistryWrongArgCount,
55-
RegistryWrongArgType
60+
RegistryWrongArgType,
5661
};
5762

5863
void addError(Diagnostics *error, SourceRange range, ErrorType errorType,

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ class DynMatcher {
6363

6464
bool match(Operation *op) const { return implementation->match(op); }
6565

66+
void setFunctionName(StringRef name) { functionName = name.str(); };
67+
68+
bool hasFunctionName() const { return !functionName.empty(); };
69+
70+
StringRef getFunctionName() const { return functionName; };
71+
6672
private:
6773
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
74+
std::string functionName;
6875
};
6976

7077
} // namespace mlir::query::matcher

mlir/lib/Query/Matcher/Diagnostics.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
3838
return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
3939
case ErrorType::RegistryValueNotFound:
4040
return "Value not found: $0";
41+
case ErrorType::RegistryNotBindable:
42+
return "Matcher does not support binding.";
4143

4244
case ErrorType::ParserStringError:
4345
return "Error parsing string token: <$0>";
@@ -57,6 +59,14 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
5759
return "Unexpected end of code.";
5860
case ErrorType::ParserOverloadedType:
5961
return "Input value has unresolved overloaded type: $0";
62+
case ErrorType::ParserMalformedChainedExpr:
63+
return "Period not followed by valid chained call.";
64+
case ErrorType::ParserChainedExprInvalidArg:
65+
return "Missing/Invalid argument for the chained call.";
66+
case ErrorType::ParserChainedExprNoCloseParen:
67+
return "Missing ')' for the chained call.";
68+
case ErrorType::ParserChainedExprNoOpenParen:
69+
return "Missing '(' for the chained call.";
6070
case ErrorType::ParserFailedToBuildMatcher:
6171
return "Failed to build matcher: $0.";
6272

mlir/lib/Query/Matcher/Parser.cpp

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,17 @@ struct Parser::TokenInfo {
2626
text = newText;
2727
}
2828

29+
// Known identifiers.
30+
static const char *const ID_Extract;
31+
2932
llvm::StringRef text;
3033
TokenKind kind = TokenKind::Eof;
3134
SourceRange range;
3235
VariantValue value;
3336
};
3437

38+
const char *const Parser::TokenInfo::ID_Extract = "extract";
39+
3540
class Parser::CodeTokenizer {
3641
public:
3742
// Constructor with matcherCode and error
@@ -298,6 +303,36 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
298303
return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
299304
}
300305

306+
bool Parser::parseChainedExpression(std::string &argument) {
307+
// Parse the parenthesized argument to .extract("foo")
308+
// Note: EOF is handled inside the consume functions and would fail below when
309+
// checking token kind.
310+
const TokenInfo openToken = tokenizer->consumeNextToken();
311+
const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
312+
const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
313+
314+
if (openToken.kind != TokenKind::OpenParen) {
315+
error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
316+
return false;
317+
}
318+
319+
if (argumentToken.kind != TokenKind::Literal ||
320+
!argumentToken.value.isString()) {
321+
error->addError(argumentToken.range,
322+
ErrorType::ParserChainedExprInvalidArg);
323+
return false;
324+
}
325+
326+
if (closeToken.kind != TokenKind::CloseParen) {
327+
error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
328+
return false;
329+
}
330+
331+
// If all checks passed, extract the argument and return true.
332+
argument = argumentToken.value.getString();
333+
return true;
334+
}
335+
301336
// Parse the arguments of a matcher
302337
bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
303338
const TokenInfo &nameToken, TokenInfo &endToken) {
@@ -364,13 +399,34 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
364399
return false;
365400
}
366401

402+
std::string functionName;
403+
if (tokenizer->peekNextToken().kind == TokenKind::Period) {
404+
tokenizer->consumeNextToken();
405+
TokenInfo chainCallToken = tokenizer->consumeNextToken();
406+
if (chainCallToken.kind == TokenKind::CodeCompletion) {
407+
addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
408+
return false;
409+
}
410+
411+
if (chainCallToken.kind != TokenKind::Ident ||
412+
chainCallToken.text != TokenInfo::ID_Extract) {
413+
error->addError(chainCallToken.range,
414+
ErrorType::ParserMalformedChainedExpr);
415+
return false;
416+
}
417+
418+
if (chainCallToken.text == TokenInfo::ID_Extract &&
419+
!parseChainedExpression(functionName))
420+
return false;
421+
}
422+
367423
if (!ctor)
368424
return false;
369425
// Merge the start and end infos.
370426
SourceRange matcherRange = nameToken.range;
371427
matcherRange.end = endToken.range.end;
372-
VariantMatcher result =
373-
sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
428+
VariantMatcher result = sema->actOnMatcherExpression(
429+
*ctor, matcherRange, functionName, args, error);
374430
if (result.isNull())
375431
return false;
376432
*value = result;
@@ -470,9 +526,10 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
470526
}
471527

472528
VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
473-
MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
474-
Diagnostics *error) {
475-
return RegistryManager::constructMatcher(ctor, nameRange, args, error);
529+
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
530+
llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
531+
return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
532+
error);
476533
}
477534

478535
std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(

mlir/lib/Query/Matcher/Parser.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,9 @@ class Parser {
6464

6565
// Process a matcher expression. The caller takes ownership of the Matcher
6666
// object returned.
67-
virtual VariantMatcher
68-
actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
69-
llvm::ArrayRef<ParserValue> args,
70-
Diagnostics *error) = 0;
67+
virtual VariantMatcher actOnMatcherExpression(
68+
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
69+
llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
7170

7271
// Look up a matcher by name in the matcher name found by the parser.
7372
virtual std::optional<MatcherCtor>
@@ -93,10 +92,11 @@ class Parser {
9392
std::optional<MatcherCtor>
9493
lookupMatcherCtor(llvm::StringRef matcherName) override;
9594

96-
VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
97-
SourceRange nameRange,
98-
llvm::ArrayRef<ParserValue> args,
99-
Diagnostics *error) override;
95+
VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
96+
SourceRange NameRange,
97+
StringRef functionName,
98+
ArrayRef<ParserValue> Args,
99+
Diagnostics *Error) override;
100100

101101
std::vector<ArgKind> getAcceptedCompletionTypes(
102102
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@@ -153,6 +153,8 @@ class Parser {
153153
Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
154154
const NamedValueMap *namedValues, Diagnostics *error);
155155

156+
bool parseChainedExpression(std::string &argument);
157+
156158
bool parseExpressionImpl(VariantValue *value);
157159

158160
bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,

mlir/lib/Query/Matcher/RegistryManager.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,19 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
132132

133133
VariantMatcher RegistryManager::constructMatcher(
134134
MatcherCtor ctor, internal::SourceRange nameRange,
135-
llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
136-
return ctor->create(nameRange, args, error);
135+
llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
136+
internal::Diagnostics *error) {
137+
VariantMatcher out = ctor->create(nameRange, args, error);
138+
if (functionName.empty() || out.isNull())
139+
return out;
140+
141+
if (std::optional<DynMatcher> result = out.getDynMatcher()) {
142+
result->setFunctionName(functionName);
143+
return VariantMatcher::SingleMatcher(*result);
144+
}
145+
146+
error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
147+
return {};
137148
}
138149

139150
} // namespace mlir::query::matcher

mlir/lib/Query/Matcher/RegistryManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class RegistryManager {
6161

6262
static VariantMatcher constructMatcher(MatcherCtor ctor,
6363
internal::SourceRange nameRange,
64+
llvm::StringRef functionName,
6465
ArrayRef<ParserValue> args,
6566
internal::Diagnostics *error);
6667
};

mlir/lib/Query/Query.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include "mlir/Query/Query.h"
1010
#include "QueryParser.h"
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/IR/IRMapping.h"
1113
#include "mlir/Query/Matcher/MatchFinder.h"
1214
#include "mlir/Query/QuerySession.h"
1315
#include "mlir/Support/LogicalResult.h"
@@ -34,6 +36,70 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
3436
"\"" + binding + "\" binds here");
3537
}
3638

39+
// TODO: Extract into a helper function that can be reused outside query
40+
// context.
41+
static Operation *extractFunction(std::vector<Operation *> &ops,
42+
MLIRContext *context,
43+
llvm::StringRef functionName) {
44+
context->loadDialect<func::FuncDialect>();
45+
OpBuilder builder(context);
46+
47+
// Collect data for function creation
48+
std::vector<Operation *> slice;
49+
std::vector<Value> values;
50+
std::vector<Type> outputTypes;
51+
52+
for (auto *op : ops) {
53+
// Return op's operands are propagated, but the op itself isn't needed.
54+
if (!isa<func::ReturnOp>(op))
55+
slice.push_back(op);
56+
57+
// All results are returned by the extracted function.
58+
outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
59+
op->getResults().getTypes().end());
60+
61+
// Track all values that need to be taken as input to function.
62+
values.insert(values.end(), op->getOperands().begin(),
63+
op->getOperands().end());
64+
}
65+
66+
// Create the function
67+
FunctionType funcType =
68+
builder.getFunctionType(ValueRange(values), outputTypes);
69+
auto loc = builder.getUnknownLoc();
70+
func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
71+
72+
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
73+
74+
// Map original values to function arguments
75+
IRMapping mapper;
76+
for (const auto &arg : llvm::enumerate(values))
77+
mapper.map(arg.value(), funcOp.getArgument(arg.index()));
78+
79+
// Clone operations and build function body
80+
std::vector<Operation *> clonedOps;
81+
std::vector<Value> clonedVals;
82+
for (Operation *slicedOp : slice) {
83+
Operation *clonedOp =
84+
clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
85+
clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
86+
clonedOp->result_end());
87+
}
88+
// Add return operation
89+
builder.create<func::ReturnOp>(loc, clonedVals);
90+
91+
// Remove unused function arguments
92+
size_t currentIndex = 0;
93+
while (currentIndex < funcOp.getNumArguments()) {
94+
if (funcOp.getArgument(currentIndex).use_empty())
95+
funcOp.eraseArgument(currentIndex);
96+
else
97+
++currentIndex;
98+
}
99+
100+
return funcOp;
101+
}
102+
37103
Query::~Query() = default;
38104

39105
mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@@ -65,9 +131,21 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
65131

66132
mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
67133
QuerySession &qs) const {
134+
Operation *rootOp = qs.getRootOp();
68135
int matchCount = 0;
69136
std::vector<Operation *> matches =
70-
matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
137+
matcher::MatchFinder().getMatches(rootOp, matcher);
138+
139+
// An extract call is recognized by considering if the matcher has a name.
140+
// TODO: Consider making the extract more explicit.
141+
if (matcher.hasFunctionName()) {
142+
auto functionName = matcher.getFunctionName();
143+
Operation *function =
144+
extractFunction(matches, rootOp->getContext(), functionName);
145+
os << "\n" << *function << "\n\n";
146+
return mlir::success();
147+
}
148+
71149
os << "\n";
72150
for (Operation *op : matches) {
73151
os << "Match #" << ++matchCount << ":\n\n";
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
2+
3+
// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
4+
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
5+
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
6+
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
7+
// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
8+
9+
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
10+
%sum0 = arith.addf %a, %b : f32
11+
%sub0 = arith.subf %sum0, %c : f32
12+
%mul0 = arith.mulf %a, %sub0 : f32
13+
%sum1 = arith.addf %b, %c : f32
14+
%mul1 = arith.mulf %sum1, %mul0 : f32
15+
%sub2 = arith.subf %mul1, %a : f32
16+
%sum2 = arith.addf %mul1, %b : f32
17+
%mul2 = arith.mulf %sub2, %sum2 : f32
18+
return %mul2 : f32
19+
}

0 commit comments

Comments
 (0)