Skip to content

Commit 2b6b6ce

Browse files
authored
fix double-deabstraction bug (#19593)
1 parent bb26134 commit 2b6b6ce

File tree

4 files changed

+82
-12
lines changed

4 files changed

+82
-12
lines changed

include/swift/SIL/GraphOperationBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class GraphOperationBuilder {
4444
void addListArgument(llvm::ArrayRef<SILValue> arguments,
4545
llvm::StringRef name = llvm::StringRef());
4646

47+
/// Add a list argument to the GraphOperationInst, with an optional name.
48+
void addListArgument(OperandValueArrayRef arguments,
49+
llvm::StringRef name = llvm::StringRef());
50+
4751
/// Add an attribute with known constant value to the GraphOperationInst.
4852
/// Returns a reference to the attribute, valid for the lifetime of the
4953
/// GraphOperationBuilder, that you can use to mutate the attribute before

lib/AST/GraphOperationBuilder.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ void GraphOperationBuilder::addListArgument(ArrayRef<SILValue> arguments,
4343
}
4444
}
4545

46+
/// Add a list argument to the GraphOperationInst, with an optional name.
47+
void GraphOperationBuilder::addListArgument(OperandValueArrayRef arguments,
48+
StringRef name) {
49+
MangledName += ",L";
50+
MangledName += name;
51+
for (auto argument : arguments) {
52+
MangledName += ",e";
53+
Operands.push_back(argument);
54+
}
55+
}
56+
4657
/// Add an attribute with known constant value to the GraphOperationInst.
4758
/// Returns a reference to the attribute, valid for the lifetime of the
4859
/// GraphOperationBuilder, that you can use to mutate the attribute before

lib/SILOptimizer/Mandatory/TFDeabstraction.cpp

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -384,22 +384,31 @@ static GraphOperationInst *simplifyOperands(GraphOperationInst *origInst,
384384
};
385385

386386
// Predicate that returns true if an argument of the specified type should be
387-
// rewritten - either to load an address argument or expand a struct
388-
// parameter.
389-
auto canSimplifyOperand =
390-
[&](SILType type, GraphOperationInfo::ArgumentLowering lowering) -> bool {
387+
// rewritten to load an address argument or expand a struct parameter.
388+
auto canSimplifyArgumentType = [&](SILType type) -> bool {
391389
return isLoadableAddressType(type) ||
392-
getPrimitiveStructField(type.getASTType()) != nullptr ||
393-
lowering == GraphOperationInfo::ArgumentLowering::Out;
390+
getPrimitiveStructField(type.getASTType()) != nullptr;
391+
};
392+
393+
// Predicate that returns true if an argument should be rewritten.
394+
auto canSimplifyArgument =
395+
[&](const GraphOperationInfo::StructuredArgument &argument) -> bool {
396+
switch (argument.getKind()) {
397+
case GraphOperationInfo::SAK_Single:
398+
return canSimplifyArgumentType(argument.getSingleArgument()->getType()) ||
399+
std::get<1>(argument.getArgumentNameAndLowering()) ==
400+
GraphOperationInfo::ArgumentLowering::Out;
401+
case GraphOperationInfo::SAK_List:
402+
// We can get SAK_List arguments from inlining functions that have already
403+
// been deabstracted. These arguments do not need further simplification.
404+
return false;
405+
}
394406
};
395407

396408
// If we don't have to change any arguments, don't rewrite the graph_op.
397409
bool mustChangeGraphOp = false;
398410
for (auto &argument : opInfo.getStructuredArguments()) {
399-
assert(argument.getKind() == GraphOperationInfo::SAK_Single &&
400-
"SILGen should not have generated a list argument");
401-
if (canSimplifyOperand(argument.getSingleArgument()->getType(),
402-
std::get<1>(argument.getArgumentNameAndLowering()))) {
411+
if (canSimplifyArgument(argument)) {
403412
mustChangeGraphOp = true;
404413
break;
405414
}
@@ -413,10 +422,21 @@ static GraphOperationInst *simplifyOperands(GraphOperationInst *origInst,
413422
// Okay, we do have to simplify something. Scan through and rewrite arguments.
414423
SILBuilder B(origInst);
415424
GraphOperationBuilder opBuilder(opInfo.getOperationName());
425+
// Pass attributes through.
426+
for (auto &attr : origInst->getAttributes())
427+
opBuilder.addAttribute(attr);
416428
SILValue outParameterAddress;
417429
for (auto &argument : opInfo.getStructuredArguments()) {
430+
if (argument.getKind() == GraphOperationInfo::SAK_List) {
431+
// We can get SAK_List arguments from inlining functions that have already
432+
// been deabstracted. Pass these arguments through.
433+
opBuilder.addListArgument(argument.getArgumentList(),
434+
argument.getArgumentNameWithSuffix());
435+
continue;
436+
}
437+
418438
assert(argument.getKind() == GraphOperationInfo::SAK_Single &&
419-
"SILGen should not have generated a list argument");
439+
"should have already handled all other argument kinds");
420440
auto argumentValue = argument.getSingleArgument();
421441
auto argumentLowering = std::get<1>(argument.getArgumentNameAndLowering());
422442

@@ -2126,6 +2146,14 @@ void TFDeabstraction::evaluateAttributesAndDoPacking(
21262146
// Find the device attribute specified for the instruction if present.
21272147
StringRef opDevice;
21282148

2149+
// Pass attributes through.
2150+
for (auto &attr : origInst->getAttributes()) {
2151+
if (attr.name.str() == DEVICE_ATTR) {
2152+
opDevice = attr.value.getStringValue();
2153+
}
2154+
opBuilder.addAttribute(attr);
2155+
}
2156+
21292157
// It is common to have input lists with repeated elements. These will
21302158
// generally be uniqued on entry to this routine. We cache the projections in
21312159
// these maps so that we can reuse them and avoid code bloat.
@@ -2137,8 +2165,17 @@ void TFDeabstraction::evaluateAttributesAndDoPacking(
21372165

21382166
for (auto i : range(opInfo.getStructuredArguments().size())) {
21392167
auto argument = opInfo.getStructuredArguments()[i];
2168+
2169+
if (argument.getKind() == GraphOperationInfo::SAK_List) {
2170+
// We can get SAK_List arguments from inlining functions that have already
2171+
// been deabstracted. Pass these arguments through.
2172+
opBuilder.addListArgument(argument.getArgumentList(),
2173+
argument.getArgumentNameWithSuffix());
2174+
continue;
2175+
}
2176+
21402177
assert(argument.getKind() == GraphOperationInfo::SAK_Single &&
2141-
"SILGen should not have generated a list argument");
2178+
"should have already handled all other argument kinds");
21422179
auto argumentValue = argument.getSingleArgument();
21432180
auto argumentTy = argumentValue->getType();
21442181
auto argumentNameAndLowering = argument.getArgumentNameAndLowering();

test/TensorFlow/crashers.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,21 @@ public func SR8419(iterationCount: Int) {
372372
}
373373
}
374374
}
375+
376+
// If `deabstractedCallee` gets deabstracted before `inlineDeabstracted_*`,
377+
// then the insts in `deabstractedCallee` get deabstracted twice. There was
378+
// a bug where the compiler crashed when deabstracting certain graph_ops twice.
379+
// There is no guaranteed deabstraction order, so this test isn't guaranteed to
380+
// catch the problem. Sandwiching `deabstractedCallee` between two callers
381+
// makes this test catch the problem as long as the order happens to be linear
382+
// up or down.
383+
public func inlineDeabstracted_a() -> Tensor<Float> {
384+
return deabstractedCallee([1, 2, 3])
385+
}
386+
// expected-warning @+1 {{implicitly copied}}
387+
public func deabstractedCallee(_ t: Tensor<Float>) -> Tensor<Float> {
388+
return t ++ Tensor<Float>([1, 2, 3]) // expected-note {{value used here}}
389+
}
390+
public func inlineDeabstracted_b() -> Tensor<Float> {
391+
return deabstractedCallee([1, 2, 3])
392+
}

0 commit comments

Comments
 (0)