20
20
#include " llvm/ADT/STLExtras.h"
21
21
#include " llvm/ADT/TypeSwitch.h"
22
22
23
+ #define DEBUG_TYPE " sroa"
24
+
23
25
using namespace mlir ;
24
26
25
27
// ===----------------------------------------------------------------------===//
@@ -431,29 +433,165 @@ DeletionKind LLVM::GEPOp::removeBlockingUses(
431
433
return DeletionKind::Delete;
432
434
}
433
435
434
- static bool isFirstIndexZero (LLVM::GEPOp gep) {
435
- IntegerAttr index =
436
- llvm::dyn_cast_if_present<IntegerAttr>(gep.getIndices ()[0 ]);
437
- return index && index.getInt () == 0 ;
436
+ // / Returns the amount of bytes the provided GEP elements will offset the
437
+ // / pointer by. Returns nullopt if no constant offset could be computed.
438
+ static std::optional<uint64_t > gepToByteOffset (const DataLayout &dataLayout,
439
+ LLVM::GEPOp gep) {
440
+ // Collects all indices.
441
+ SmallVector<uint64_t > indices;
442
+ for (auto index : gep.getIndices ()) {
443
+ auto constIndex = dyn_cast<IntegerAttr>(index);
444
+ if (!constIndex)
445
+ return {};
446
+ int64_t gepIndex = constIndex.getInt ();
447
+ // Negative indices are not supported.
448
+ if (gepIndex < 0 )
449
+ return {};
450
+ indices.push_back (gepIndex);
451
+ }
452
+
453
+ Type currentType = gep.getElemType ();
454
+ uint64_t offset = indices[0 ] * dataLayout.getTypeSize (currentType);
455
+
456
+ for (uint64_t index : llvm::drop_begin (indices)) {
457
+ bool shouldCancel =
458
+ TypeSwitch<Type, bool >(currentType)
459
+ .Case ([&](LLVM::LLVMArrayType arrayType) {
460
+ offset +=
461
+ index * dataLayout.getTypeSize (arrayType.getElementType ());
462
+ currentType = arrayType.getElementType ();
463
+ return false ;
464
+ })
465
+ .Case ([&](LLVM::LLVMStructType structType) {
466
+ ArrayRef<Type> body = structType.getBody ();
467
+ assert (index < body.size () && " expected valid struct indexing" );
468
+ for (uint32_t i : llvm::seq (index)) {
469
+ if (!structType.isPacked ())
470
+ offset = llvm::alignTo (
471
+ offset, dataLayout.getTypeABIAlignment (body[i]));
472
+ offset += dataLayout.getTypeSize (body[i]);
473
+ }
474
+
475
+ // Align for the current type as well.
476
+ if (!structType.isPacked ())
477
+ offset = llvm::alignTo (
478
+ offset, dataLayout.getTypeABIAlignment (body[index]));
479
+ currentType = body[index];
480
+ return false ;
481
+ })
482
+ .Default ([&](Type type) {
483
+ LLVM_DEBUG (llvm::dbgs ()
484
+ << " [sroa] Unsupported type for offset computations"
485
+ << type << " \n " );
486
+ return true ;
487
+ });
488
+
489
+ if (shouldCancel)
490
+ return std::nullopt;
491
+ }
492
+
493
+ return offset;
494
+ }
495
+
496
+ namespace {
497
+ // / A struct that stores both the index into the aggregate type of the slot as
498
+ // / well as the corresponding byte offset in memory.
499
+ struct SubslotAccessInfo {
500
+ // / The parent slot's index that the access falls into.
501
+ uint32_t index;
502
+ // / The offset into the subslot of the access.
503
+ uint64_t subslotOffset;
504
+ };
505
+ } // namespace
506
+
507
+ // / Computes subslot access information for an access into `slot` with the given
508
+ // / offset.
509
+ // / Returns nullopt when the offset is out-of-bounds or when the access is into
510
+ // / the padding of `slot`.
511
+ static std::optional<SubslotAccessInfo>
512
+ getSubslotAccessInfo (const DestructurableMemorySlot &slot,
513
+ const DataLayout &dataLayout, LLVM::GEPOp gep) {
514
+ std::optional<uint64_t > offset = gepToByteOffset (dataLayout, gep);
515
+ if (!offset)
516
+ return {};
517
+
518
+ // Helper to check that a constant index is in the bounds of the GEP index
519
+ // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus
520
+ // this additional check is necessary.
521
+ auto isOutOfBoundsGEPIndex = [](uint64_t index) {
522
+ return index >= (1 << LLVM::kGEPConstantBitWidth );
523
+ };
524
+
525
+ Type type = slot.elemType ;
526
+ if (*offset >= dataLayout.getTypeSize (type))
527
+ return {};
528
+ return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type)
529
+ .Case ([&](LLVM::LLVMArrayType arrayType)
530
+ -> std::optional<SubslotAccessInfo> {
531
+ // Find which element of the array contains the offset.
532
+ uint64_t elemSize = dataLayout.getTypeSize (arrayType.getElementType ());
533
+ uint64_t index = *offset / elemSize;
534
+ if (isOutOfBoundsGEPIndex (index))
535
+ return {};
536
+ return SubslotAccessInfo{static_cast <uint32_t >(index),
537
+ *offset - (index * elemSize)};
538
+ })
539
+ .Case ([&](LLVM::LLVMStructType structType)
540
+ -> std::optional<SubslotAccessInfo> {
541
+ uint64_t distanceToStart = 0 ;
542
+ // Walk over the elements of the struct to find in which of
543
+ // them the offset is.
544
+ for (auto [index, elem] : llvm::enumerate (structType.getBody ())) {
545
+ uint64_t elemSize = dataLayout.getTypeSize (elem);
546
+ if (!structType.isPacked ()) {
547
+ distanceToStart = llvm::alignTo (
548
+ distanceToStart, dataLayout.getTypeABIAlignment (elem));
549
+ // If the offset is in padding, cancel the rewrite.
550
+ if (offset < distanceToStart)
551
+ return {};
552
+ }
553
+
554
+ if (offset < distanceToStart + elemSize) {
555
+ if (isOutOfBoundsGEPIndex (index))
556
+ return {};
557
+ // The offset is within this element, stop iterating the
558
+ // struct and return the index.
559
+ return SubslotAccessInfo{static_cast <uint32_t >(index),
560
+ *offset - distanceToStart};
561
+ }
562
+
563
+ // The offset is not within this element, continue walking
564
+ // over the struct.
565
+ distanceToStart += elemSize;
566
+ }
567
+
568
+ return {};
569
+ });
570
+ }
571
+
572
+ // / Constructs a byte array type of the given size.
573
+ static LLVM::LLVMArrayType getByteArrayType (MLIRContext *context,
574
+ unsigned size) {
575
+ auto byteType = IntegerType::get (context, 8 );
576
+ return LLVM::LLVMArrayType::get (context, byteType, size);
438
577
}
439
578
440
579
LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses (
441
580
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
442
581
const DataLayout &dataLayout) {
443
582
if (getBase () != slot.ptr )
444
583
return success ();
445
- if (slot.elemType != getElemType ())
446
- return failure ();
447
- if (!isFirstIndexZero (*this ))
584
+ std::optional<uint64_t > gepOffset = gepToByteOffset (dataLayout, *this );
585
+ if (!gepOffset)
448
586
return failure ();
449
- // Dynamic indices can be out-of-bounds (even negative), so an access with
450
- // dynamic indices can never be considered safe .
451
- if (! getDynamicIndices (). empty () )
587
+ uint64_t slotSize = dataLayout. getTypeSize (slot. elemType );
588
+ // Check that the access is strictly inside the slot .
589
+ if (*gepOffset >= slotSize )
452
590
return failure ();
453
- Type reachedType = getResultPtrElementType ();
454
- if (!reachedType)
455
- return failure ();
456
- mustBeSafelyUsed. emplace_back <MemorySlot>({ getResult (), reachedType });
591
+ // Every access that remains in bounds of the remaining slot is considered
592
+ // legal.
593
+ mustBeSafelyUsed. emplace_back <MemorySlot>(
594
+ { getRes (), getByteArrayType ( getContext (), slotSize - *gepOffset) });
457
595
return success ();
458
596
}
459
597
@@ -464,60 +602,45 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
464
602
if (!isa<LLVM::LLVMPointerType>(getBase ().getType ()))
465
603
return false ;
466
604
467
- if (getBase () != slot.ptr || slot.elemType != getElemType ())
468
- return false ;
469
- if (!isFirstIndexZero (*this ))
470
- return false ;
471
- // Dynamic indices can be out-of-bounds (even negative), so an access with
472
- // dynamic indices can never be properly rewired.
473
- if (!getDynamicIndices ().empty ())
474
- return false ;
475
- Type reachedType = getResultPtrElementType ();
476
- if (!reachedType || getIndices ().size () < 2 )
605
+ if (getBase () != slot.ptr )
477
606
return false ;
478
- auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices ()[1 ]);
479
- if (!firstLevelIndex)
607
+ std::optional<SubslotAccessInfo> accessInfo =
608
+ getSubslotAccessInfo (slot, dataLayout, *this );
609
+ if (!accessInfo)
480
610
return false ;
481
- mustBeSafelyUsed.emplace_back <MemorySlot>({getResult (), reachedType});
482
- assert (slot.elementPtrs .contains (firstLevelIndex));
483
- usedIndices.insert (firstLevelIndex);
611
+ auto indexAttr =
612
+ IntegerAttr::get (IntegerType::get (getContext (), 32 ), accessInfo->index );
613
+ assert (slot.elementPtrs .contains (indexAttr));
614
+ usedIndices.insert (indexAttr);
615
+
616
+ // The remainder of the subslot should be accesses in-bounds. Thus, we create
617
+ // a dummy slot with the size of the remainder.
618
+ Type subslotType = slot.elementPtrs .lookup (indexAttr);
619
+ uint64_t slotSize = dataLayout.getTypeSize (subslotType);
620
+ LLVM::LLVMArrayType remainingSlotType =
621
+ getByteArrayType (getContext (), slotSize - accessInfo->subslotOffset );
622
+ mustBeSafelyUsed.emplace_back <MemorySlot>({getRes (), remainingSlotType});
623
+
484
624
return true ;
485
625
}
486
626
487
627
DeletionKind LLVM::GEPOp::rewire (const DestructurableMemorySlot &slot,
488
628
DenseMap<Attribute, MemorySlot> &subslots,
489
629
RewriterBase &rewriter,
490
630
const DataLayout &dataLayout) {
491
- IntegerAttr firstLevelIndex =
492
- llvm::dyn_cast_if_present<IntegerAttr>(getIndices ()[1 ]);
493
- const MemorySlot &newSlot = subslots.at (firstLevelIndex);
494
-
495
- ArrayRef<int32_t > remainingIndices = getRawConstantIndices ().slice (2 );
496
-
497
- // If the GEP would become trivial after this transformation, eliminate it.
498
- // A GEP should only be eliminated if it has no indices (except the first
499
- // pointer index), as simplifying GEPs with all-zero indices would eliminate
500
- // structure information useful for further destruction.
501
- if (remainingIndices.empty ()) {
502
- rewriter.replaceAllUsesWith (getResult (), newSlot.ptr );
503
- return DeletionKind::Delete;
504
- }
505
-
506
- rewriter.modifyOpInPlace (*this , [&]() {
507
- // Rewire the indices by popping off the second index.
508
- // Start with a single zero, then add the indices beyond the second.
509
- SmallVector<int32_t > newIndices (1 );
510
- newIndices.append (remainingIndices.begin (), remainingIndices.end ());
511
- setRawConstantIndices (newIndices);
512
-
513
- // Rewire the pointed type.
514
- setElemType (newSlot.elemType );
515
-
516
- // Rewire the pointer.
517
- getBaseMutable ().assign (newSlot.ptr );
518
- });
519
-
520
- return DeletionKind::Keep;
631
+ std::optional<SubslotAccessInfo> accessInfo =
632
+ getSubslotAccessInfo (slot, dataLayout, *this );
633
+ assert (accessInfo && " expected access info to be checked before" );
634
+ auto indexAttr =
635
+ IntegerAttr::get (IntegerType::get (getContext (), 32 ), accessInfo->index );
636
+ const MemorySlot &newSlot = subslots.at (indexAttr);
637
+
638
+ auto byteType = IntegerType::get (rewriter.getContext (), 8 );
639
+ auto newPtr = rewriter.createOrFold <LLVM::GEPOp>(
640
+ getLoc (), getResult ().getType (), byteType, newSlot.ptr ,
641
+ ArrayRef<GEPArg>(accessInfo->subslotOffset ), getInbounds ());
642
+ rewriter.replaceAllUsesWith (getResult (), newPtr);
643
+ return DeletionKind::Delete;
521
644
}
522
645
523
646
// ===----------------------------------------------------------------------===//
0 commit comments