Skip to content

Commit 44c18b2

Browse files
authored
Improve deabstraction SSA promotion logic, fixing SR-8395 (#18400)
Improve deabstraction SSA promotion logic, fixing SR-8395 - Handle the simple case of storing into a struct, to handle the common pattern of a store through a struct_element_addr into a Tensor. - Fix handling of begin_access to be a separate loop. Before, we were adding new entries to the use list that we're iterating over. This was almost fine (no invalid iterators or anything) but the entries are added to the start of the list so we wouldn't see them. - Fix a bug this exposed in graph lowering where we wouldn't handle aggregate wrappers around elements in a Tensor initializer.
1 parent fa0d224 commit 44c18b2

File tree

7 files changed

+76
-41
lines changed

7 files changed

+76
-41
lines changed

lib/SILOptimizer/Mandatory/TFDeabstraction.cpp

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -354,14 +354,13 @@ static BuiltinInst *simplifyOperands(BuiltinInst *inst, TFDeabstraction &TFDA) {
354354
if (!decl || !isa<StructDecl>(decl)) return nullptr;
355355

356356
// Check to see if there is a single stored field.
357-
auto fieldIt = decl->getStoredProperties().begin();
358-
if (fieldIt == decl->getStoredProperties().end()) return nullptr;
357+
auto field = tf::getFieldIfContainsSingleField(decl);
358+
if (!field) return nullptr;
359359

360360
// If this is the top level of the struct, retain the field decl.
361-
if (result == nullptr) result = *fieldIt;
361+
if (result == nullptr) result = field;
362362

363-
type = (*fieldIt++)->getType();
364-
if (fieldIt != decl->getStoredProperties().end()) return nullptr;
363+
type = field->getType();
365364

366365
// If we unwrapped a level and got to a builtin type, then this is a
367366
// wrapper.
@@ -1190,10 +1189,33 @@ void TFDeabstraction::prepareStackAllocForPromotion(AllocStackInst *alloc) {
11901189
// we have tensor values mixed in with other random values that shouldn't
11911190
// (or can't) be loaded. For now, we can just fail to deabstract these
11921191
// cases.
1192+
1193+
// Our first scan will look for begin_access instructions and remove them,
1194+
// allowing the second pass to be simpler.
1195+
for (auto UI = alloc->use_begin(); UI != alloc->use_end();) {
1196+
auto *begin = dyn_cast<BeginAccessInst>((*UI++)->getUser());
1197+
if (!begin)
1198+
continue;
1199+
1200+
// If we have a begin_access instruction, replace uses of begin_access with
1201+
// uses of the original value and remove the end_access.
1202+
for (auto UI = begin->use_begin(); UI != begin->use_end();) {
1203+
auto *use = *UI++;
1204+
auto inst = use->getUser();
1205+
if (isa<EndAccessInst>(inst))
1206+
inst->eraseFromParent();
1207+
else
1208+
use->set(alloc);
1209+
}
1210+
begin->eraseFromParent();
1211+
}
1212+
1213+
// Our second pass looks for aggregate operations and struct_element_addrs
1214+
// that poke inside the allocation.
11931215
for (auto UI = alloc->use_begin(); UI != alloc->use_end();) {
1194-
auto inst = (*UI)->getUser();
1216+
auto *inst = (*UI)->getUser();
11951217

1196-
if (auto sea = dyn_cast<StructElementAddrInst>(inst))
1218+
if (auto *sea = dyn_cast<StructElementAddrInst>(inst)) {
11971219
if (auto *use = sea->getSingleUse()) {
11981220
// If we have a load(struct_element_addr(alloc)) turn it into
11991221
// struct_extract(load(alloc)).
@@ -1210,7 +1232,31 @@ void TFDeabstraction::prepareStackAllocForPromotion(AllocStackInst *alloc) {
12101232
sea->eraseFromParent();
12111233
continue;
12121234
}
1235+
1236+
// If we have a store(x ->struct_element_addr(alloc)), turn it into a
1237+
// load of the whole value, a bunch of extracts, then a struct_inst
1238+
// to rebuild the whole value, then a store of the whole thing.
1239+
//
1240+
// TODO: For now, we only handle a single element struct, which is
1241+
// considerably simpler.
1242+
//
1243+
if (auto *store = dyn_cast<StoreInst>(use->getUser())) {
1244+
if (use->getOperandNumber() == 1 && // store TO the alloca.
1245+
tf::getFieldIfContainsSingleField(sea->getStructDecl())) {
1246+
SILBuilder B(store);
1247+
auto *newStruct = B.createStruct(store->getLoc(),
1248+
alloc->getType().getObjectType(),
1249+
store->getOperand(0));
1250+
B.createStore(store->getLoc(), newStruct, sea->getOperand(),
1251+
store->getOwnershipQualifier());
1252+
store->eraseFromParent();
1253+
++UI;
1254+
sea->eraseFromParent();
1255+
continue;
1256+
}
1257+
}
12131258
}
1259+
}
12141260

12151261
// Explode aggregate by-address instructions like copy-addr.
12161262
if (explodeAggregateInst(inst, /*all types*/nullptr)) {
@@ -1219,28 +1265,8 @@ void TFDeabstraction::prepareStackAllocForPromotion(AllocStackInst *alloc) {
12191265
continue;
12201266
}
12211267

1222-
// If we have an instruction other than begin_access, remember it.
1223-
auto *begin = dyn_cast<BeginAccessInst>(inst);
1224-
if (!begin) {
1225-
++UI;
1226-
continue;
1227-
}
1228-
1229-
// If we have a begin_access instruction, look through it. Add all of the
1230-
// users to the users list, and replace uses of begin_access with uses of
1231-
// the original value. Finally, ignore and remove the end_access.
1232-
for (auto UI = begin->use_begin(); UI != begin->use_end();) {
1233-
auto *use = *UI++;
1234-
auto inst = use->getUser();
1235-
if (isa<EndAccessInst>(inst)) {
1236-
inst->eraseFromParent();
1237-
} else {
1238-
use->set(alloc);
1239-
}
1240-
}
1241-
1268+
// Otherwise we have something else, leave it alone.
12421269
++UI;
1243-
begin->eraseFromParent();
12441270
}
12451271
}
12461272

lib/SILOptimizer/Mandatory/TFLowerGraph.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,6 +2119,7 @@ TFGraphFunctionLowering::visitGraphOperationInst(GraphOperationInst *inst) {
21192119
// The scalar case is very simple, the shape of a scalar is 0d, and the
21202120
// data type comes from an attr that should already be processed.
21212121
SmallVector<int64_t, 4> shape;
2122+
attrValue = attrValue.lookThroughSingleElementAggregates();
21222123
if (attrValue.getKind() == SymbolicValue::Integer ||
21232124
attrValue.getKind() == SymbolicValue::Float) {
21242125
if (addScalar(attrValue, elements))

lib/SILOptimizer/Mandatory/TFPartition.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,9 @@ static bool isUserIgnoredByPartitioning(SILInstruction *inst) {
8282
/// type of the single member, asserting and aborting if we get something
8383
/// unexpected.
8484
static CanType getSingleElementDeclFieldType(NominalTypeDecl *decl) {
85-
auto fieldIt = decl->getStoredProperties().begin();
86-
assert(fieldIt != decl->getStoredProperties().end() &&
87-
"Struct should have one member");
88-
auto fieldType = (*fieldIt++)->getType()->getCanonicalType();
89-
assert(fieldIt == decl->getStoredProperties().end() &&
90-
"Struct should have one member");
91-
return fieldType;
85+
auto *field = tf::getFieldIfContainsSingleField(decl);
86+
assert(field && "Struct should have one member");
87+
return field->getType()->getCanonicalType();
9288
}
9389

9490
/// Classification of instructions that are interesting to the partitioning

lib/SILOptimizer/Mandatory/TFUtilities.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ llvm::raw_ostream *tf::getTFDumpIntermediateStream() {
8484
return &fileStream;
8585
}
8686

87+
/// If the specified decl has a single stored field, return it. Otherwise
88+
/// return null.
89+
VarDecl *tf::getFieldIfContainsSingleField(NominalTypeDecl *decl) {
90+
// Check to see if there is a single stored field.
91+
auto fieldIt = decl->getStoredProperties().begin();
92+
if (fieldIt == decl->getStoredProperties().end())
93+
return nullptr;
94+
auto result = *fieldIt++;
95+
if (fieldIt != decl->getStoredProperties().end())
96+
return nullptr;
97+
return result;
98+
}
99+
87100
bool tf::isTensorHandle(SILType ty) {
88101
return (bool)isTensorHandle(ty.getASTType());
89102
}

lib/SILOptimizer/Mandatory/TFUtilities.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ namespace tf {
3737
/// return null. This is used for integration unit tests and debugging.
3838
llvm::raw_ostream *getTFDumpIntermediateStream();
3939

40+
/// If the specified decl has a single stored field, return it. Otherwise
41+
/// return null.
42+
VarDecl *getFieldIfContainsSingleField(NominalTypeDecl *decl);
43+
4044
/// If the specified type is the well-known TensorHandle<T> type, then return
4145
/// "T". If not, return a null type.
4246
bool isTensorHandle(SILType ty);

test/TensorFlow/debugging.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ public func debugValuesInLoop(_ x: Tensor<Float>) {
1818
// CHECK-LABEL: --- TFPartition Accelerator Result: {{.*}}basicDebugValues{{.*}}
1919
// CHECK: @{{.*}}basicDebugValues{{.*}}.tf
2020
// CHECK: [[ONE:%.*]] = graph_op "Const"
21-
// CHECK-NEXT: graph_op "tfc.SendToHost,i"
2221
// CHECK: [[ADD_RESULT:%.*]] = graph_op "Add,i,i"
23-
// CHECK-NEXT: graph_op "tfc.SendToHost,i"([[ADD_RESULT]] : $TensorHandle<Float>)
2422
// CHECK: graph_op "Square,i"([[ADD_RESULT]] : $TensorHandle<Float>) {T: $Float, __device: "/device:CPU:0"} : $TensorHandle<Float>
2523

2624

test/TensorFlow/optimization-disabled.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,14 @@
22
import TensorFlow
33

44
public func testArrayValues() -> Tensor<Float> {
5-
// expected-warning @+1 14 {{value implicitly copied to the host}}
65
let x: Tensor<Float> = [[1, 2], [3, 4]]
76
return (matmul(x, x) + x).toHost()
8-
// expected-warning @-1 {{value implicitly copied to the host}}
97
}
108

119
/*
1210
CHECK-LABEL: --- TFPartition Accelerator Result: {{.*}}testArrayValues
1311
CHECK: %0 = graph_op "Const"() {dtype: $Float, value$tensor: f32 0x3F800000 /* 1 */, __device: "ALL_DEVICES"} : $TensorHandle<Float>
14-
CHECK: %1 = graph_op "tfc.SendToHost,i"(%0 : $TensorHandle<Float>) {tensorId: i32 0, __device: "/device:CPU:0"}
15-
CHECK-NOT: tfc.RecvFromHost
12+
CHECK: %1 = graph_op "Const"() {dtype: $Float, value$tensor: f32 0x40000000 /* 2 */
1613
CHECK-LABEL: ----
1714
*/
1815

0 commit comments

Comments
 (0)