Skip to content

Commit c12db28

Browse files
authored
(SR-8783) nicer textual sil for graph_op (#19429)
1 parent a0c68fe commit c12db28

24 files changed

+208
-144
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,8 @@ ERROR(autodiff_use_wrt_not_withrespectto,none,
16451645
//------------------------------------------------------------------------------
16461646
// Graph operation related parsing diagnostics
16471647
//------------------------------------------------------------------------------
1648+
ERROR(sil_graph_op_name_comma,PointsToFirstBadToken,
1649+
"'graph_op' name cannot contain ','", ())
16481650
ERROR(sil_graph_op_expected_attr_name,PointsToFirstBadToken,
16491651
"expected 'graph_op' attribute name", ())
16501652
ERROR(sil_graph_op_expected_attr_value,PointsToFirstBadToken,
@@ -1653,6 +1655,8 @@ ERROR(sil_graph_op_unhandled_attr_value,PointsToFirstBadToken,
16531655
"unhandled 'graph_op' attribute value", ())
16541656
ERROR(sil_graph_op_expected_rparen,PointsToFirstBadToken,
16551657
"expected ')' in 'graph_op' argument list", ())
1658+
ERROR(sil_graph_op_expected_rsquare,PointsToFirstBadToken,
1659+
"expected ']' in 'graph_op' list operand", ())
16561660
ERROR(sil_graph_op_expected_rbrace,PointsToFirstBadToken,
16571661
"expected '}' in 'graph_op' attribute list", ())
16581662
ERROR(sil_graph_op_expected_colon_after_attr_name,PointsToFirstBadToken,

include/swift/SIL/GraphOperationBuilder.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,6 @@ class GraphOperationBuilder {
5656
GraphOperationAttribute &addAttribute(
5757
const GraphOperationAttribute &attribute);
5858

59-
/// Special method that should only be used for "tfc.scalarToTensor"'s operand,
60-
/// because it has special name mangling. (Marker is "s").
61-
void addScalarOperand(SILValue operand);
62-
63-
/// Special method that should only be used for "tf_tensor_to_i1"'s operand,
64-
/// because it has special name mangling. (No marker for its operand).
65-
/// TODO: Make "tf_tensor_to_i1" support normal name mangling, and then remove
66-
/// this.
67-
void addTFTensorToI1Operand(SILValue operand);
68-
6959
/// Build the GraphOperationInst.
7060
GraphOperationInst* build(
7161
SILBuilder &B, ASTContext &C, SILLocation loc,

include/swift/SIL/GraphOperationInfo.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ struct GraphOperationInfo {
9797
// }
9898

9999
enum StructuredOperandKind {
100-
/// Scalar input, used by tfc.scalarToTensor only.
101-
/// Mangled name is ",s"
102-
SOK_Scalar,
103100
/// Single operand.
104101
/// Mangled name is ",i${name}" where ${name} is an optional name.
105102
SOK_Single,
@@ -117,7 +114,7 @@ struct GraphOperationInfo {
117114
StructuredOperandKind Kind;
118115
StringRef Name;
119116
union {
120-
/// Operand for SOK_Scalar and SOK_Single.
117+
/// Operand for SOK_Single.
121118
SILValue SingleOperand;
122119
/// Operands for SOK_List.
123120
ArrayRef<Operand> OperandList;
@@ -140,7 +137,7 @@ struct GraphOperationInfo {
140137
}
141138

142139
SILValue getSingleOperand() const {
143-
assert(getKind() == SOK_Scalar || getKind() == SOK_Single);
140+
assert(getKind() == SOK_Single);
144141
return SingleOperand;
145142
}
146143

lib/AST/GraphOperationBuilder.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ namespace tf {
2020

2121
/// Start building a GraphOperationInst for op `OpName`.
2222
GraphOperationBuilder::GraphOperationBuilder(StringRef OpName)
23-
: MangledName(OpName) {}
23+
: MangledName(OpName) {
24+
assert(MangledName.find(',') == StringRef::npos &&
25+
"graph_op name cannot include ','");
26+
}
2427

2528
/// Add a single operand to the GraphOperationInst, with an optional name.
2629
void GraphOperationBuilder::addOperand(SILValue operand, StringRef name) {
@@ -50,21 +53,6 @@ GraphOperationAttribute &GraphOperationBuilder::addAttribute(
5053
return Attributes.back();
5154
}
5255

53-
/// Special method that should only be used for "tfc.scalarToTensor"'s operand,
54-
/// because it has special name mangling. (Marker is "s").
55-
void GraphOperationBuilder::addScalarOperand(SILValue operand) {
56-
MangledName += ",s";
57-
Operands.push_back(operand);
58-
}
59-
60-
/// Special method that should only be used for "tf_tensor_to_i1"'s operand,
61-
/// because it has special name mangling. (No marker for its operand).
62-
/// TODO: Make "tf_tensor_to_i1" support normal name mangling, and then remove
63-
/// this.
64-
void GraphOperationBuilder::addTFTensorToI1Operand(SILValue operand) {
65-
Operands.push_back(operand);
66-
}
67-
6856
/// Build the GraphOperationInst.
6957
GraphOperationInst* GraphOperationBuilder::build(
7058
SILBuilder &B, ASTContext &C, SILLocation loc,

lib/AST/GraphOperationInfo.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,6 @@ StringRef GraphOperationInfo::decodeName(
9292
StringRef thisMarkerName = thisMarker.drop_front(2);
9393
assert(thisMarker.size() >= 2 && "marker too short");
9494
switch (thisMarker[1]) {
95-
case 's':
96-
// Push a SOK_Scalar.
97-
assert(thisMarkerName.empty() && "SOK_Scalar should not have name");
98-
structuredOperands.emplace_back(SOK_Scalar, StringRef(),
99-
remainingOperands.front().get());
100-
remainingOperands = remainingOperands.drop_front(1);
101-
break;
10295
case 'i':
10396
// Push a SOK_Single.
10497
structuredOperands.emplace_back(SOK_Single, thisMarkerName,

lib/ParseSIL/ParseSIL.cpp

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "swift/SIL/SILArgument.h"
3030
#include "swift/SIL/SILBuilder.h"
3131
/// SWIFT_ENABLE_TENSORFLOW
32+
#include "swift/SIL/GraphOperationBuilder.h"
3233
#include "swift/SIL/SILConstants.h"
3334
#include "swift/SIL/SILDebugScope.h"
3435
#include "swift/SIL/SILModule.h"
@@ -2960,35 +2961,73 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
29602961
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, "graph_op name");
29612962
return true;
29622963
}
2963-
StringRef rawString = P.Tok.getText().drop_front().drop_back();
2964-
Identifier name = P.Context.getIdentifier(rawString);
2964+
StringRef opName = P.Tok.getText().drop_front().drop_back();
2965+
if (opName.find(',') != StringRef::npos) {
2966+
P.diagnose(P.Tok, diag::sil_graph_op_name_comma);
2967+
return true;
2968+
}
2969+
tf::GraphOperationBuilder opBuilder(opName);
29652970
P.consumeToken(tok::string_literal);
29662971

2967-
// Parse graph operation arguments.
2972+
// Parses a top-level operand to the graphop, and add it to `opBuilder`.
2973+
auto parseOperand = [&]() -> ParserStatus {
2974+
// Parse the optional operand name.
2975+
StringRef operandName;
2976+
if (P.Tok.is(tok::identifier)) {
2977+
operandName = P.Tok.getText();
2978+
P.consumeToken();
2979+
}
2980+
2981+
if (P.Tok.is(tok::l_square)) {
2982+
// It is a list operand.
2983+
SourceLoc lSquareLoc = P.consumeToken(tok::l_square);
2984+
SourceLoc rSquareLoc;
2985+
SmallVector<SILValue, 4> elements;
2986+
2987+
// Parses an element of a list operand, and adds it to `elements`.
2988+
auto parseListOperandElement = [&]() -> ParserStatus {
2989+
SILValue value;
2990+
if (parseTypedValueRef(value, B))
2991+
return makeParserError();
2992+
elements.push_back(value);
2993+
return makeParserSuccess();
2994+
};
2995+
2996+
ParserStatus status = P.parseList(tok::r_square, lSquareLoc, rSquareLoc,
2997+
/*AllowSepAfterLast*/ false,
2998+
diag::sil_graph_op_expected_rsquare,
2999+
SyntaxKind::TuplePatternElementList,
3000+
parseListOperandElement);
3001+
if (status.isError())
3002+
return status;
3003+
opBuilder.addListOperand(elements, operandName);
3004+
return makeParserSuccess();
3005+
} else {
3006+
// It is a single operand.
3007+
SILValue value;
3008+
if (parseTypedValueRef(value, B))
3009+
return makeParserError();
3010+
opBuilder.addOperand(value, operandName);
3011+
return makeParserSuccess();
3012+
}
3013+
};
3014+
3015+
// Parse graph operation operands.
29683016
if (P.Tok.isNot(tok::l_paren)) {
29693017
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, "(");
29703018
return true;
29713019
}
2972-
SmallVector<SILValue, 4> arguments;
29733020
SourceLoc lParenLoc = P.consumeToken(tok::l_paren);
29743021
SourceLoc rParenLoc;
2975-
ParserStatus status =
2976-
P.parseList(tok::r_paren, lParenLoc, rParenLoc,
2977-
/*AllowSepAfterLast*/ false,
2978-
diag::sil_graph_op_expected_rparen,
2979-
SyntaxKind::TuplePatternElementList,
2980-
[&]() -> ParserStatus {
2981-
SILValue value;
2982-
if (parseTypedValueRef(value, B))
2983-
return makeParserError();
2984-
arguments.push_back(value);
2985-
return makeParserSuccess();
2986-
});
3022+
ParserStatus status = P.parseList(tok::r_paren, lParenLoc, rParenLoc,
3023+
/*AllowSepAfterLast*/ false,
3024+
diag::sil_graph_op_expected_rparen,
3025+
SyntaxKind::TuplePatternElementList,
3026+
parseOperand);
29873027
if (status.isError())
29883028
return true;
29893029

29903030
// Parse optional graph operation attributes.
2991-
SmallVector<GraphOperationAttribute, 4> attributes;
29923031
SourceLoc lBraceLoc;
29933032
if (P.consumeIf(tok::l_brace, lBraceLoc)) {
29943033
SourceLoc rBraceLoc;
@@ -3010,7 +3049,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30103049
}
30113050
if (parseSymbolicValue(attrValue, *this, B))
30123051
return makeParserError();
3013-
attributes.push_back({ attrName, attrValue });
3052+
opBuilder.addAttribute({ attrName, attrValue });
30143053
return makeParserSuccess();
30153054
});
30163055
if (status.isError())
@@ -3033,8 +3072,8 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30333072

30343073
if (parseSILDebugLocation(InstLoc, B))
30353074
return true;
3036-
ResultVal = B.createGraphOperation(InstLoc, name, arguments, attributes,
3037-
resultTypes);
3075+
3076+
ResultVal = opBuilder.build(B, P.Context, InstLoc, resultTypes);
30383077
break;
30393078
}
30403079
case SILInstructionKind::OpenExistentialAddrInst:

lib/SIL/SILPrinter.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "swift/SIL/SILPrintContext.h"
2323
#include "swift/SIL/CFG.h"
2424
// SWIFT_ENABLE_TENSORFLOW
25+
#include "swift/SIL/GraphOperationInfo.h"
2526
#include "swift/SIL/SILConstants.h"
2627
#include "swift/SIL/SILFunction.h"
2728
#include "swift/SIL/SILCoverageMap.h"
@@ -1248,11 +1249,30 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
12481249

12491250
// SWIFT_ENABLE_TENSORFLOW
12501251
void visitGraphOperationInst(GraphOperationInst *GI) {
1251-
*this << QuotedString(GI->getName().str());
1252+
tf::GraphOperationInfo info(GI);
1253+
SmallVector<tf::GraphOperationInfo::StructuredOperand, 4> operands;
1254+
auto opName = info.decodeName(operands);
1255+
1256+
*this << QuotedString(opName);
12521257

12531258
*this << "(";
1254-
interleave(GI->getArguments(), [&](SILValue v) {
1255-
*this << getIDAndType(v);
1259+
interleave(operands, [&](tf::GraphOperationInfo::StructuredOperand operand) {
1260+
if (!operand.getName().empty())
1261+
*this << operand.getName() << " ";
1262+
switch (operand.getKind()) {
1263+
case tf::GraphOperationInfo::SOK_Single:
1264+
*this << getIDAndType(operand.getSingleOperand());
1265+
break;
1266+
case tf::GraphOperationInfo::SOK_List:
1267+
*this << "[";
1268+
interleave(operand.getOperandList(), [&](SILValue v) {
1269+
*this << getIDAndType(v);
1270+
}, [&] {
1271+
*this << ", ";
1272+
});
1273+
*this << "]";
1274+
break;
1275+
}
12561276
}, [&] {
12571277
*this << ", ";
12581278
});

lib/SILOptimizer/Mandatory/TFDeabstraction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2162,7 +2162,7 @@ void TFDeabstraction::formGraphOp(SILTensorOpInfo &opInfo,
21622162
return;
21632163
}
21642164

2165-
opBuilder.addScalarOperand(operand);
2165+
opBuilder.addOperand(operand);
21662166
continue;
21672167
}
21682168

lib/SILOptimizer/Mandatory/TFLowerGraph.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,7 +1842,7 @@ TFGraphFunctionLowering::visitGraphOperationInst(GraphOperationInst *inst) {
18421842
// If this is the magic tf_tensor_to_i1 builtin, then we completely ignore it.
18431843
// the only user of it are things that take conditional branches, and they
18441844
// handle it directly.
1845-
if (inst->getName().str() == "tf_tensor_to_i1")
1845+
if (inst->getName().str() == "tf_tensor_to_i1,i")
18461846
return GLStatus::Success;
18471847

18481848
// Decode information about the graph_op.
@@ -1886,8 +1886,6 @@ TFGraphFunctionLowering::visitGraphOperationInst(GraphOperationInst *inst) {
18861886
for (auto structuredOperand : structuredOperands) {
18871887
assert(structuredOperand.getName().empty() && "cannot lower named operands");
18881888
switch (structuredOperand.getKind()) {
1889-
case GraphOperationInfo::SOK_Scalar:
1890-
llvm_unreachable("tfc.scalarToTensor should be lowered by now");
18911889
case GraphOperationInfo::SOK_Single: {
18921890
// Normal tensor inputs.
18931891
auto operand = structuredOperand.getSingleOperand();
@@ -2332,7 +2330,7 @@ static TF_Output getCondition(CondBranchInst *condBr,
23322330
auto *graphOpResult = cast<GraphOperationResult>(cond);
23332331
auto *graphOpInst = graphOpResult->getParent();
23342332
assert(graphOpInst->getNumResults() == 1);
2335-
assert(graphOpInst->getName().str() == "tf_tensor_to_i1");
2333+
assert(graphOpInst->getName().str() == "tf_tensor_to_i1,i");
23362334
tensorToI1 = graphOpInst;
23372335
}
23382336
assert(tensorToI1->getNumOperands() == 1 &&

lib/SILOptimizer/Mandatory/TFPartition.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2408,7 +2408,7 @@ void PartitionCloner::visitCondBranchInst(CondBranchInst *inst) {
24082408

24092409
void PartitionCloner::visitGraphOperationInst(GraphOperationInst *inst) {
24102410
// Handle special case "ops".
2411-
if (inst->getName().is("tfc.scalarToTensor,s")) {
2411+
if (inst->getName().is("tfc.scalarToTensor,i")) {
24122412
assert(inst->getNumOperands() == 1 && "invalid tfc.scalarToTensor!");
24132413
// We just lower the result as the input, since the scalar input will have
24142414
// been promoted to a tensor already. It is possible that the input will

lib/SILOptimizer/Mandatory/TFUtilities.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ tf::createTensorToInt1Inst(SILValue value, SILBuilder &builder,
497497
GraphFunctionDeviceInfo &deviceInfo) {
498498
ASTContext &context = builder.getASTContext();
499499
GraphOperationBuilder opBuilder("tf_tensor_to_i1");
500-
opBuilder.addTFTensorToI1Operand(value);
500+
opBuilder.addOperand(value);
501501
deviceInfo.handleDevicePlacement(
502502
"tf_tensor_to_i1",
503503
/*opDevice*/ getDeviceString(DeviceType::ALL),

test/TensorFlow/dataset_legacy.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public func testDatasetWithFakeData() {
1616
// CHECK-LABEL: --- TFPartition Accelerator Result: {{.*}}testDatasetWithFakeData{{.*}}
1717
// CHECK: bb0:
1818
// CHECK: [[GETNEXT:%[0-9]+]] = graph_op "tfc.makeIteratorGetNextWithDatasets{{.*}} : $TensorHandle<Float>
19-
// CHECK: [[RESULT:%[0-9]+]] = graph_op "Add,i,i"([[GETNEXT]] : $TensorHandle<Float>, {{.*}} : $TensorHandle<Float>
19+
// CHECK: [[RESULT:%[0-9]+]] = graph_op "Add"([[GETNEXT]] : $TensorHandle<Float>, {{.*}} : $TensorHandle<Float>
2020
// CHECK-NEXT: return [[RESULT]] : $TensorHandle<Float>
2121

2222
public func testDatasetWithMNIST() {
@@ -39,8 +39,8 @@ public func testDatasetWithMNIST() {
3939
// CHECK-LABEL: --- TFPartition Accelerator Result: {{.*}}testDatasetWithMNIST{{.*}}
4040
// CHECK: bb0:
4141
// CHECK: (%0, %1) = graph_op "tfc.makeIteratorGetNextWithDatasets{{.*}} : $TensorHandle<Float>, $TensorHandle<Int32>
42-
// CHECK: graph_op "Add,i,i"(
43-
// CHECK: graph_op "Add,i,i"(
42+
// CHECK: graph_op "Add"(
43+
// CHECK: graph_op "Add"(
4444
// The operands can appear in arbitrary order here.
4545
// CHECK: [[RESULT:%.*]] = tuple ({{.*}} : $TensorHandle<{{.*}}>, {{.*}} : $TensorHandle<{{.*}}>)
4646
// CHECK-NEXT: return [[RESULT]] : $(TensorHandle<{{.*}}>, TensorHandle<{{.*}}>)

0 commit comments

Comments
 (0)