Skip to content

Commit 814aa43

Browse files
authored
[SandboxIR] Implement ConstantAggregate (#107136)
This patch implements sandboxir:: ConstantAggregate, ConstantStruct, ConstantArray and ConstantVector, mirroring LLVM IR.
1 parent 83ad644 commit 814aa43

File tree

7 files changed

+310
-8
lines changed

7 files changed

+310
-8
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ class Value {
304304
friend class PHINode; // For getting `Val`.
305305
friend class UnreachableInst; // For getting `Val`.
306306
friend class CatchSwitchAddHandler; // For `Val`.
307+
friend class ConstantArray; // For `Val`.
308+
friend class ConstantStruct; // For `Val`.
307309

308310
/// All values point to the context.
309311
Context &Ctx;
@@ -840,6 +842,97 @@ class ConstantFP final : public Constant {
840842
#endif
841843
};
842844

845+
/// Base class for aggregate constants (with operands).
846+
class ConstantAggregate : public Constant {
847+
protected:
848+
ConstantAggregate(ClassID ID, llvm::Constant *C, Context &Ctx)
849+
: Constant(ID, C, Ctx) {}
850+
851+
public:
852+
/// For isa/dyn_cast.
853+
static bool classof(const sandboxir::Value *From) {
854+
auto ID = From->getSubclassID();
855+
return ID == ClassID::ConstantVector || ID == ClassID::ConstantStruct ||
856+
ID == ClassID::ConstantArray;
857+
}
858+
};
859+
860+
class ConstantArray final : public ConstantAggregate {
861+
ConstantArray(llvm::ConstantArray *C, Context &Ctx)
862+
: ConstantAggregate(ClassID::ConstantArray, C, Ctx) {}
863+
friend class Context; // For constructor.
864+
865+
public:
866+
static Constant *get(ArrayType *T, ArrayRef<Constant *> V);
867+
ArrayType *getType() const;
868+
869+
// TODO: Missing functions: getType(), getTypeForElements(), getAnon(), get().
870+
871+
/// For isa/dyn_cast.
872+
static bool classof(const Value *From) {
873+
return From->getSubclassID() == ClassID::ConstantArray;
874+
}
875+
};
876+
877+
class ConstantStruct final : public ConstantAggregate {
878+
ConstantStruct(llvm::ConstantStruct *C, Context &Ctx)
879+
: ConstantAggregate(ClassID::ConstantStruct, C, Ctx) {}
880+
friend class Context; // For constructor.
881+
882+
public:
883+
static Constant *get(StructType *T, ArrayRef<Constant *> V);
884+
885+
template <typename... Csts>
886+
static std::enable_if_t<are_base_of<Constant, Csts...>::value, Constant *>
887+
get(StructType *T, Csts *...Vs) {
888+
return get(T, ArrayRef<Constant *>({Vs...}));
889+
}
890+
/// Return an anonymous struct that has the specified elements.
891+
/// If the struct is possibly empty, then you must specify a context.
892+
static Constant *getAnon(ArrayRef<Constant *> V, bool Packed = false) {
893+
return get(getTypeForElements(V, Packed), V);
894+
}
895+
static Constant *getAnon(Context &Ctx, ArrayRef<Constant *> V,
896+
bool Packed = false) {
897+
return get(getTypeForElements(Ctx, V, Packed), V);
898+
}
899+
/// This version of the method allows an empty list.
900+
static StructType *getTypeForElements(Context &Ctx, ArrayRef<Constant *> V,
901+
bool Packed = false);
902+
/// Return an anonymous struct type to use for a constant with the specified
903+
/// set of elements. The list must not be empty.
904+
static StructType *getTypeForElements(ArrayRef<Constant *> V,
905+
bool Packed = false) {
906+
assert(!V.empty() &&
907+
"ConstantStruct::getTypeForElements cannot be called on empty list");
908+
return getTypeForElements(V[0]->getContext(), V, Packed);
909+
}
910+
911+
/// Specialization - reduce amount of casting.
912+
inline StructType *getType() const {
913+
return cast<StructType>(Value::getType());
914+
}
915+
916+
/// For isa/dyn_cast.
917+
static bool classof(const Value *From) {
918+
return From->getSubclassID() == ClassID::ConstantStruct;
919+
}
920+
};
921+
922+
class ConstantVector final : public ConstantAggregate {
923+
ConstantVector(llvm::ConstantVector *C, Context &Ctx)
924+
: ConstantAggregate(ClassID::ConstantVector, C, Ctx) {}
925+
friend class Context; // For constructor.
926+
927+
public:
928+
// TODO: Missing functions: getSplat(), getType(), getSplatValue(), get().
929+
930+
/// For isa/dyn_cast.
931+
static bool classof(const Value *From) {
932+
return From->getSubclassID() == ClassID::ConstantVector;
933+
}
934+
};
935+
843936
/// Iterator for `Instruction`s in a `BasicBlock.
844937
/// \Returns an sandboxir::Instruction & when derereferenced.
845938
class BBIterator {
@@ -3353,6 +3446,7 @@ class Context {
33533446
friend class Type; // For LLVMCtx.
33543447
friend class PointerType; // For LLVMCtx.
33553448
friend class IntegerType; // For LLVMCtx.
3449+
friend class StructType; // For LLVMCtx.
33563450
Tracker IRTracker;
33573451

33583452
/// Maps LLVM Value to the corresponding sandboxir::Value. Owns all

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ DEF_VALUE(Block, BasicBlock)
2727
DEF_CONST(Constant, Constant)
2828
DEF_CONST(ConstantInt, ConstantInt)
2929
DEF_CONST(ConstantFP, ConstantFP)
30+
DEF_CONST(ConstantArray, ConstantArray)
31+
DEF_CONST(ConstantStruct, ConstantStruct)
32+
DEF_CONST(ConstantVector, ConstantVector)
3033

3134
#ifndef DEF_INSTR
3235
#define DEF_INSTR(ID, OPCODE, CLASS)

llvm/include/llvm/SandboxIR/Type.h

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class PointerType;
2727
class VectorType;
2828
class IntegerType;
2929
class FunctionType;
30+
class ArrayType;
31+
class StructType;
3032
#define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
3133
#define DEF_CONST(ID, CLASS) class CLASS;
3234
#include "llvm/SandboxIR/SandboxIRValues.def"
@@ -36,13 +38,19 @@ class FunctionType;
3638
class Type {
3739
protected:
3840
llvm::Type *LLVMTy;
39-
friend class VectorType; // For LLVMTy.
40-
friend class PointerType; // For LLVMTy.
41-
friend class FunctionType; // For LLVMTy.
42-
friend class IntegerType; // For LLVMTy.
43-
friend class Function; // For LLVMTy.
44-
friend class CallBase; // For LLVMTy.
45-
friend class ConstantInt; // For LLVMTy.
41+
friend class ArrayType; // For LLVMTy.
42+
friend class StructType; // For LLVMTy.
43+
friend class VectorType; // For LLVMTy.
44+
friend class PointerType; // For LLVMTy.
45+
friend class FunctionType; // For LLVMTy.
46+
friend class IntegerType; // For LLVMTy.
47+
friend class Function; // For LLVMTy.
48+
friend class CallBase; // For LLVMTy.
49+
friend class ConstantInt; // For LLVMTy.
50+
friend class ConstantArray; // For LLVMTy.
51+
friend class ConstantStruct; // For LLVMTy.
52+
friend class ConstantVector; // For LLVMTy.
53+
4654
// Friend all instruction classes because `create()` functions use LLVMTy.
4755
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
4856
#define DEF_CONST(ID, CLASS) friend class CLASS;
@@ -281,8 +289,31 @@ class PointerType : public Type {
281289
}
282290
};
283291

292+
class ArrayType : public Type {
293+
public:
294+
// TODO: add missing functions
295+
static bool classof(const Type *From) {
296+
return isa<llvm::ArrayType>(From->LLVMTy);
297+
}
298+
};
299+
300+
class StructType : public Type {
301+
public:
302+
/// This static method is the primary way to create a literal StructType.
303+
static StructType *get(Context &Ctx, ArrayRef<Type *> Elements,
304+
bool IsPacked = false);
305+
306+
bool isPacked() const { return cast<llvm::StructType>(LLVMTy)->isPacked(); }
307+
308+
// TODO: add missing functions
309+
static bool classof(const Type *From) {
310+
return isa<llvm::StructType>(From->LLVMTy);
311+
}
312+
};
313+
284314
class VectorType : public Type {
285315
public:
316+
static VectorType *get(Type *ElementType, ElementCount EC);
286317
// TODO: add missing functions
287318
static bool classof(const Type *From) {
288319
return isa<llvm::VectorType>(From->LLVMTy);

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2364,6 +2364,44 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat &V) {
23642364
return llvm::ConstantFP::isValueValidForType(Ty->LLVMTy, V);
23652365
}
23662366

2367+
Constant *ConstantArray::get(ArrayType *T, ArrayRef<Constant *> V) {
2368+
auto &Ctx = T->getContext();
2369+
SmallVector<llvm::Constant *> LLVMValues;
2370+
LLVMValues.reserve(V.size());
2371+
for (auto *Elm : V)
2372+
LLVMValues.push_back(cast<llvm::Constant>(Elm->Val));
2373+
auto *LLVMC =
2374+
llvm::ConstantArray::get(cast<llvm::ArrayType>(T->LLVMTy), LLVMValues);
2375+
return cast<ConstantArray>(Ctx.getOrCreateConstant(LLVMC));
2376+
}
2377+
2378+
ArrayType *ConstantArray::getType() const {
2379+
return cast<ArrayType>(
2380+
Ctx.getType(cast<llvm::ConstantArray>(Val)->getType()));
2381+
}
2382+
2383+
Constant *ConstantStruct::get(StructType *T, ArrayRef<Constant *> V) {
2384+
auto &Ctx = T->getContext();
2385+
SmallVector<llvm::Constant *> LLVMValues;
2386+
LLVMValues.reserve(V.size());
2387+
for (auto *Elm : V)
2388+
LLVMValues.push_back(cast<llvm::Constant>(Elm->Val));
2389+
auto *LLVMC =
2390+
llvm::ConstantStruct::get(cast<llvm::StructType>(T->LLVMTy), LLVMValues);
2391+
return cast<ConstantStruct>(Ctx.getOrCreateConstant(LLVMC));
2392+
}
2393+
2394+
StructType *ConstantStruct::getTypeForElements(Context &Ctx,
2395+
ArrayRef<Constant *> V,
2396+
bool Packed) {
2397+
unsigned VecSize = V.size();
2398+
SmallVector<Type *, 16> EltTypes;
2399+
EltTypes.reserve(VecSize);
2400+
for (Constant *Elm : V)
2401+
EltTypes.push_back(Elm->getType());
2402+
return StructType::get(Ctx, EltTypes, Packed);
2403+
}
2404+
23672405
FunctionType *Function::getFunctionType() const {
23682406
return cast<FunctionType>(
23692407
Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
@@ -2459,7 +2497,15 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
24592497
It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
24602498
return It->second.get();
24612499
}
2462-
if (auto *F = dyn_cast<llvm::Function>(LLVMV))
2500+
if (auto *CA = dyn_cast<llvm::ConstantArray>(C))
2501+
It->second = std::unique_ptr<ConstantArray>(new ConstantArray(CA, *this));
2502+
else if (auto *CS = dyn_cast<llvm::ConstantStruct>(C))
2503+
It->second =
2504+
std::unique_ptr<ConstantStruct>(new ConstantStruct(CS, *this));
2505+
else if (auto *CV = dyn_cast<llvm::ConstantVector>(C))
2506+
It->second =
2507+
std::unique_ptr<ConstantVector>(new ConstantVector(CV, *this));
2508+
else if (auto *F = dyn_cast<llvm::Function>(LLVMV))
24632509
It->second = std::unique_ptr<Function>(new Function(F, *this));
24642510
else
24652511
It->second = std::unique_ptr<Constant>(new Constant(C, *this));

llvm/lib/SandboxIR/Type.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
4747
Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
4848
}
4949

50+
StructType *StructType::get(Context &Ctx, ArrayRef<Type *> Elements,
51+
bool IsPacked) {
52+
SmallVector<llvm::Type *> LLVMElements;
53+
LLVMElements.reserve(Elements.size());
54+
for (Type *Elm : Elements)
55+
LLVMElements.push_back(Elm->LLVMTy);
56+
return cast<StructType>(
57+
Ctx.getType(llvm::StructType::get(Ctx.LLVMCtx, LLVMElements, IsPacked)));
58+
}
59+
60+
VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
61+
return cast<VectorType>(ElementType->getContext().getType(
62+
llvm::VectorType::get(ElementType->LLVMTy, EC)));
63+
}
64+
5065
IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
5166
return cast<IntegerType>(
5267
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,81 @@ define void @foo(float %v0, double %v1) {
445445
EXPECT_TRUE(NegZero->isExactlyValue(-0.0));
446446
}
447447

448+
// Tests ConstantArray, ConstantStruct and ConstantVector.
449+
TEST_F(SandboxIRTest, ConstantAggregate) {
450+
// Note: we are using i42 to avoid the creation of ConstantDataVector or
451+
// ConstantDataArray.
452+
parseIR(C, R"IR(
453+
define void @foo() {
454+
%array = extractvalue [2 x i42] [i42 0, i42 1], 0
455+
%struct = extractvalue {i42, i42} {i42 0, i42 1}, 0
456+
%vector = extractelement <2 x i42> <i42 0, i42 1>, i32 0
457+
ret void
458+
}
459+
)IR");
460+
Function &LLVMF = *M->getFunction("foo");
461+
sandboxir::Context Ctx(C);
462+
463+
auto &F = *Ctx.createFunction(&LLVMF);
464+
auto &BB = *F.begin();
465+
auto It = BB.begin();
466+
auto *I0 = &*It++;
467+
auto *I1 = &*It++;
468+
auto *I2 = &*It++;
469+
// Check classof() and creation.
470+
auto *Array = cast<sandboxir::ConstantArray>(I0->getOperand(0));
471+
EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Array));
472+
auto *Struct = cast<sandboxir::ConstantStruct>(I1->getOperand(0));
473+
EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Struct));
474+
auto *Vector = cast<sandboxir::ConstantVector>(I2->getOperand(0));
475+
EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Vector));
476+
477+
auto *ZeroI42 = cast<sandboxir::ConstantInt>(Array->getOperand(0));
478+
auto *OneI42 = cast<sandboxir::ConstantInt>(Array->getOperand(1));
479+
// Check ConstantArray::get(), getType().
480+
auto *NewCA =
481+
sandboxir::ConstantArray::get(Array->getType(), {ZeroI42, OneI42});
482+
EXPECT_EQ(NewCA, Array);
483+
484+
// Check ConstantStruct::get(), getType().
485+
auto *NewCS =
486+
sandboxir::ConstantStruct::get(Struct->getType(), {ZeroI42, OneI42});
487+
EXPECT_EQ(NewCS, Struct);
488+
// Check ConstantStruct::get(...).
489+
auto *NewCS2 =
490+
sandboxir::ConstantStruct::get(Struct->getType(), ZeroI42, OneI42);
491+
EXPECT_EQ(NewCS2, Struct);
492+
// Check ConstantStruct::getAnon(ArayRef).
493+
auto *AnonCS = sandboxir::ConstantStruct::getAnon({ZeroI42, OneI42});
494+
EXPECT_FALSE(cast<sandboxir::StructType>(AnonCS->getType())->isPacked());
495+
auto *AnonCSPacked =
496+
sandboxir::ConstantStruct::getAnon({ZeroI42, OneI42}, /*Packed=*/true);
497+
EXPECT_TRUE(cast<sandboxir::StructType>(AnonCSPacked->getType())->isPacked());
498+
// Check ConstantStruct::getAnon(Ctx, ArrayRef).
499+
auto *AnonCS2 = sandboxir::ConstantStruct::getAnon(Ctx, {ZeroI42, OneI42});
500+
EXPECT_EQ(AnonCS2, AnonCS);
501+
auto *AnonCS2Packed = sandboxir::ConstantStruct::getAnon(
502+
Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
503+
EXPECT_EQ(AnonCS2Packed, AnonCSPacked);
504+
// Check ConstantStruct::getTypeForElements(Ctx, ArrayRef).
505+
auto *StructTy =
506+
sandboxir::ConstantStruct::getTypeForElements(Ctx, {ZeroI42, OneI42});
507+
EXPECT_EQ(StructTy, Struct->getType());
508+
EXPECT_FALSE(StructTy->isPacked());
509+
// Check ConstantStruct::getTypeForElements(Ctx, ArrayRef, Packed).
510+
auto *StructTyPacked = sandboxir::ConstantStruct::getTypeForElements(
511+
Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
512+
EXPECT_TRUE(StructTyPacked->isPacked());
513+
// Check ConstantStruct::getTypeForElements(ArrayRef).
514+
auto *StructTy2 =
515+
sandboxir::ConstantStruct::getTypeForElements(Ctx, {ZeroI42, OneI42});
516+
EXPECT_EQ(StructTy2, Struct->getType());
517+
// Check ConstantStruct::getTypeForElements(ArrayRef, Packed).
518+
auto *StructTy2Packed = sandboxir::ConstantStruct::getTypeForElements(
519+
Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
520+
EXPECT_EQ(StructTy2Packed, StructTyPacked);
521+
}
522+
448523
TEST_F(SandboxIRTest, Use) {
449524
parseIR(C, R"IR(
450525
define i32 @foo(i32 %v0, i32 %v1) {

0 commit comments

Comments
 (0)