Skip to content

Commit 58b44c8

Browse files
committed
Reapply "Reapply "[mlir-query] Add function extraction feature to mlir-query""
Fix ASAN by erasing the op extracted post printing. This reverts commit 732a5cb.
1 parent 5b4759f commit 58b44c8

File tree

10 files changed

+209
-17
lines changed

10 files changed

+209
-17
lines changed

mlir/include/mlir/Query/Matcher/ErrorBuilder.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ enum class ErrorType {
3737
None,
3838

3939
// Parser Errors
40+
ParserChainedExprInvalidArg,
41+
ParserChainedExprNoCloseParen,
42+
ParserChainedExprNoOpenParen,
4043
ParserFailedToBuildMatcher,
4144
ParserInvalidToken,
45+
ParserMalformedChainedExpr,
4246
ParserNoCloseParen,
4347
ParserNoCode,
4448
ParserNoComma,
@@ -50,9 +54,10 @@ enum class ErrorType {
5054

5155
// Registry Errors
5256
RegistryMatcherNotFound,
57+
RegistryNotBindable,
5358
RegistryValueNotFound,
5459
RegistryWrongArgCount,
55-
RegistryWrongArgType
60+
RegistryWrongArgType,
5661
};
5762

5863
void addError(Diagnostics *error, SourceRange range, ErrorType errorType,

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ class DynMatcher {
6363

6464
bool match(Operation *op) const { return implementation->match(op); }
6565

66+
void setFunctionName(StringRef name) { functionName = name.str(); };
67+
68+
bool hasFunctionName() const { return !functionName.empty(); };
69+
70+
StringRef getFunctionName() const { return functionName; };
71+
6672
private:
6773
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
74+
std::string functionName;
6875
};
6976

7077
} // namespace mlir::query::matcher

mlir/lib/Query/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_library(MLIRQuery
66
${MLIR_MAIN_INCLUDE_DIR}/mlir/Query
77

88
LINK_LIBS PUBLIC
9+
MLIRFuncDialect
910
MLIRQueryMatcher
1011
)
1112

mlir/lib/Query/Matcher/Diagnostics.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
3838
return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
3939
case ErrorType::RegistryValueNotFound:
4040
return "Value not found: $0";
41+
case ErrorType::RegistryNotBindable:
42+
return "Matcher does not support binding.";
4143

4244
case ErrorType::ParserStringError:
4345
return "Error parsing string token: <$0>";
@@ -57,6 +59,14 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
5759
return "Unexpected end of code.";
5860
case ErrorType::ParserOverloadedType:
5961
return "Input value has unresolved overloaded type: $0";
62+
case ErrorType::ParserMalformedChainedExpr:
63+
return "Period not followed by valid chained call.";
64+
case ErrorType::ParserChainedExprInvalidArg:
65+
return "Missing/Invalid argument for the chained call.";
66+
case ErrorType::ParserChainedExprNoCloseParen:
67+
return "Missing ')' for the chained call.";
68+
case ErrorType::ParserChainedExprNoOpenParen:
69+
return "Missing '(' for the chained call.";
6070
case ErrorType::ParserFailedToBuildMatcher:
6171
return "Failed to build matcher: $0.";
6272

mlir/lib/Query/Matcher/Parser.cpp

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,17 @@ struct Parser::TokenInfo {
2626
text = newText;
2727
}
2828

29+
// Known identifiers.
30+
static const char *const ID_Extract;
31+
2932
llvm::StringRef text;
3033
TokenKind kind = TokenKind::Eof;
3134
SourceRange range;
3235
VariantValue value;
3336
};
3437

38+
const char *const Parser::TokenInfo::ID_Extract = "extract";
39+
3540
class Parser::CodeTokenizer {
3641
public:
3742
// Constructor with matcherCode and error
@@ -298,6 +303,36 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
298303
return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
299304
}
300305

306+
bool Parser::parseChainedExpression(std::string &argument) {
307+
// Parse the parenthesized argument to .extract("foo")
308+
// Note: EOF is handled inside the consume functions and would fail below when
309+
// checking token kind.
310+
const TokenInfo openToken = tokenizer->consumeNextToken();
311+
const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
312+
const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
313+
314+
if (openToken.kind != TokenKind::OpenParen) {
315+
error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
316+
return false;
317+
}
318+
319+
if (argumentToken.kind != TokenKind::Literal ||
320+
!argumentToken.value.isString()) {
321+
error->addError(argumentToken.range,
322+
ErrorType::ParserChainedExprInvalidArg);
323+
return false;
324+
}
325+
326+
if (closeToken.kind != TokenKind::CloseParen) {
327+
error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
328+
return false;
329+
}
330+
331+
// If all checks passed, extract the argument and return true.
332+
argument = argumentToken.value.getString();
333+
return true;
334+
}
335+
301336
// Parse the arguments of a matcher
302337
bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
303338
const TokenInfo &nameToken, TokenInfo &endToken) {
@@ -364,13 +399,34 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
364399
return false;
365400
}
366401

402+
std::string functionName;
403+
if (tokenizer->peekNextToken().kind == TokenKind::Period) {
404+
tokenizer->consumeNextToken();
405+
TokenInfo chainCallToken = tokenizer->consumeNextToken();
406+
if (chainCallToken.kind == TokenKind::CodeCompletion) {
407+
addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
408+
return false;
409+
}
410+
411+
if (chainCallToken.kind != TokenKind::Ident ||
412+
chainCallToken.text != TokenInfo::ID_Extract) {
413+
error->addError(chainCallToken.range,
414+
ErrorType::ParserMalformedChainedExpr);
415+
return false;
416+
}
417+
418+
if (chainCallToken.text == TokenInfo::ID_Extract &&
419+
!parseChainedExpression(functionName))
420+
return false;
421+
}
422+
367423
if (!ctor)
368424
return false;
369425
// Merge the start and end infos.
370426
SourceRange matcherRange = nameToken.range;
371427
matcherRange.end = endToken.range.end;
372-
VariantMatcher result =
373-
sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
428+
VariantMatcher result = sema->actOnMatcherExpression(
429+
*ctor, matcherRange, functionName, args, error);
374430
if (result.isNull())
375431
return false;
376432
*value = result;
@@ -470,9 +526,10 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
470526
}
471527

472528
VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
473-
MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
474-
Diagnostics *error) {
475-
return RegistryManager::constructMatcher(ctor, nameRange, args, error);
529+
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
530+
llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
531+
return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
532+
error);
476533
}
477534

478535
std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(

mlir/lib/Query/Matcher/Parser.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,9 @@ class Parser {
6464

6565
// Process a matcher expression. The caller takes ownership of the Matcher
6666
// object returned.
67-
virtual VariantMatcher
68-
actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
69-
llvm::ArrayRef<ParserValue> args,
70-
Diagnostics *error) = 0;
67+
virtual VariantMatcher actOnMatcherExpression(
68+
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
69+
llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
7170

7271
// Look up a matcher by name in the matcher name found by the parser.
7372
virtual std::optional<MatcherCtor>
@@ -93,10 +92,11 @@ class Parser {
9392
std::optional<MatcherCtor>
9493
lookupMatcherCtor(llvm::StringRef matcherName) override;
9594

96-
VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
97-
SourceRange nameRange,
98-
llvm::ArrayRef<ParserValue> args,
99-
Diagnostics *error) override;
95+
VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
96+
SourceRange NameRange,
97+
StringRef functionName,
98+
ArrayRef<ParserValue> Args,
99+
Diagnostics *Error) override;
100100

101101
std::vector<ArgKind> getAcceptedCompletionTypes(
102102
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@@ -153,6 +153,8 @@ class Parser {
153153
Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
154154
const NamedValueMap *namedValues, Diagnostics *error);
155155

156+
bool parseChainedExpression(std::string &argument);
157+
156158
bool parseExpressionImpl(VariantValue *value);
157159

158160
bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,

mlir/lib/Query/Matcher/RegistryManager.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,19 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
132132

133133
VariantMatcher RegistryManager::constructMatcher(
134134
MatcherCtor ctor, internal::SourceRange nameRange,
135-
llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
136-
return ctor->create(nameRange, args, error);
135+
llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
136+
internal::Diagnostics *error) {
137+
VariantMatcher out = ctor->create(nameRange, args, error);
138+
if (functionName.empty() || out.isNull())
139+
return out;
140+
141+
if (std::optional<DynMatcher> result = out.getDynMatcher()) {
142+
result->setFunctionName(functionName);
143+
return VariantMatcher::SingleMatcher(*result);
144+
}
145+
146+
error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
147+
return {};
137148
}
138149

139150
} // namespace mlir::query::matcher

mlir/lib/Query/Matcher/RegistryManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class RegistryManager {
6161

6262
static VariantMatcher constructMatcher(MatcherCtor ctor,
6363
internal::SourceRange nameRange,
64+
llvm::StringRef functionName,
6465
ArrayRef<ParserValue> args,
6566
internal::Diagnostics *error);
6667
};

mlir/lib/Query/Query.cpp

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include "mlir/Query/Query.h"
1010
#include "QueryParser.h"
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/IR/IRMapping.h"
1113
#include "mlir/Query/Matcher/MatchFinder.h"
1214
#include "mlir/Query/QuerySession.h"
1315
#include "mlir/Support/LogicalResult.h"
@@ -34,6 +36,70 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
3436
"\"" + binding + "\" binds here");
3537
}
3638

39+
// TODO: Extract into a helper function that can be reused outside query
40+
// context.
41+
static Operation *extractFunction(std::vector<Operation *> &ops,
42+
MLIRContext *context,
43+
llvm::StringRef functionName) {
44+
context->loadDialect<func::FuncDialect>();
45+
OpBuilder builder(context);
46+
47+
// Collect data for function creation
48+
std::vector<Operation *> slice;
49+
std::vector<Value> values;
50+
std::vector<Type> outputTypes;
51+
52+
for (auto *op : ops) {
53+
// Return op's operands are propagated, but the op itself isn't needed.
54+
if (!isa<func::ReturnOp>(op))
55+
slice.push_back(op);
56+
57+
// All results are returned by the extracted function.
58+
outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
59+
op->getResults().getTypes().end());
60+
61+
// Track all values that need to be taken as input to function.
62+
values.insert(values.end(), op->getOperands().begin(),
63+
op->getOperands().end());
64+
}
65+
66+
// Create the function
67+
FunctionType funcType =
68+
builder.getFunctionType(ValueRange(values), outputTypes);
69+
auto loc = builder.getUnknownLoc();
70+
func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
71+
72+
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
73+
74+
// Map original values to function arguments
75+
IRMapping mapper;
76+
for (const auto &arg : llvm::enumerate(values))
77+
mapper.map(arg.value(), funcOp.getArgument(arg.index()));
78+
79+
// Clone operations and build function body
80+
std::vector<Operation *> clonedOps;
81+
std::vector<Value> clonedVals;
82+
for (Operation *slicedOp : slice) {
83+
Operation *clonedOp =
84+
clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
85+
clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
86+
clonedOp->result_end());
87+
}
88+
// Add return operation
89+
builder.create<func::ReturnOp>(loc, clonedVals);
90+
91+
// Remove unused function arguments
92+
size_t currentIndex = 0;
93+
while (currentIndex < funcOp.getNumArguments()) {
94+
if (funcOp.getArgument(currentIndex).use_empty())
95+
funcOp.eraseArgument(currentIndex);
96+
else
97+
++currentIndex;
98+
}
99+
100+
return funcOp;
101+
}
102+
37103
Query::~Query() = default;
38104

39105
mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@@ -65,9 +131,22 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
65131

66132
mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
67133
QuerySession &qs) const {
134+
Operation *rootOp = qs.getRootOp();
68135
int matchCount = 0;
69136
std::vector<Operation *> matches =
70-
matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
137+
matcher::MatchFinder().getMatches(rootOp, matcher);
138+
139+
// An extract call is recognized by considering if the matcher has a name.
140+
// TODO: Consider making the extract more explicit.
141+
if (matcher.hasFunctionName()) {
142+
auto functionName = matcher.getFunctionName();
143+
Operation *function =
144+
extractFunction(matches, rootOp->getContext(), functionName);
145+
os << "\n" << *function << "\n\n";
146+
function->erase();
147+
return mlir::success();
148+
}
149+
71150
os << "\n";
72151
for (Operation *op : matches) {
73152
os << "Match #" << ++matchCount << ":\n\n";
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
2+
3+
// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
4+
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
5+
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
6+
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
7+
// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
8+
9+
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
10+
%sum0 = arith.addf %a, %b : f32
11+
%sub0 = arith.subf %sum0, %c : f32
12+
%mul0 = arith.mulf %a, %sub0 : f32
13+
%sum1 = arith.addf %b, %c : f32
14+
%mul1 = arith.mulf %sum1, %mul0 : f32
15+
%sub2 = arith.subf %mul1, %a : f32
16+
%sum2 = arith.addf %mul1, %b : f32
17+
%mul2 = arith.mulf %sub2, %sum2 : f32
18+
return %mul2 : f32
19+
}

0 commit comments

Comments
 (0)