@@ -208,6 +208,10 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
208
208
llvm_unreachable (" unknown kind" );
209
209
}
210
210
211
+ std::string SymbolInfoMap::SymbolInfo::getVarName (StringRef name) const {
212
+ return alternativeName.hasValue () ? alternativeName.getValue () : name.str ();
213
+ }
214
+
211
215
std::string SymbolInfoMap::SymbolInfo::getVarDecl (StringRef name) const {
212
216
LLVM_DEBUG (llvm::dbgs () << " getVarDecl for '" << name << " ': " );
213
217
switch (kind) {
@@ -219,8 +223,9 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
219
223
case Kind::Operand: {
220
224
// Use operand range for captured operands (to support potential variadic
221
225
// operands).
222
- return std::string (formatv (
223
- " ::mlir::Operation::operand_range {0}(op0->getOperands());\n " , name));
226
+ return std::string (
227
+ formatv (" ::mlir::Operation::operand_range {0}(op0->getOperands());\n " ,
228
+ getVarName (name)));
224
229
}
225
230
case Kind::Value: {
226
231
return std::string (formatv (" ::llvm::ArrayRef<::mlir::Value> {0};\n " , name));
@@ -359,27 +364,73 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
359
364
? SymbolInfo::getAttr (&op, argIndex)
360
365
: SymbolInfo::getOperand (&op, argIndex);
361
366
362
- return symbolInfoMap.insert ({symbol, symInfo}).second ;
367
+ std::string key = symbol.str ();
368
+ if (auto numberOfEntries = symbolInfoMap.count (key)) {
369
+ // Only non unique name for the operand is supported.
370
+ if (symInfo.kind != SymbolInfo::Kind::Operand) {
371
+ return false ;
372
+ }
373
+
374
+ // Cannot add new operand if there is already non operand with the same
375
+ // name.
376
+ if (symbolInfoMap.find (key)->second .kind != SymbolInfo::Kind::Operand) {
377
+ return false ;
378
+ }
379
+ }
380
+
381
+ symbolInfoMap.emplace (key, symInfo);
382
+ return true ;
363
383
}
364
384
365
385
bool SymbolInfoMap::bindOpResult (StringRef symbol, const Operator &op) {
366
386
StringRef name = getValuePackName (symbol);
367
- return symbolInfoMap.insert ({name, SymbolInfo::getResult (&op)}).second ;
387
+ auto inserted = symbolInfoMap.emplace (name, SymbolInfo::getResult (&op));
388
+
389
+ return symbolInfoMap.count (inserted->first ) == 1 ;
368
390
}
369
391
370
392
bool SymbolInfoMap::bindValue (StringRef symbol) {
371
- return symbolInfoMap.insert ({symbol, SymbolInfo::getValue ()}).second ;
393
+ auto inserted = symbolInfoMap.emplace (symbol, SymbolInfo::getValue ());
394
+ return symbolInfoMap.count (inserted->first ) == 1 ;
372
395
}
373
396
374
397
bool SymbolInfoMap::contains (StringRef symbol) const {
375
398
return find (symbol) != symbolInfoMap.end ();
376
399
}
377
400
378
401
SymbolInfoMap::const_iterator SymbolInfoMap::find (StringRef key) const {
379
- StringRef name = getValuePackName (key);
402
+ std::string name = getValuePackName (key).str ();
403
+
380
404
return symbolInfoMap.find (name);
381
405
}
382
406
407
+ SymbolInfoMap::const_iterator
408
+ SymbolInfoMap::findBoundSymbol (StringRef key, const Operator &op,
409
+ int argIndex) const {
410
+ std::string name = getValuePackName (key).str ();
411
+ auto range = symbolInfoMap.equal_range (name);
412
+
413
+ for (auto it = range.first ; it != range.second ; ++it) {
414
+ if (it->second .op == &op && it->second .argIndex == argIndex) {
415
+ return it;
416
+ }
417
+ }
418
+
419
+ return symbolInfoMap.end ();
420
+ }
421
+
422
+ std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
423
+ SymbolInfoMap::getRangeOfEqualElements (StringRef key) {
424
+ std::string name = getValuePackName (key).str ();
425
+
426
+ return symbolInfoMap.equal_range (name);
427
+ }
428
+
429
+ int SymbolInfoMap::count (StringRef key) const {
430
+ std::string name = getValuePackName (key).str ();
431
+ return symbolInfoMap.count (name);
432
+ }
433
+
383
434
int SymbolInfoMap::getStaticValueCount (StringRef symbol) const {
384
435
StringRef name = getValuePackName (symbol);
385
436
if (name != symbol) {
@@ -388,7 +439,7 @@ int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
388
439
return 1 ;
389
440
}
390
441
// Otherwise, find how many it represents by querying the symbol's info.
391
- return find (name)->getValue () .getStaticValueCount ();
442
+ return find (name)->second .getStaticValueCount ();
392
443
}
393
444
394
445
std::string SymbolInfoMap::getValueAndRangeUse (StringRef symbol,
@@ -397,27 +448,58 @@ std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
397
448
int index = -1 ;
398
449
StringRef name = getValuePackName (symbol, &index);
399
450
400
- auto it = symbolInfoMap.find (name);
451
+ auto it = symbolInfoMap.find (name. str () );
401
452
if (it == symbolInfoMap.end ()) {
402
453
auto error = formatv (" referencing unbound symbol '{0}'" , symbol);
403
454
PrintFatalError (loc, error);
404
455
}
405
456
406
- return it->getValue () .getValueAndRangeUse (name, index, fmt, separator);
457
+ return it->second .getValueAndRangeUse (name, index, fmt, separator);
407
458
}
408
459
409
460
std::string SymbolInfoMap::getAllRangeUse (StringRef symbol, const char *fmt,
410
461
const char *separator) const {
411
462
int index = -1 ;
412
463
StringRef name = getValuePackName (symbol, &index);
413
464
414
- auto it = symbolInfoMap.find (name);
465
+ auto it = symbolInfoMap.find (name. str () );
415
466
if (it == symbolInfoMap.end ()) {
416
467
auto error = formatv (" referencing unbound symbol '{0}'" , symbol);
417
468
PrintFatalError (loc, error);
418
469
}
419
470
420
- return it->getValue ().getAllRangeUse (name, index, fmt, separator);
471
+ return it->second .getAllRangeUse (name, index, fmt, separator);
472
+ }
473
+
474
+ void SymbolInfoMap::assignUniqueAlternativeNames () {
475
+ llvm::StringSet<> usedNames;
476
+
477
+ for (auto symbolInfoIt = symbolInfoMap.begin ();
478
+ symbolInfoIt != symbolInfoMap.end ();) {
479
+ auto range = symbolInfoMap.equal_range (symbolInfoIt->first );
480
+ auto startRange = range.first ;
481
+ auto endRange = range.second ;
482
+
483
+ auto operandName = symbolInfoIt->first ;
484
+ int startSearchIndex = 0 ;
485
+ for (++startRange; startRange != endRange; ++startRange) {
486
+ // Current operand name is not unique, find a unique one
487
+ // and set the alternative name.
488
+ for (int i = startSearchIndex;; ++i) {
489
+ std::string alternativeName = operandName + std::to_string (i);
490
+ if (!usedNames.contains (alternativeName) &&
491
+ symbolInfoMap.count (alternativeName) == 0 ) {
492
+ usedNames.insert (alternativeName);
493
+ startRange->second .alternativeName = alternativeName;
494
+ startSearchIndex = i + 1 ;
495
+
496
+ break ;
497
+ }
498
+ }
499
+ }
500
+
501
+ symbolInfoIt = endRange;
502
+ }
421
503
}
422
504
423
505
// ===----------------------------------------------------------------------===//
@@ -445,6 +527,10 @@ void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
445
527
LLVM_DEBUG (llvm::dbgs () << " start collecting source pattern bound symbols\n " );
446
528
collectBoundSymbols (getSourcePattern (), infoMap, /* isSrcPattern=*/ true );
447
529
LLVM_DEBUG (llvm::dbgs () << " done collecting source pattern bound symbols\n " );
530
+
531
+ LLVM_DEBUG (llvm::dbgs () << " start assigning alternative names for symbols\n " );
532
+ infoMap.assignUniqueAlternativeNames ();
533
+ LLVM_DEBUG (llvm::dbgs () << " done assigning alternative names for symbols\n " );
448
534
}
449
535
450
536
void Pattern::collectResultPatternBoundSymbols (SymbolInfoMap &infoMap) {
0 commit comments