Skip to content

(SR-8783) nicer textual sil for graph_op #19429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,8 @@ ERROR(autodiff_use_wrt_not_withrespectto,none,
//------------------------------------------------------------------------------
// Graph operation related parsing diagnostics
//------------------------------------------------------------------------------
ERROR(sil_graph_op_name_comma,PointsToFirstBadToken,
"'graph_op' name cannot contain ','", ())
ERROR(sil_graph_op_expected_attr_name,PointsToFirstBadToken,
"expected 'graph_op' attribute name", ())
ERROR(sil_graph_op_expected_attr_value,PointsToFirstBadToken,
Expand All @@ -1653,6 +1655,8 @@ ERROR(sil_graph_op_unhandled_attr_value,PointsToFirstBadToken,
"unhandled 'graph_op' attribute value", ())
ERROR(sil_graph_op_expected_rparen,PointsToFirstBadToken,
"expected ')' in 'graph_op' argument list", ())
ERROR(sil_graph_op_expected_rsquare,PointsToFirstBadToken,
"expected ']' in 'graph_op' list operand", ())
ERROR(sil_graph_op_expected_rbrace,PointsToFirstBadToken,
"expected '}' in 'graph_op' attribute list", ())
ERROR(sil_graph_op_expected_colon_after_attr_name,PointsToFirstBadToken,
Expand Down
10 changes: 0 additions & 10 deletions include/swift/SIL/GraphOperationBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ class GraphOperationBuilder {
GraphOperationAttribute &addAttribute(
const GraphOperationAttribute &attribute);

/// Special method that should only be used for "tfc.scalarToTensor"'s operand,
/// because it has special name mangling. (Marker is "s").
void addScalarOperand(SILValue operand);

/// Special method that should only be used for "tf_tensor_to_i1"'s operand,
/// because it has special name mangling. (No marker for its operand).
/// TODO: Make "tf_tensor_to_i1" support normal name mangling, and then remove
/// this.
void addTFTensorToI1Operand(SILValue operand);

/// Build the GraphOperationInst.
GraphOperationInst* build(
SILBuilder &B, ASTContext &C, SILLocation loc,
Expand Down
7 changes: 2 additions & 5 deletions include/swift/SIL/GraphOperationInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ struct GraphOperationInfo {
// }

enum StructuredOperandKind {
/// Scalar input, used by tfc.scalarToTensor only.
/// Mangled name is ",s"
SOK_Scalar,
/// Single operand.
/// Mangled name is ",i${name}" where ${name} is an optional name.
SOK_Single,
Expand All @@ -117,7 +114,7 @@ struct GraphOperationInfo {
StructuredOperandKind Kind;
StringRef Name;
union {
/// Operand for SOK_Scalar and SOK_Single.
/// Operand for SOK_Single.
SILValue SingleOperand;
/// Operands for SOK_List.
ArrayRef<Operand> OperandList;
Expand All @@ -140,7 +137,7 @@ struct GraphOperationInfo {
}

SILValue getSingleOperand() const {
assert(getKind() == SOK_Scalar || getKind() == SOK_Single);
assert(getKind() == SOK_Single);
return SingleOperand;
}

Expand Down
20 changes: 4 additions & 16 deletions lib/AST/GraphOperationBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ namespace tf {

/// Start building a GraphOperationInst for op `OpName`.
GraphOperationBuilder::GraphOperationBuilder(StringRef OpName)
: MangledName(OpName) {}
: MangledName(OpName) {
assert(MangledName.find(',') == StringRef::npos &&
"graph_op name cannot include ','");
}

/// Add a single operand to the GraphOperationInst, with an optional name.
void GraphOperationBuilder::addOperand(SILValue operand, StringRef name) {
Expand Down Expand Up @@ -50,21 +53,6 @@ GraphOperationAttribute &GraphOperationBuilder::addAttribute(
return Attributes.back();
}

/// Special method that should only be used for "tfc.scalarToTensor"'s operand,
/// because it has special name mangling. (Marker is "s").
void GraphOperationBuilder::addScalarOperand(SILValue operand) {
MangledName += ",s";
Operands.push_back(operand);
}

/// Special method that should only be used for "tf_tensor_to_i1"'s operand,
/// because it has special name mangling. (No marker for its operand).
/// TODO: Make "tf_tensor_to_i1" support normal name mangling, and then remove
/// this.
void GraphOperationBuilder::addTFTensorToI1Operand(SILValue operand) {
Operands.push_back(operand);
}

/// Build the GraphOperationInst.
GraphOperationInst* GraphOperationBuilder::build(
SILBuilder &B, ASTContext &C, SILLocation loc,
Expand Down
7 changes: 0 additions & 7 deletions lib/AST/GraphOperationInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,6 @@ StringRef GraphOperationInfo::decodeName(
StringRef thisMarkerName = thisMarker.drop_front(2);
assert(thisMarker.size() >= 2 && "marker too short");
switch (thisMarker[1]) {
case 's':
// Push a SOK_Scalar.
assert(thisMarkerName.empty() && "SOK_Scalar should not have name");
structuredOperands.emplace_back(SOK_Scalar, StringRef(),
remainingOperands.front().get());
remainingOperands = remainingOperands.drop_front(1);
break;
case 'i':
// Push a SOK_Single.
structuredOperands.emplace_back(SOK_Single, thisMarkerName,
Expand Down
79 changes: 59 additions & 20 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "swift/SIL/SILArgument.h"
#include "swift/SIL/SILBuilder.h"
/// SWIFT_ENABLE_TENSORFLOW
#include "swift/SIL/GraphOperationBuilder.h"
#include "swift/SIL/SILConstants.h"
#include "swift/SIL/SILDebugScope.h"
#include "swift/SIL/SILModule.h"
Expand Down Expand Up @@ -2960,35 +2961,73 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, "graph_op name");
return true;
}
StringRef rawString = P.Tok.getText().drop_front().drop_back();
Identifier name = P.Context.getIdentifier(rawString);
StringRef opName = P.Tok.getText().drop_front().drop_back();
if (opName.find(',') != StringRef::npos) {
P.diagnose(P.Tok, diag::sil_graph_op_name_comma);
return true;
}
tf::GraphOperationBuilder opBuilder(opName);
P.consumeToken(tok::string_literal);

// Parse graph operation arguments.
// Parses a top-level operand to the graphop, and add it to `opBuilder`.
auto parseOperand = [&]() -> ParserStatus {
// Parse the optional operand name.
StringRef operandName;
if (P.Tok.is(tok::identifier)) {
operandName = P.Tok.getText();
P.consumeToken();
}

if (P.Tok.is(tok::l_square)) {
// It is a list operand.
SourceLoc lSquareLoc = P.consumeToken(tok::l_square);
SourceLoc rSquareLoc;
SmallVector<SILValue, 4> elements;

// Parses an element of a list operand, and adds it to `elements`.
auto parseListOperandElement = [&]() -> ParserStatus {
SILValue value;
if (parseTypedValueRef(value, B))
return makeParserError();
elements.push_back(value);
return makeParserSuccess();
};

ParserStatus status = P.parseList(tok::r_square, lSquareLoc, rSquareLoc,
/*AllowSepAfterLast*/ false,
diag::sil_graph_op_expected_rsquare,
SyntaxKind::TuplePatternElementList,
parseListOperandElement);
if (status.isError())
return status;
opBuilder.addListOperand(elements, operandName);
return makeParserSuccess();
} else {
// It is a single operand.
SILValue value;
if (parseTypedValueRef(value, B))
return makeParserError();
opBuilder.addOperand(value, operandName);
return makeParserSuccess();
}
};

// Parse graph operation operands.
if (P.Tok.isNot(tok::l_paren)) {
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, "(");
return true;
}
SmallVector<SILValue, 4> arguments;
SourceLoc lParenLoc = P.consumeToken(tok::l_paren);
SourceLoc rParenLoc;
ParserStatus status =
P.parseList(tok::r_paren, lParenLoc, rParenLoc,
/*AllowSepAfterLast*/ false,
diag::sil_graph_op_expected_rparen,
SyntaxKind::TuplePatternElementList,
[&]() -> ParserStatus {
SILValue value;
if (parseTypedValueRef(value, B))
return makeParserError();
arguments.push_back(value);
return makeParserSuccess();
});
ParserStatus status = P.parseList(tok::r_paren, lParenLoc, rParenLoc,
/*AllowSepAfterLast*/ false,
diag::sil_graph_op_expected_rparen,
SyntaxKind::TuplePatternElementList,
parseOperand);
if (status.isError())
return true;

// Parse optional graph operation attributes.
SmallVector<GraphOperationAttribute, 4> attributes;
SourceLoc lBraceLoc;
if (P.consumeIf(tok::l_brace, lBraceLoc)) {
SourceLoc rBraceLoc;
Expand All @@ -3010,7 +3049,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
}
if (parseSymbolicValue(attrValue, *this, B))
return makeParserError();
attributes.push_back({ attrName, attrValue });
opBuilder.addAttribute({ attrName, attrValue });
return makeParserSuccess();
});
if (status.isError())
Expand All @@ -3033,8 +3072,8 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {

if (parseSILDebugLocation(InstLoc, B))
return true;
ResultVal = B.createGraphOperation(InstLoc, name, arguments, attributes,
resultTypes);

ResultVal = opBuilder.build(B, P.Context, InstLoc, resultTypes);
break;
}
case SILInstructionKind::OpenExistentialAddrInst:
Expand Down
26 changes: 23 additions & 3 deletions lib/SIL/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "swift/SIL/SILPrintContext.h"
#include "swift/SIL/CFG.h"
// SWIFT_ENABLE_TENSORFLOW
#include "swift/SIL/GraphOperationInfo.h"
#include "swift/SIL/SILConstants.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILCoverageMap.h"
Expand Down Expand Up @@ -1248,11 +1249,30 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {

// SWIFT_ENABLE_TENSORFLOW
void visitGraphOperationInst(GraphOperationInst *GI) {
*this << QuotedString(GI->getName().str());
tf::GraphOperationInfo info(GI);
SmallVector<tf::GraphOperationInfo::StructuredOperand, 4> operands;
auto opName = info.decodeName(operands);

*this << QuotedString(opName);

*this << "(";
interleave(GI->getArguments(), [&](SILValue v) {
*this << getIDAndType(v);
interleave(operands, [&](tf::GraphOperationInfo::StructuredOperand operand) {
if (!operand.getName().empty())
*this << operand.getName() << " ";
switch (operand.getKind()) {
case tf::GraphOperationInfo::SOK_Single:
*this << getIDAndType(operand.getSingleOperand());
break;
case tf::GraphOperationInfo::SOK_List:
*this << "[";
interleave(operand.getOperandList(), [&](SILValue v) {
*this << getIDAndType(v);
}, [&] {
*this << ", ";
});
*this << "]";
break;
}
}, [&] {
*this << ", ";
});
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Mandatory/TFDeabstraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,7 @@ void TFDeabstraction::formGraphOp(SILTensorOpInfo &opInfo,
return;
}

opBuilder.addScalarOperand(operand);
opBuilder.addOperand(operand);
continue;
}

Expand Down
6 changes: 2 additions & 4 deletions lib/SILOptimizer/Mandatory/TFLowerGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1842,7 +1842,7 @@ TFGraphFunctionLowering::visitGraphOperationInst(GraphOperationInst *inst) {
// If this is the magic tf_tensor_to_i1 builtin, then we completely ignore it.
// the only user of it are things that take conditional branches, and they
// handle it directly.
if (inst->getName().str() == "tf_tensor_to_i1")
if (inst->getName().str() == "tf_tensor_to_i1,i")
return GLStatus::Success;

// Decode information about the graph_op.
Expand Down Expand Up @@ -1886,8 +1886,6 @@ TFGraphFunctionLowering::visitGraphOperationInst(GraphOperationInst *inst) {
for (auto structuredOperand : structuredOperands) {
assert(structuredOperand.getName().empty() && "cannot lower named operands");
switch (structuredOperand.getKind()) {
case GraphOperationInfo::SOK_Scalar:
llvm_unreachable("tfc.scalarToTensor should be lowered by now");
case GraphOperationInfo::SOK_Single: {
// Normal tensor inputs.
auto operand = structuredOperand.getSingleOperand();
Expand Down Expand Up @@ -2332,7 +2330,7 @@ static TF_Output getCondition(CondBranchInst *condBr,
auto *graphOpResult = cast<GraphOperationResult>(cond);
auto *graphOpInst = graphOpResult->getParent();
assert(graphOpInst->getNumResults() == 1);
assert(graphOpInst->getName().str() == "tf_tensor_to_i1");
assert(graphOpInst->getName().str() == "tf_tensor_to_i1,i");
tensorToI1 = graphOpInst;
}
assert(tensorToI1->getNumOperands() == 1 &&
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Mandatory/TFPartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2408,7 +2408,7 @@ void PartitionCloner::visitCondBranchInst(CondBranchInst *inst) {

void PartitionCloner::visitGraphOperationInst(GraphOperationInst *inst) {
// Handle special case "ops".
if (inst->getName().is("tfc.scalarToTensor,s")) {
if (inst->getName().is("tfc.scalarToTensor,i")) {
assert(inst->getNumOperands() == 1 && "invalid tfc.scalarToTensor!");
// We just lower the result as the input, since the scalar input will have
// been promoted to a tensor already. It is possible that the input will
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Mandatory/TFUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ tf::createTensorToInt1Inst(SILValue value, SILBuilder &builder,
GraphFunctionDeviceInfo &deviceInfo) {
ASTContext &context = builder.getASTContext();
GraphOperationBuilder opBuilder("tf_tensor_to_i1");
opBuilder.addTFTensorToI1Operand(value);
opBuilder.addOperand(value);
deviceInfo.handleDevicePlacement(
"tf_tensor_to_i1",
/*opDevice*/ getDeviceString(DeviceType::ALL),
Expand Down
6 changes: 3 additions & 3 deletions test/TensorFlow/dataset_legacy.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public func testDatasetWithFakeData() {
// CHECK-LABEL: --- TFPartition Accelerator Result: {{.*}}testDatasetWithFakeData{{.*}}
// CHECK: bb0:
// CHECK: [[GETNEXT:%[0-9]+]] = graph_op "tfc.makeIteratorGetNextWithDatasets{{.*}} : $TensorHandle<Float>
// CHECK: [[RESULT:%[0-9]+]] = graph_op "Add,i,i"([[GETNEXT]] : $TensorHandle<Float>, {{.*}} : $TensorHandle<Float>
// CHECK: [[RESULT:%[0-9]+]] = graph_op "Add"([[GETNEXT]] : $TensorHandle<Float>, {{.*}} : $TensorHandle<Float>
// CHECK-NEXT: return [[RESULT]] : $TensorHandle<Float>

public func testDatasetWithMNIST() {
Expand All @@ -39,8 +39,8 @@ public func testDatasetWithMNIST() {
// CHECK-LABEL: --- TFPartition Accelerator Result: {{.*}}testDatasetWithMNIST{{.*}}
// CHECK: bb0:
// CHECK: (%0, %1) = graph_op "tfc.makeIteratorGetNextWithDatasets{{.*}} : $TensorHandle<Float>, $TensorHandle<Int32>
// CHECK: graph_op "Add,i,i"(
// CHECK: graph_op "Add,i,i"(
// CHECK: graph_op "Add"(
// CHECK: graph_op "Add"(
// The operands can appear in arbitrary order here.
// CHECK: [[RESULT:%.*]] = tuple ({{.*}} : $TensorHandle<{{.*}}>, {{.*}} : $TensorHandle<{{.*}}>)
// CHECK-NEXT: return [[RESULT]] : $(TensorHandle<{{.*}}>, TensorHandle<{{.*}}>)
Loading