@@ -217,6 +217,11 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
217
217
llvm::SmallSet<int , 4 > processedTensorIdsForSend;
218
218
llvm::SmallSet<int , 4 > processedTensorIdsForReceive;
219
219
220
+ // / Mapping from declarations to the number to times a TF_Function took the
221
+ // / name from the declaration. This will be used in `getUniqueName` to produce
222
+ // / uniqued graph node names.
223
+ llvm::SmallDenseMap<ValueDecl *, unsigned > uniqueNames;
224
+
220
225
// / This flag gets set if lowering code to the graph produces a TensorFlow
221
226
// / error and emits a diagnostic. This tells us to stop lowering and give up
222
227
// / gracefully.
@@ -558,14 +563,29 @@ static void escapeOpName(std::string &name) {
558
563
// Currently, invalid characters are simply replaced with underscores.
559
564
// TODO: Use a more robust escaping transformation. It should handle unicode
560
565
// characters (using llvm::UTF8 or some other means) and be reversible.
561
- for (unsigned i = 0 , n = name. size (); i < n; i++ ) {
562
- char c = name[i];
563
- if (!std::isalnum (c) && c != ' .' )
564
- if (i == 0 || (i != ' _' && i != ' /' ))
565
- name[i] = ' _' ;
566
+ for (auto i : indices (name) ) {
567
+ char & c = name[i];
568
+ if (!llvm::isAlnum (c) && c != ' .' )
569
+ if (i == 0 || (c != ' _' && c != ' /' ))
570
+ c = ' _' ;
566
571
}
567
572
}
568
573
574
+ // / Given a DeclName, returns an escaped (TF-compatible)
575
+ // / name that replaces parentheses with '.' and colons with '_', for example:
576
+ // / `foo(x:y:z:)` -> `foo.x_y_z_.`.
577
+ static std::string escapeDeclName (DeclName name) {
578
+ SmallVector<char , 8 > buffer;
579
+ auto newName = name.getString (buffer, /* skipEmptyArgumentNames*/ true );
580
+ for (char &c : buffer) {
581
+ if (c == ' (' || c == ' )' )
582
+ c = ' .' ;
583
+ else if (!llvm::isAlnum (c))
584
+ c = ' _' ;
585
+ }
586
+ return newName.str ();
587
+ }
588
+
569
589
// / Produce a "stack trace" for the specified location, producing it in a form
570
590
// / that we can use as a unique op name.
571
591
std::string TFGraphLowering::getUniqueName (SILDebugLocation loc,
@@ -580,6 +600,9 @@ std::string TFGraphLowering::getUniqueName(SILDebugLocation loc,
580
600
// Form a name for this op based on the user's source location and "stack
581
601
// trace" of where it got inlined in user code. We use the form
582
602
// "file:line:col".
603
+ //
604
+ // FIXME: InlinedCallSite is always nullptr even if we use the performance
605
+ // inliner, so it currently does not track the inlined call site.
583
606
for (auto ds = loc.getScope (); ds; ds = ds->InlinedCallSite ) {
584
607
// If the call site location is invalid, stop scanning.
585
608
if (!ds->Loc .getSourceLoc ().isValid ())
@@ -596,7 +619,27 @@ std::string TFGraphLowering::getUniqueName(SILDebugLocation loc,
596
619
if (fnName.endswith (" .device_partition" ))
597
620
fnName = fnName.drop_back (strlen (" .device_partition" ));
598
621
599
- name += " ." + fnName.str () + " ." + llvm::utostr (lineCol.first );
622
+ // Separate functions using '/' so that TensorBoard can treat it as a
623
+ // hierarchical separator.
624
+ name += ' /' ;
625
+
626
+ // If the SIL function is backed by a Swift decl, use the decl name.
627
+ // Otherwise, use the SIL name.
628
+ std::string funcName;
629
+ auto *dc = SILFn.getDeclContext ();
630
+ if (auto *afd = dyn_cast_or_null<AbstractFunctionDecl>(dc)) {
631
+ funcName = escapeDeclName (afd->getEffectiveFullName ());
632
+ // Make sure the name is unique.
633
+ auto declCountLookup = uniqueNames.find (afd);
634
+ if (declCountLookup != uniqueNames.end ())
635
+ funcName += " _" + llvm::itostr (declCountLookup->getSecond ()++);
636
+ else
637
+ uniqueNames.insert ({afd, 1 });
638
+ } else {
639
+ funcName = fnName.str ();
640
+ }
641
+
642
+ name += funcName + " ." + llvm::utostr (lineCol.first );
600
643
name += " ." + llvm::utostr (lineCol.second );
601
644
}
602
645
}
@@ -608,7 +651,7 @@ std::string TFGraphLowering::getUniqueName(SILDebugLocation loc,
608
651
if (sourceLoc.isValid ()) {
609
652
auto lineCol = SM.getLineAndColumn (sourceLoc);
610
653
auto bufferID = SM.getBufferIdentifierForLoc (sourceLoc);
611
- name += " . " + bufferID.str () + " ." + llvm::utostr (lineCol.first );
654
+ name += " / " + bufferID.str () + " ." + llvm::utostr (lineCol.first );
612
655
name += " ." + llvm::utostr (lineCol.second );
613
656
}
614
657
}
@@ -1099,7 +1142,6 @@ void TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
1099
1142
StringRef filePath;
1100
1143
if (dataSource != DatasetCreationContext::FAKE) {
1101
1144
auto operand = inst->getOperand (1 );
1102
- auto opInfo = tfopInfo.operandClasses [1 ];
1103
1145
auto *sli = cast<StringLiteralInst>(operand);
1104
1146
assert (sli->getEncoding () == StringLiteralInst::Encoding::UTF8);
1105
1147
filePath = sli->getValue ();
0 commit comments