Skip to content

Commit df50751

Browse files
authored
[SandboxIR] Implement ConstantAggregateZero (#107172)
This patch implements sandboxir::ConstantAggregateZero mirroring llvm::ConstantAggregateZero.
1 parent 98c6bbf commit df50751

File tree

7 files changed

+186
-16
lines changed

7 files changed

+186
-16
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ namespace sandboxir {
120120
class BasicBlock;
121121
class ConstantInt;
122122
class ConstantFP;
123+
class ConstantAggregateZero;
123124
class Context;
124125
class Function;
125126
class Instruction;
@@ -316,6 +317,7 @@ class Value {
316317
friend class CmpInst; // For getting `Val`.
317318
friend class ConstantArray; // For `Val`.
318319
friend class ConstantStruct; // For `Val`.
320+
friend class ConstantAggregateZero; // For `Val`.
319321

320322
/// All values point to the context.
321323
Context &Ctx;
@@ -943,6 +945,48 @@ class ConstantVector final : public ConstantAggregate {
943945
}
944946
};
945947

948+
// TODO: Inherit from ConstantData.
949+
class ConstantAggregateZero final : public Constant {
950+
ConstantAggregateZero(llvm::ConstantAggregateZero *C, Context &Ctx)
951+
: Constant(ClassID::ConstantAggregateZero, C, Ctx) {}
952+
friend class Context; // For constructor.
953+
954+
public:
955+
static ConstantAggregateZero *get(Type *Ty);
956+
/// If this CAZ has array or vector type, return a zero with the right element
957+
/// type.
958+
Constant *getSequentialElement() const;
959+
/// If this CAZ has struct type, return a zero with the right element type for
960+
/// the specified element.
961+
Constant *getStructElement(unsigned Elt) const;
962+
/// Return a zero of the right value for the specified GEP index if we can,
963+
/// otherwise return null (e.g. if C is a ConstantExpr).
964+
Constant *getElementValue(Constant *C) const;
965+
/// Return a zero of the right value for the specified GEP index.
966+
Constant *getElementValue(unsigned Idx) const;
967+
/// Return the number of elements in the array, vector, or struct.
968+
ElementCount getElementCount() const {
969+
return cast<llvm::ConstantAggregateZero>(Val)->getElementCount();
970+
}
971+
972+
/// For isa/dyn_cast.
973+
static bool classof(const sandboxir::Value *From) {
974+
return From->getSubclassID() == ClassID::ConstantAggregateZero;
975+
}
976+
unsigned getUseOperandNo(const Use &Use) const final {
977+
llvm_unreachable("ConstantAggregateZero has no operands!");
978+
}
979+
#ifndef NDEBUG
980+
void verify() const override {
981+
assert(isa<llvm::ConstantAggregateZero>(Val) && "Expected a CAZ!");
982+
}
983+
void dumpOS(raw_ostream &OS) const override {
984+
dumpCommonPrefix(OS);
985+
dumpCommonSuffix(OS);
986+
}
987+
#endif
988+
};
989+
946990
/// Iterator for `Instruction`s in a `BasicBlock.
947991
/// \Returns an sandboxir::Instruction & when derereferenced.
948992
class BBIterator {

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ DEF_CONST(ConstantFP, ConstantFP)
3030
DEF_CONST(ConstantArray, ConstantArray)
3131
DEF_CONST(ConstantStruct, ConstantStruct)
3232
DEF_CONST(ConstantVector, ConstantVector)
33+
DEF_CONST(ConstantAggregateZero, ConstantAggregateZero)
3334

3435
#ifndef DEF_INSTR
3536
#define DEF_INSTR(ID, OPCODE, CLASS)

llvm/include/llvm/SandboxIR/Type.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class PointerType : public Type {
293293

294294
class ArrayType : public Type {
295295
public:
296+
static ArrayType *get(Type *ElementType, uint64_t NumElements);
296297
// TODO: add missing functions
297298
static bool classof(const Type *From) {
298299
return isa<llvm::ArrayType>(From->LLVMTy);

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,6 +2402,30 @@ StructType *ConstantStruct::getTypeForElements(Context &Ctx,
24022402
return StructType::get(Ctx, EltTypes, Packed);
24032403
}
24042404

2405+
ConstantAggregateZero *ConstantAggregateZero::get(Type *Ty) {
2406+
auto *LLVMC = llvm::ConstantAggregateZero::get(Ty->LLVMTy);
2407+
return cast<ConstantAggregateZero>(
2408+
Ty->getContext().getOrCreateConstant(LLVMC));
2409+
}
2410+
2411+
Constant *ConstantAggregateZero::getSequentialElement() const {
2412+
return cast<Constant>(Ctx.getValue(
2413+
cast<llvm::ConstantAggregateZero>(Val)->getSequentialElement()));
2414+
}
2415+
Constant *ConstantAggregateZero::getStructElement(unsigned Elt) const {
2416+
return cast<Constant>(Ctx.getValue(
2417+
cast<llvm::ConstantAggregateZero>(Val)->getStructElement(Elt)));
2418+
}
2419+
Constant *ConstantAggregateZero::getElementValue(Constant *C) const {
2420+
return cast<Constant>(
2421+
Ctx.getValue(cast<llvm::ConstantAggregateZero>(Val)->getElementValue(
2422+
cast<llvm::Constant>(C->Val))));
2423+
}
2424+
Constant *ConstantAggregateZero::getElementValue(unsigned Idx) const {
2425+
return cast<Constant>(Ctx.getValue(
2426+
cast<llvm::ConstantAggregateZero>(Val)->getElementValue(Idx)));
2427+
}
2428+
24052429
FunctionType *Function::getFunctionType() const {
24062430
return cast<FunctionType>(
24072431
Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
@@ -2489,26 +2513,48 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
24892513
return It->second.get();
24902514

24912515
if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
2492-
if (auto *CI = dyn_cast<llvm::ConstantInt>(C)) {
2493-
It->second = std::unique_ptr<ConstantInt>(new ConstantInt(CI, *this));
2516+
switch (C->getValueID()) {
2517+
case llvm::Value::ConstantIntVal:
2518+
It->second = std::unique_ptr<ConstantInt>(
2519+
new ConstantInt(cast<llvm::ConstantInt>(C), *this));
24942520
return It->second.get();
2495-
}
2496-
if (auto *CF = dyn_cast<llvm::ConstantFP>(C)) {
2497-
It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
2521+
case llvm::Value::ConstantFPVal:
2522+
It->second = std::unique_ptr<ConstantFP>(
2523+
new ConstantFP(cast<llvm::ConstantFP>(C), *this));
24982524
return It->second.get();
2525+
case llvm::Value::ConstantAggregateZeroVal: {
2526+
auto *CAZ = cast<llvm::ConstantAggregateZero>(C);
2527+
It->second = std::unique_ptr<ConstantAggregateZero>(
2528+
new ConstantAggregateZero(CAZ, *this));
2529+
auto *Ret = It->second.get();
2530+
// Must create sandboxir for elements.
2531+
auto EC = CAZ->getElementCount();
2532+
if (EC.isFixed()) {
2533+
for (auto ElmIdx : seq<unsigned>(0, EC.getFixedValue()))
2534+
getOrCreateValueInternal(CAZ->getElementValue(ElmIdx), CAZ);
2535+
}
2536+
return Ret;
24992537
}
2500-
if (auto *CA = dyn_cast<llvm::ConstantArray>(C))
2501-
It->second = std::unique_ptr<ConstantArray>(new ConstantArray(CA, *this));
2502-
else if (auto *CS = dyn_cast<llvm::ConstantStruct>(C))
2503-
It->second =
2504-
std::unique_ptr<ConstantStruct>(new ConstantStruct(CS, *this));
2505-
else if (auto *CV = dyn_cast<llvm::ConstantVector>(C))
2506-
It->second =
2507-
std::unique_ptr<ConstantVector>(new ConstantVector(CV, *this));
2508-
else if (auto *F = dyn_cast<llvm::Function>(LLVMV))
2509-
It->second = std::unique_ptr<Function>(new Function(F, *this));
2510-
else
2538+
case llvm::Value::ConstantArrayVal:
2539+
It->second = std::unique_ptr<ConstantArray>(
2540+
new ConstantArray(cast<llvm::ConstantArray>(C), *this));
2541+
break;
2542+
case llvm::Value::ConstantStructVal:
2543+
It->second = std::unique_ptr<ConstantStruct>(
2544+
new ConstantStruct(cast<llvm::ConstantStruct>(C), *this));
2545+
break;
2546+
case llvm::Value::ConstantVectorVal:
2547+
It->second = std::unique_ptr<ConstantVector>(
2548+
new ConstantVector(cast<llvm::ConstantVector>(C), *this));
2549+
break;
2550+
case llvm::Value::FunctionVal:
2551+
It->second = std::unique_ptr<Function>(
2552+
new Function(cast<llvm::Function>(C), *this));
2553+
break;
2554+
default:
25112555
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
2556+
break;
2557+
}
25122558
auto *NewC = It->second.get();
25132559
for (llvm::Value *COp : C->operands())
25142560
getOrCreateValueInternal(COp, C);

llvm/lib/SandboxIR/Type.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
4747
Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
4848
}
4949

50+
ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) {
51+
return cast<ArrayType>(ElementType->getContext().getType(
52+
llvm::ArrayType::get(ElementType->LLVMTy, NumElements)));
53+
}
54+
5055
StructType *StructType::get(Context &Ctx, ArrayRef<Type *> Elements,
5156
bool IsPacked) {
5257
SmallVector<llvm::Type *> LLVMElements;

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,75 @@ define void @foo() {
520520
EXPECT_EQ(StructTy2Packed, StructTyPacked);
521521
}
522522

523+
TEST_F(SandboxIRTest, ConstantAggregateZero) {
524+
parseIR(C, R"IR(
525+
define void @foo(ptr %ptr, {i32, i8} %v1, <2 x i8> %v2) {
526+
%extr0 = extractvalue [2 x i8] zeroinitializer, 0
527+
%extr1 = extractvalue {i32, i8} zeroinitializer, 0
528+
%extr2 = extractelement <2 x i8> zeroinitializer, i32 0
529+
ret void
530+
}
531+
)IR");
532+
Function &LLVMF = *M->getFunction("foo");
533+
sandboxir::Context Ctx(C);
534+
535+
auto &F = *Ctx.createFunction(&LLVMF);
536+
auto &BB = *F.begin();
537+
auto It = BB.begin();
538+
auto *Extr0 = &*It++;
539+
auto *Extr1 = &*It++;
540+
auto *Extr2 = &*It++;
541+
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
542+
auto *Zero32 =
543+
sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0);
544+
auto *Zero8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0);
545+
auto *Int8Ty = sandboxir::Type::getInt8Ty(Ctx);
546+
auto *Int32Ty = sandboxir::Type::getInt32Ty(Ctx);
547+
auto *ArrayTy = sandboxir::ArrayType::get(Int8Ty, 2u);
548+
auto *StructTy = sandboxir::StructType::get(Ctx, {Int32Ty, Int8Ty});
549+
auto *VectorTy =
550+
sandboxir::VectorType::get(Int8Ty, ElementCount::getFixed(2u));
551+
552+
// Check creation and classof().
553+
auto *ArrayCAZ = cast<sandboxir::ConstantAggregateZero>(Extr0->getOperand(0));
554+
EXPECT_EQ(ArrayCAZ->getType(), ArrayTy);
555+
auto *StructCAZ =
556+
cast<sandboxir::ConstantAggregateZero>(Extr1->getOperand(0));
557+
EXPECT_EQ(StructCAZ->getType(), StructTy);
558+
auto *VectorCAZ =
559+
cast<sandboxir::ConstantAggregateZero>(Extr2->getOperand(0));
560+
EXPECT_EQ(VectorCAZ->getType(), VectorTy);
561+
// Check get().
562+
auto *SameVectorCAZ =
563+
sandboxir::ConstantAggregateZero::get(sandboxir::VectorType::get(
564+
sandboxir::Type::getInt8Ty(Ctx), ElementCount::getFixed(2)));
565+
EXPECT_EQ(SameVectorCAZ, VectorCAZ); // Should be uniqued.
566+
auto *NewVectorCAZ =
567+
sandboxir::ConstantAggregateZero::get(sandboxir::VectorType::get(
568+
sandboxir::Type::getInt8Ty(Ctx), ElementCount::getFixed(4)));
569+
EXPECT_NE(NewVectorCAZ, VectorCAZ);
570+
// Check getSequentialElement().
571+
auto *SeqElm = VectorCAZ->getSequentialElement();
572+
EXPECT_EQ(SeqElm,
573+
sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0));
574+
// Check getStructElement().
575+
auto *StructElm0 = StructCAZ->getStructElement(0);
576+
auto *StructElm1 = StructCAZ->getStructElement(1);
577+
EXPECT_EQ(StructElm0, Zero32);
578+
EXPECT_EQ(StructElm1, Zero8);
579+
// Check getElementValue(Constant).
580+
EXPECT_EQ(ArrayCAZ->getElementValue(Zero32), Zero8);
581+
EXPECT_EQ(StructCAZ->getElementValue(Zero32), Zero32);
582+
EXPECT_EQ(VectorCAZ->getElementValue(Zero32), Zero8);
583+
// Check getElementValue(unsigned).
584+
EXPECT_EQ(ArrayCAZ->getElementValue(0u), Zero8);
585+
EXPECT_EQ(StructCAZ->getElementValue(0u), Zero32);
586+
EXPECT_EQ(VectorCAZ->getElementValue(0u), Zero8);
587+
// Check getElementCount().
588+
EXPECT_EQ(ArrayCAZ->getElementCount(), ElementCount::getFixed(2));
589+
EXPECT_EQ(NewVectorCAZ->getElementCount(), ElementCount::getFixed(4));
590+
}
591+
523592
TEST_F(SandboxIRTest, Use) {
524593
parseIR(C, R"IR(
525594
define i32 @foo(i32 %v0, i32 %v1) {

llvm/unittests/SandboxIR/TypesTest.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ define void @foo([2 x i8] %v0) {
236236
// Check classof(), creation.
237237
[[maybe_unused]] auto *ArrayTy =
238238
cast<sandboxir::ArrayType>(F->getArg(0)->getType());
239+
// Check get().
240+
auto *NewArrayTy =
241+
sandboxir::ArrayType::get(sandboxir::Type::getInt8Ty(Ctx), 2u);
242+
EXPECT_EQ(NewArrayTy, ArrayTy);
239243
}
240244

241245
TEST_F(SandboxTypeTest, StructType) {

0 commit comments

Comments
 (0)