Skip to content

Commit 7271c1b

Browse files
rdzhabarovjpienaar
authored andcommitted
[DDR] Introduce implicit equality check for the source pattern operands with the same name.
This CL allows user to specify the same name for the operands in the source pattern which implicitly enforces equality on operands with the same name. E.g., Pat<(OpA $a, $b, $a) ... > would create a matching rule for checking equality for the first and the last operands. Equality of the operands is enforced at any depth, e.g., OpA ($a, $b, OpB($a, $c, OpC ($a))). Example usage: Pat<(Reshape $arg0, (Shape $arg0)), (replaceWithValue $arg0)> Note, this feature only covers operands but not attributes. Current use cases are based on the operand equality and explicitly add the constraint into the pattern. Attribute equality will be worked out on the different CL. Differential Revision: https://reviews.llvm.org/D89254
1 parent ab870f3 commit 7271c1b

File tree

5 files changed

+250
-21
lines changed

5 files changed

+250
-21
lines changed

mlir/include/mlir/TableGen/Pattern.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "llvm/ADT/StringMap.h"
2222
#include "llvm/ADT/StringSet.h"
2323

24+
#include <unordered_map>
25+
2426
namespace llvm {
2527
class DagInit;
2628
class Init;
@@ -228,6 +230,9 @@ class SymbolInfoMap {
228230
// value bound by this symbol.
229231
std::string getVarDecl(StringRef name) const;
230232

233+
// Returns a variable name for the symbol named as `name`.
234+
std::string getVarName(StringRef name) const;
235+
231236
private:
232237
// Allow SymbolInfoMap to access private methods.
233238
friend class SymbolInfoMap;
@@ -285,9 +290,12 @@ class SymbolInfoMap {
285290
Kind kind; // The kind of the bound entity
286291
// The argument index (for `Attr` and `Operand` only)
287292
Optional<int> argIndex;
293+
// Alternative name for the symbol. It is used in case the name
294+
// is not unique. Applicable for `Operand` only.
295+
Optional<std::string> alternativeName;
288296
};
289297

290-
using BaseT = llvm::StringMap<SymbolInfo>;
298+
using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
291299

292300
// Iterators for accessing all symbols.
293301
using iterator = BaseT::iterator;
@@ -300,7 +308,7 @@ class SymbolInfoMap {
300308
const_iterator end() const { return symbolInfoMap.end(); }
301309

302310
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
303-
// Returns false if `symbol` is already bound.
311+
// Returns false if `symbol` is already bound and symbols are not operands.
304312
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
305313

306314
// Binds the given `symbol` to the results the given `op`. Returns false if
@@ -317,6 +325,18 @@ class SymbolInfoMap {
317325
// Returns an iterator to the information of the given symbol named as `key`.
318326
const_iterator find(StringRef key) const;
319327

328+
// Returns an iterator to the information of the given symbol named as `key`,
329+
// with index `argIndex` for operator `op`.
330+
const_iterator findBoundSymbol(StringRef key, const Operator &op,
331+
int argIndex) const;
332+
333+
// Returns the bounds of a range that includes all the elements which
334+
// bind to the `key`.
335+
std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
336+
337+
// Returns number of times symbol named as `key` was used.
338+
int count(StringRef key) const;
339+
320340
// Returns the number of static values of the given `symbol` corresponds to.
321341
// A static value is an operand/result declared in ODS. Normally a symbol only
322342
// represents one static value, but symbols bound to op results can represent
@@ -338,6 +358,9 @@ class SymbolInfoMap {
338358
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
339359
const char *separator = ", ") const;
340360

361+
// Assign alternative unique names to Operands that have equal names.
362+
void assignUniqueAlternativeNames();
363+
341364
// Splits the given `symbol` into a value pack name and an index. Returns the
342365
// value pack name and writes the index to `index` on success. Returns
343366
// `symbol` itself if it does not contain an index.
@@ -347,7 +370,7 @@ class SymbolInfoMap {
347370
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
348371

349372
private:
350-
llvm::StringMap<SymbolInfo> symbolInfoMap;
373+
BaseT symbolInfoMap;
351374

352375
// Pattern instantiation location. This is intended to be used as parameter
353376
// to PrintFatalError() to report errors.

mlir/lib/TableGen/Pattern.cpp

Lines changed: 97 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
208208
llvm_unreachable("unknown kind");
209209
}
210210

211+
std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
212+
return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
213+
}
214+
211215
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
212216
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
213217
switch (kind) {
@@ -219,8 +223,9 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
219223
case Kind::Operand: {
220224
// Use operand range for captured operands (to support potential variadic
221225
// operands).
222-
return std::string(formatv(
223-
"::mlir::Operation::operand_range {0}(op0->getOperands());\n", name));
226+
return std::string(
227+
formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
228+
getVarName(name)));
224229
}
225230
case Kind::Value: {
226231
return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
@@ -359,27 +364,73 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
359364
? SymbolInfo::getAttr(&op, argIndex)
360365
: SymbolInfo::getOperand(&op, argIndex);
361366

362-
return symbolInfoMap.insert({symbol, symInfo}).second;
367+
std::string key = symbol.str();
368+
if (auto numberOfEntries = symbolInfoMap.count(key)) {
369+
// Only non unique name for the operand is supported.
370+
if (symInfo.kind != SymbolInfo::Kind::Operand) {
371+
return false;
372+
}
373+
374+
// Cannot add new operand if there is already non operand with the same
375+
// name.
376+
if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
377+
return false;
378+
}
379+
}
380+
381+
symbolInfoMap.emplace(key, symInfo);
382+
return true;
363383
}
364384

365385
bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
366386
StringRef name = getValuePackName(symbol);
367-
return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
387+
auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
388+
389+
return symbolInfoMap.count(inserted->first) == 1;
368390
}
369391

370392
bool SymbolInfoMap::bindValue(StringRef symbol) {
371-
return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
393+
auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getValue());
394+
return symbolInfoMap.count(inserted->first) == 1;
372395
}
373396

374397
bool SymbolInfoMap::contains(StringRef symbol) const {
375398
return find(symbol) != symbolInfoMap.end();
376399
}
377400

378401
SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
379-
StringRef name = getValuePackName(key);
402+
std::string name = getValuePackName(key).str();
403+
380404
return symbolInfoMap.find(name);
381405
}
382406

407+
SymbolInfoMap::const_iterator
408+
SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
409+
int argIndex) const {
410+
std::string name = getValuePackName(key).str();
411+
auto range = symbolInfoMap.equal_range(name);
412+
413+
for (auto it = range.first; it != range.second; ++it) {
414+
if (it->second.op == &op && it->second.argIndex == argIndex) {
415+
return it;
416+
}
417+
}
418+
419+
return symbolInfoMap.end();
420+
}
421+
422+
std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
423+
SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
424+
std::string name = getValuePackName(key).str();
425+
426+
return symbolInfoMap.equal_range(name);
427+
}
428+
429+
int SymbolInfoMap::count(StringRef key) const {
430+
std::string name = getValuePackName(key).str();
431+
return symbolInfoMap.count(name);
432+
}
433+
383434
int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
384435
StringRef name = getValuePackName(symbol);
385436
if (name != symbol) {
@@ -388,7 +439,7 @@ int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
388439
return 1;
389440
}
390441
// Otherwise, find how many it represents by querying the symbol's info.
391-
return find(name)->getValue().getStaticValueCount();
442+
return find(name)->second.getStaticValueCount();
392443
}
393444

394445
std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
@@ -397,27 +448,58 @@ std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
397448
int index = -1;
398449
StringRef name = getValuePackName(symbol, &index);
399450

400-
auto it = symbolInfoMap.find(name);
451+
auto it = symbolInfoMap.find(name.str());
401452
if (it == symbolInfoMap.end()) {
402453
auto error = formatv("referencing unbound symbol '{0}'", symbol);
403454
PrintFatalError(loc, error);
404455
}
405456

406-
return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
457+
return it->second.getValueAndRangeUse(name, index, fmt, separator);
407458
}
408459

409460
std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
410461
const char *separator) const {
411462
int index = -1;
412463
StringRef name = getValuePackName(symbol, &index);
413464

414-
auto it = symbolInfoMap.find(name);
465+
auto it = symbolInfoMap.find(name.str());
415466
if (it == symbolInfoMap.end()) {
416467
auto error = formatv("referencing unbound symbol '{0}'", symbol);
417468
PrintFatalError(loc, error);
418469
}
419470

420-
return it->getValue().getAllRangeUse(name, index, fmt, separator);
471+
return it->second.getAllRangeUse(name, index, fmt, separator);
472+
}
473+
474+
void SymbolInfoMap::assignUniqueAlternativeNames() {
475+
llvm::StringSet<> usedNames;
476+
477+
for (auto symbolInfoIt = symbolInfoMap.begin();
478+
symbolInfoIt != symbolInfoMap.end();) {
479+
auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
480+
auto startRange = range.first;
481+
auto endRange = range.second;
482+
483+
auto operandName = symbolInfoIt->first;
484+
int startSearchIndex = 0;
485+
for (++startRange; startRange != endRange; ++startRange) {
486+
// Current operand name is not unique, find a unique one
487+
// and set the alternative name.
488+
for (int i = startSearchIndex;; ++i) {
489+
std::string alternativeName = operandName + std::to_string(i);
490+
if (!usedNames.contains(alternativeName) &&
491+
symbolInfoMap.count(alternativeName) == 0) {
492+
usedNames.insert(alternativeName);
493+
startRange->second.alternativeName = alternativeName;
494+
startSearchIndex = i + 1;
495+
496+
break;
497+
}
498+
}
499+
}
500+
501+
symbolInfoIt = endRange;
502+
}
421503
}
422504

423505
//===----------------------------------------------------------------------===//
@@ -445,6 +527,10 @@ void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
445527
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
446528
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
447529
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
530+
531+
LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
532+
infoMap.assignUniqueAlternativeNames();
533+
LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
448534
}
449535

450536
void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,32 @@ def OpM : TEST_Op<"op_m"> {
619619
let results = (outs I32);
620620
}
621621

622+
def OpN : TEST_Op<"op_n"> {
623+
let arguments = (ins I32, I32);
624+
let results = (outs I32);
625+
}
626+
627+
def OpO : TEST_Op<"op_o"> {
628+
let arguments = (ins I32);
629+
let results = (outs I32);
630+
}
631+
632+
def OpP : TEST_Op<"op_p"> {
633+
let arguments = (ins I32, I32, I32, I32, I32, I32);
634+
let results = (outs I32);
635+
}
636+
637+
// Test same operand name enforces equality condition check.
638+
def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
639+
640+
// Test when equality is enforced at different depth.
641+
def TestNestedOpEqualArgsPattern :
642+
Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
643+
644+
// Test multiple equal arguments check enforced.
645+
def TestMultipleEqualArgsPattern :
646+
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
647+
622648
// Test for memrefs normalization of an op with normalizable memrefs.
623649
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
624650
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);

mlir/test/mlir-tblgen/pattern.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,64 @@ func @verifyManyArgs(%arg: i32) {
111111
return
112112
}
113113

114+
// CHECK-LABEL: verifyEqualArgs
115+
func @verifyEqualArgs(%arg0: i32, %arg1: i32) {
116+
// def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
117+
118+
// CHECK: "test.op_o"(%arg0) : (i32) -> i32
119+
"test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32)
120+
121+
// CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32
122+
"test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32)
123+
124+
return
125+
}
126+
127+
// CHECK-LABEL: verifyNestedOpEqualArgs
128+
func @verifyNestedOpEqualArgs(
129+
%arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) {
130+
// def TestNestedOpEqualArgsPattern :
131+
// Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
132+
133+
// CHECK: %arg1
134+
%0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
135+
: (i32, i32, i32, i32, i32, i32) -> (i32)
136+
%1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
137+
138+
// CHECK: test.op_p
139+
// CHECK: test.op_n
140+
%2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
141+
: (i32, i32, i32, i32, i32, i32) -> (i32)
142+
%3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
143+
144+
return
145+
}
146+
147+
// CHECK-LABEL: verifyMultipleEqualArgs
148+
func @verifyMultipleEqualArgs(
149+
%arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
150+
// def TestMultipleEqualArgsPattern :
151+
// Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
152+
153+
// CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32
154+
"test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) :
155+
(i32, i32, i32, i32 , i32, i32) -> i32
156+
157+
// CHECK: test.op_p
158+
"test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) :
159+
(i32, i32, i32, i32 , i32, i32) -> i32
160+
161+
// CHECK: test.op_p
162+
"test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) :
163+
(i32, i32, i32, i32 , i32, i32) -> i32
164+
165+
// CHECK: test.op_p
166+
"test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) :
167+
(i32, i32, i32, i32 , i32, i32) -> i32
168+
169+
return
170+
}
171+
114172
//===----------------------------------------------------------------------===//
115173
// Test Symbol Binding
116174
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)