@@ -354,14 +354,13 @@ static BuiltinInst *simplifyOperands(BuiltinInst *inst, TFDeabstraction &TFDA) {
354
354
if (!decl || !isa<StructDecl>(decl)) return nullptr ;
355
355
356
356
// Check to see if there is a single stored field.
357
- auto fieldIt = decl-> getStoredProperties (). begin ( );
358
- if (fieldIt == decl-> getStoredProperties (). end () ) return nullptr ;
357
+ auto field = tf::getFieldIfContainsSingleField (decl );
358
+ if (!field ) return nullptr ;
359
359
360
360
// If this is the top level of the struct, retain the field decl.
361
- if (result == nullptr ) result = *fieldIt ;
361
+ if (result == nullptr ) result = field ;
362
362
363
- type = (*fieldIt++)->getType ();
364
- if (fieldIt != decl->getStoredProperties ().end ()) return nullptr ;
363
+ type = field->getType ();
365
364
366
365
// If we unwrapped a level and got to a builtin type, then this is a
367
366
// wrapper.
@@ -1190,10 +1189,33 @@ void TFDeabstraction::prepareStackAllocForPromotion(AllocStackInst *alloc) {
1190
1189
// we have tensor values mixed in with other random values that shouldn't
1191
1190
// (or can't) be loaded. For now, we can just fail to deabstract these
1192
1191
// cases.
1192
+
1193
+ // Our first scan will look for begin_access instructions and remove them,
1194
+ // allowing the second pass to be simpler.
1195
+ for (auto UI = alloc->use_begin (); UI != alloc->use_end ();) {
1196
+ auto *begin = dyn_cast<BeginAccessInst>((*UI++)->getUser ());
1197
+ if (!begin)
1198
+ continue ;
1199
+
1200
+ // If we have a begin_access instruction, replace uses of begin_access with
1201
+ // uses of the original value and remove the end_access.
1202
+ for (auto UI = begin->use_begin (); UI != begin->use_end ();) {
1203
+ auto *use = *UI++;
1204
+ auto inst = use->getUser ();
1205
+ if (isa<EndAccessInst>(inst))
1206
+ inst->eraseFromParent ();
1207
+ else
1208
+ use->set (alloc);
1209
+ }
1210
+ begin->eraseFromParent ();
1211
+ }
1212
+
1213
+ // Our second pass looks for aggregate operations and struct_element_addrs
1214
+ // that poke inside the allocation.
1193
1215
for (auto UI = alloc->use_begin (); UI != alloc->use_end ();) {
1194
- auto inst = (*UI)->getUser ();
1216
+ auto * inst = (*UI)->getUser ();
1195
1217
1196
- if (auto sea = dyn_cast<StructElementAddrInst>(inst))
1218
+ if (auto * sea = dyn_cast<StructElementAddrInst>(inst)) {
1197
1219
if (auto *use = sea->getSingleUse ()) {
1198
1220
// If we have a load(struct_element_addr(alloc)) turn it into
1199
1221
// struct_extract(load(alloc)).
@@ -1210,7 +1232,31 @@ void TFDeabstraction::prepareStackAllocForPromotion(AllocStackInst *alloc) {
1210
1232
sea->eraseFromParent ();
1211
1233
continue ;
1212
1234
}
1235
+
1236
+ // If we have a store(x ->struct_element_addr(alloc)), turn it into a
1237
+ // load of the whole value, a bunch of extracts, then a struct_inst
1238
+ // to rebuild the whole value, then a store of the whole thing.
1239
+ //
1240
+ // TODO: For now, we only handle a single element struct, which is
1241
+ // considerably simpler.
1242
+ //
1243
+ if (auto *store = dyn_cast<StoreInst>(use->getUser ())) {
1244
+ if (use->getOperandNumber () == 1 && // store TO the alloca.
1245
+ tf::getFieldIfContainsSingleField (sea->getStructDecl ())) {
1246
+ SILBuilder B (store);
1247
+ auto *newStruct = B.createStruct (store->getLoc (),
1248
+ alloc->getType ().getObjectType (),
1249
+ store->getOperand (0 ));
1250
+ B.createStore (store->getLoc (), newStruct, sea->getOperand (),
1251
+ store->getOwnershipQualifier ());
1252
+ store->eraseFromParent ();
1253
+ ++UI;
1254
+ sea->eraseFromParent ();
1255
+ continue ;
1256
+ }
1257
+ }
1213
1258
}
1259
+ }
1214
1260
1215
1261
// Explode aggregate by-address instructions like copy-addr.
1216
1262
if (explodeAggregateInst (inst, /* all types*/ nullptr )) {
@@ -1219,28 +1265,8 @@ void TFDeabstraction::prepareStackAllocForPromotion(AllocStackInst *alloc) {
1219
1265
continue ;
1220
1266
}
1221
1267
1222
- // If we have an instruction other than begin_access, remember it.
1223
- auto *begin = dyn_cast<BeginAccessInst>(inst);
1224
- if (!begin) {
1225
- ++UI;
1226
- continue ;
1227
- }
1228
-
1229
- // If we have a begin_access instruction, look through it. Add all of the
1230
- // users to the users list, and replace uses of begin_access with uses of
1231
- // the original value. Finally, ignore and remove the end_access.
1232
- for (auto UI = begin->use_begin (); UI != begin->use_end ();) {
1233
- auto *use = *UI++;
1234
- auto inst = use->getUser ();
1235
- if (isa<EndAccessInst>(inst)) {
1236
- inst->eraseFromParent ();
1237
- } else {
1238
- use->set (alloc);
1239
- }
1240
- }
1241
-
1268
+ // Otherwise we have something else, leave it alone.
1242
1269
++UI;
1243
- begin->eraseFromParent ();
1244
1270
}
1245
1271
}
1246
1272
0 commit comments