Skip to content

[clang] Add builtin to clear padding bytes (prework for P0528R3) #75371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,12 @@ def IsWithinLifetime : LangBuiltin<"CXX_LANG"> {
let Prototype = "bool(void*)";
}

def ClearPadding : LangBuiltin<"CXX_LANG"> {
let Spellings = ["__builtin_clear_padding"];
let Attributes = [NoThrow, CustomTypeChecking];
let Prototype = "void(void*)";
}

// GCC exception builtins
def EHReturn : Builtin {
let Spellings = ["__builtin_eh_return"];
Expand Down
327 changes: 327 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/Support/ConvertUTF.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/TargetParser/AArch64TargetParser.h"
#include "llvm/TargetParser/X86TargetParser.h"
#include <algorithm>
#include <optional>
#include <utility>
#include <deque>
#include <vector>
#include <sstream>

using namespace clang;
using namespace CodeGen;
Expand Down Expand Up @@ -2554,6 +2560,320 @@ static RValue EmitHipStdParUnsupportedBuiltin(CodeGenFunction *CGF,
return RValue::get(CGF->Builder.CreateCall(UBF, Args));
}

namespace {


// PaddingClearer is a utility class that clears padding bits in a
// c++ type. It traverses the type recursively, collecting occupied
// bit intervals, and then compute the padding intervals.
// In the end, it clears the padding bits by writing zeros
// to the padding intervals bytes-by-bytes. If a byte only contains
// some padding bits, it writes zeros to only those bits. This is
// the case for bit-fields.
struct PaddingClearer {
PaddingClearer(CodeGenFunction &F)
: CGF(F), CharWidth(CGF.getContext().getCharWidth()) {}

void run(Value *Ptr, QualType Ty) {
OccuppiedIntervals.clear();
Queue.clear();

Queue.push_back(Data{0, Ty, true});
while (!Queue.empty()) {
auto Current = Queue.back();
Queue.pop_back();
Visit(Current);
}

MergeOccuppiedIntervals();
auto PaddingIntervals =
GetPaddingIntervals(CGF.getContext().getTypeSize(Ty));
llvm::dbgs() << "Occuppied Bits:\n";
for (auto [first, last] : OccuppiedIntervals) {
llvm::dbgs() << "[" << first << ", " << last << ")\n";
}
llvm::dbgs() << "Padding Bits:\n";
for (auto [first, last] : PaddingIntervals) {
llvm::dbgs() << "[" << first << ", " << last << ")\n";
}

for (const auto &Interval : PaddingIntervals) {
ClearPadding(Ptr, Interval);
}
}

private:
struct BitInterval {
// [First, Last)
uint64_t First;
uint64_t Last;
};

struct Data {
uint64_t StartBitOffset;
QualType Ty;
bool VisitVirtualBase;
};

void Visit(Data const &D) {
if (auto *AT = dyn_cast<ConstantArrayType>(D.Ty)) {
VisitArray(AT, D.StartBitOffset);
return;
}

if (auto *Record = D.Ty->getAsCXXRecordDecl()) {
VisitStruct(Record, D.StartBitOffset, D.VisitVirtualBase);
return;
}

if (D.Ty->isAtomicType()) {
auto Unwrapped = D;
Unwrapped.Ty = D.Ty.getAtomicUnqualifiedType();
Queue.push_back(Unwrapped);
return;
}

if (const auto *Complex = D.Ty->getAs<ComplexType>()) {
VisitComplex(Complex, D.StartBitOffset);
return;
}

auto *Type = CGF.ConvertTypeForMem(D.Ty);
auto SizeBit = CGF.CGM.getModule()
.getDataLayout()
.getTypeSizeInBits(Type)
.getKnownMinValue();
llvm::dbgs() << "clear_padding primitive type. adding Interval ["
<< D.StartBitOffset << ", " << D.StartBitOffset + SizeBit
<< ")\n";
OccuppiedIntervals.push_back(
BitInterval{D.StartBitOffset, D.StartBitOffset + SizeBit});
}

void VisitArray(const ConstantArrayType *AT, uint64_t StartBitOffset) {
llvm::dbgs() << "clear_padding visiting constant array starting from "
<< StartBitOffset << "\n";
for (uint64_t ArrIndex = 0; ArrIndex < AT->getSize().getLimitedValue();
++ArrIndex) {

QualType ElementQualType = AT->getElementType();
auto ElementSize = CGF.getContext().getTypeSizeInChars(ElementQualType);
auto ElementAlign = CGF.getContext().getTypeAlignInChars(ElementQualType);
auto Offset = ElementSize.alignTo(ElementAlign);

Queue.push_back(
Data{StartBitOffset + ArrIndex * Offset.getQuantity() * CharWidth,
ElementQualType, /*VisitVirtualBase*/true});
}
}

void VisitStruct(const CXXRecordDecl *R, uint64_t StartBitOffset,
bool VisitVirtualBase) {
llvm::dbgs() << "clear_padding visiting struct: "
<< R->getQualifiedNameAsString() << " starting from offset "
<< StartBitOffset << '\n';
const auto &DL = CGF.CGM.getModule().getDataLayout();

const ASTRecordLayout &ASTLayout = CGF.getContext().getASTRecordLayout(R);
if (ASTLayout.hasOwnVFPtr()) {
llvm::dbgs()
<< "clear_padding found vtable ptr. Adding occuppied interval ["
<< StartBitOffset << ", "
<< (StartBitOffset + DL.getPointerSizeInBits()) << ")\n";
OccuppiedIntervals.push_back(BitInterval{
StartBitOffset, StartBitOffset + DL.getPointerSizeInBits()});
}

const auto VisitBase = [&ASTLayout, StartBitOffset, this](
const CXXBaseSpecifier &Base, auto GetOffset) {
auto *BaseRecord = Base.getType()->getAsCXXRecordDecl();
if (!BaseRecord) {
llvm::dbgs() << "Base is not a CXXRecord!\n";
return;
}
auto BaseOffset =
std::invoke(GetOffset, ASTLayout, BaseRecord).getQuantity();

llvm::dbgs() << "visiting base at offset " << StartBitOffset << " + "
<< BaseOffset * CharWidth << '\n';
Queue.push_back(Data{StartBitOffset + BaseOffset * CharWidth,
Base.getType(), /*VisitVirtualBase*/ false});
};

for (auto Base : R->bases()) {
if (!Base.isVirtual()) {
VisitBase(Base, &ASTRecordLayout::getBaseClassOffset);
}
}

if (VisitVirtualBase) {
for (auto VBase : R->vbases()) {
VisitBase(VBase, &ASTRecordLayout::getVBaseClassOffset);
}
}

for (auto *Field : R->fields()) {
auto FieldOffset = ASTLayout.getFieldOffset(Field->getFieldIndex());
llvm::dbgs() << "visiting field at offset " << StartBitOffset << " + "
<< FieldOffset << '\n';
if (Field->isBitField()) {
llvm::dbgs() << "clear_padding found bit field. Adding Interval ["
<< StartBitOffset + FieldOffset << " , "
<< FieldOffset + Field->getBitWidthValue()
<< ")\n";
OccuppiedIntervals.push_back(
BitInterval{StartBitOffset + FieldOffset,
StartBitOffset + FieldOffset +
Field->getBitWidthValue()});
} else {
Queue.push_back(Data{StartBitOffset + FieldOffset, Field->getType(),
/*VisitVirtualBase*/ true});
}
}
}

void VisitComplex(const ComplexType *CT, uint64_t StartBitOffset) {
QualType ElementQualType = CT->getElementType();
auto ElementSize = CGF.getContext().getTypeSizeInChars(ElementQualType);
auto ElementAlign = CGF.getContext().getTypeAlignInChars(ElementQualType);
auto ImgOffset = ElementSize.alignTo(ElementAlign);

llvm::dbgs() << "clear_padding visiting Complex Type. Real from "
<< StartBitOffset << "Img from "
<< StartBitOffset + ImgOffset.getQuantity() * CharWidth
<< "\n";
Queue.push_back(
Data{StartBitOffset, ElementQualType, /*VisitVirtualBase*/ true});
Queue.push_back(Data{StartBitOffset + ImgOffset.getQuantity() * CharWidth,
ElementQualType, /*VisitVirtualBase*/ true});
}

void MergeOccuppiedIntervals() {
std::sort(OccuppiedIntervals.begin(), OccuppiedIntervals.end(),
[](const BitInterval &lhs, const BitInterval &rhs) {
return std::tie(lhs.First, lhs.Last) <
std::tie(rhs.First, rhs.Last);
});

std::vector<BitInterval> Merged;
Merged.reserve(OccuppiedIntervals.size());

for (const BitInterval &NextInterval : OccuppiedIntervals) {
if (Merged.empty()) {
Merged.push_back(NextInterval);
continue;
}
auto &LastInterval = Merged.back();

if (NextInterval.First > LastInterval.Last) {
Merged.push_back(NextInterval);
} else {
LastInterval.Last = std::max(LastInterval.Last, NextInterval.Last);
}
}

OccuppiedIntervals = Merged;
}

std::vector<BitInterval> GetPaddingIntervals(uint64_t SizeInBits) const {
std::vector<BitInterval> Results;
if (OccuppiedIntervals.size() == 1 &&
OccuppiedIntervals.front().First == 0 &&
OccuppiedIntervals.end()->Last == SizeInBits) {
return Results;
}
Results.reserve(OccuppiedIntervals.size() + 1);
uint64_t CurrentPos = 0;
for (const BitInterval &OccupiedInterval : OccuppiedIntervals) {
if (OccupiedInterval.First > CurrentPos) {
Results.push_back(BitInterval{CurrentPos, OccupiedInterval.First});
}
CurrentPos = OccupiedInterval.Last;
}
if (SizeInBits > CurrentPos) {
Results.push_back(BitInterval{CurrentPos, SizeInBits});
}
return Results;
}



void ClearPadding(Value *Ptr, const BitInterval &PaddingInterval) {
auto *I8Ptr = CGF.Builder.CreateBitCast(Ptr, CGF.Int8PtrTy);
auto *Zero = ConstantInt::get(CGF.Int8Ty, 0);

// Calculate byte indices and bit positions
auto StartByte = PaddingInterval.First / CharWidth;
auto StartBit = PaddingInterval.First % CharWidth;
auto EndByte = PaddingInterval.Last / CharWidth;
auto EndBit = PaddingInterval.Last % CharWidth;

if (StartByte == EndByte) {
// Interval is within a single byte
auto *Index = ConstantInt::get(CGF.IntTy, StartByte);
auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());

auto *Value = CGF.Builder.CreateLoad(ElementAddr);

// Create mask to clear bits within the byte
uint8_t mask = ((1 << EndBit) - 1) & ~((1 << StartBit) - 1);
auto *MaskValue = ConstantInt::get(CGF.Int8Ty, mask);
auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);

CGF.Builder.CreateStore(NewValue, ElementAddr);
} else {
// Handle the start byte
if (StartBit != 0) {
auto *Index = ConstantInt::get(CGF.IntTy, StartByte);
auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());

auto *Value = CGF.Builder.CreateLoad(ElementAddr);

uint8_t startMask = ((1 << (CharWidth - StartBit)) - 1) << StartBit;
auto *MaskValue = ConstantInt::get(CGF.Int8Ty, ~startMask);
auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);

CGF.Builder.CreateStore(NewValue, ElementAddr);
++StartByte;
}

// Handle full bytes in the middle
for (auto Offset = StartByte; Offset < EndByte; ++Offset) {
auto *Index = ConstantInt::get(CGF.IntTy, Offset);
auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());

CGF.Builder.CreateStore(Zero, ElementAddr);
}

// Handle the end byte
if (EndBit != 0) {
auto *Index = ConstantInt::get(CGF.IntTy, EndByte);
auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());

auto *Value = CGF.Builder.CreateLoad(ElementAddr);

uint8_t endMask = (1 << EndBit) - 1;
auto *MaskValue = ConstantInt::get(CGF.Int8Ty, endMask);
auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);

CGF.Builder.CreateStore(NewValue, ElementAddr);
}
}
}


CodeGenFunction &CGF;
const uint64_t CharWidth;
std::deque<Data> Queue;
std::vector<BitInterval> OccuppiedIntervals;
};

} // namespace

RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
const CallExpr *E,
ReturnValueSlot ReturnValue) {
Expand Down Expand Up @@ -4766,6 +5086,13 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,

return RValue::get(Ptr);
}
case Builtin::BI__builtin_clear_padding: {
Address Src = EmitPointerWithAlignment(E->getArg(0));
auto PointeeTy = E->getArg(0)->getType()->getPointeeType();
PaddingClearer clearer{*this};
clearer.run(Src.getBasePointer(), PointeeTy);
return RValue::get(nullptr);
}
case Builtin::BI__sync_fetch_and_add:
case Builtin::BI__sync_fetch_and_sub:
case Builtin::BI__sync_fetch_and_or:
Expand Down
Loading
Loading