Skip to content

Commit b81947e

Browse files
authored
[SandboxIR] Implement ConstantDataVector member functions (llvm#136200)
Mirroring LLVM IR.
1 parent 1444951 commit b81947e

File tree

3 files changed

+160
-1
lines changed

3 files changed

+160
-1
lines changed

llvm/include/llvm/SandboxIR/Constant.h

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,96 @@ class ConstantDataVector final : public ConstantDataSequential {
670670
friend class Context;
671671

672672
public:
673-
// TODO: Add missing functions.
673+
/// Methods for support type inquiry through isa, cast, and dyn_cast:
674+
static bool classof(const Value *From) {
675+
return From->getSubclassID() == ClassID::ConstantDataVector;
676+
}
677+
/// get() constructors - Return a constant with vector type with an element
678+
/// count and element type matching the ArrayRef passed in. Note that this
679+
/// can return a ConstantAggregateZero object.
680+
static Constant *get(Context &Ctx, ArrayRef<uint8_t> Elts) {
681+
auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
682+
return Ctx.getOrCreateConstant(NewLLVMC);
683+
}
684+
static Constant *get(Context &Ctx, ArrayRef<uint16_t> Elts) {
685+
auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
686+
return Ctx.getOrCreateConstant(NewLLVMC);
687+
}
688+
static Constant *get(Context &Ctx, ArrayRef<uint32_t> Elts) {
689+
auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
690+
return Ctx.getOrCreateConstant(NewLLVMC);
691+
}
692+
static Constant *get(Context &Ctx, ArrayRef<uint64_t> Elts) {
693+
auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
694+
return Ctx.getOrCreateConstant(NewLLVMC);
695+
}
696+
static Constant *get(Context &Ctx, ArrayRef<float> Elts) {
697+
auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
698+
return Ctx.getOrCreateConstant(NewLLVMC);
699+
}
700+
static Constant *get(Context &Ctx, ArrayRef<double> Elts) {
701+
auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
702+
return Ctx.getOrCreateConstant(NewLLVMC);
703+
}
704+
705+
/// getRaw() constructor - Return a constant with vector type with an element
706+
/// count and element type matching the NumElements and ElementTy parameters
707+
/// passed in. Note that this can return a ConstantAggregateZero object.
708+
/// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is
709+
/// the buffer containing the elements. Be careful to make sure Data uses the
710+
/// right endianness, the buffer will be used as-is.
711+
static Constant *getRaw(StringRef Data, uint64_t NumElements,
712+
Type *ElementTy) {
713+
auto *NewLLVMC =
714+
llvm::ConstantDataVector::getRaw(Data, NumElements, ElementTy->LLVMTy);
715+
return ElementTy->getContext().getOrCreateConstant(NewLLVMC);
716+
}
717+
/// getFP() constructors - Return a constant of vector type with a float
718+
/// element type taken from argument `ElementType', and count taken from
719+
/// argument `Elts'. The amount of bits of the contained type must match the
720+
/// number of bits of the type contained in the passed in ArrayRef.
721+
/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
722+
/// that this can return a ConstantAggregateZero object.
723+
static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
724+
auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts);
725+
return ElementType->getContext().getOrCreateConstant(NewLLVMC);
726+
}
727+
static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts) {
728+
auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts);
729+
return ElementType->getContext().getOrCreateConstant(NewLLVMC);
730+
}
731+
static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts) {
732+
auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts);
733+
return ElementType->getContext().getOrCreateConstant(NewLLVMC);
734+
}
735+
736+
/// Return a ConstantVector with the specified constant in each element.
737+
/// The specified constant has to be a of a compatible type (i8/i16/
738+
/// i32/i64/half/bfloat/float/double) and must be a ConstantFP or ConstantInt.
739+
static Constant *getSplat(unsigned NumElts, Constant *Elt) {
740+
auto *NewLLVMC = llvm::ConstantDataVector::getSplat(
741+
NumElts, cast<llvm::Constant>(Elt->Val));
742+
return Elt->getContext().getOrCreateConstant(NewLLVMC);
743+
}
744+
745+
/// Returns true if this is a splat constant, meaning that all elements have
746+
/// the same value.
747+
bool isSplat() const {
748+
return cast<llvm::ConstantDataVector>(Val)->isSplat();
749+
}
750+
751+
/// If this is a splat constant, meaning that all of the elements have the
752+
/// same value, return that value. Otherwise return NULL.
753+
Constant *getSplatValue() const {
754+
return Ctx.getOrCreateConstant(
755+
cast<llvm::ConstantDataVector>(Val)->getSplatValue());
756+
}
757+
758+
/// Specialize the getType() method to always return a FixedVectorType,
759+
/// which reduces the amount of casting needed in parts of the compiler.
760+
inline FixedVectorType *getType() const {
761+
return cast<FixedVectorType>(Value::getType());
762+
}
674763
};
675764

676765
// TODO: Inherit from ConstantData.

llvm/include/llvm/SandboxIR/Value.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class Value {
171171
friend class Region;
172172
friend class ScoreBoard; // Needs access to `Val` for the instruction cost.
173173
friend class ConstantDataArray; // For `Val`
174+
friend class ConstantDataVector; // For `Val`
174175

175176
/// All values point to the context.
176177
Context &Ctx;

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ define void @foo() {
622622
%fvector = extractelement <2 x double> <double 0.0, double 1.0>, i32 0
623623
%string = extractvalue [6 x i8] [i8 72, i8 69, i8 76, i8 76, i8 79, i8 0], 0
624624
%stringNoNull = extractvalue [5 x i8] [i8 72, i8 69, i8 76, i8 76, i8 79], 0
625+
%splat = extractelement <4 x i8> <i8 1, i8 1, i8 1, i8 1>, i32 0
625626
ret void
626627
}
627628
)IR");
@@ -637,6 +638,7 @@ define void @foo() {
637638
auto *I3 = &*It++;
638639
auto *I4 = &*It++;
639640
auto *I5 = &*It++;
641+
auto *I6 = &*It++;
640642
auto *Array = cast<sandboxir::ConstantDataArray>(I0->getOperand(0));
641643
EXPECT_TRUE(isa<sandboxir::ConstantDataSequential>(Array));
642644
auto *Vector = cast<sandboxir::ConstantDataVector>(I1->getOperand(0));
@@ -649,6 +651,8 @@ define void @foo() {
649651
EXPECT_TRUE(isa<sandboxir::ConstantDataArray>(String));
650652
auto *StringNoNull = cast<sandboxir::ConstantDataArray>(I5->getOperand(0));
651653
EXPECT_TRUE(isa<sandboxir::ConstantDataArray>(StringNoNull));
654+
auto *Splat = cast<sandboxir::ConstantDataVector>(I6->getOperand(0));
655+
EXPECT_TRUE(isa<sandboxir::ConstantDataVector>(Splat));
652656

653657
auto *Zero8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0);
654658
auto *One8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 1);
@@ -750,9 +754,74 @@ define void @foo() {
750754
llvm::Type::getDoubleTy(C), Elts64))));
751755
// Check getString().
752756
EXPECT_EQ(sandboxir::ConstantDataArray::getString(Ctx, "HELLO"), String);
757+
753758
EXPECT_EQ(sandboxir::ConstantDataArray::getString(Ctx, "HELLO",
754759
/*AddNull=*/false),
755760
StringNoNull);
761+
EXPECT_EQ(
762+
sandboxir::ConstantDataArray::getString(Ctx, "HELLO", /*AddNull=*/false),
763+
StringNoNull);
764+
765+
{
766+
// Check ConstantDataArray member functions
767+
// ----------------------------------------
768+
// Check get().
769+
SmallVector<uint8_t> Elts8({0u, 1u});
770+
SmallVector<uint16_t> Elts16({0u, 1u});
771+
SmallVector<uint32_t> Elts32({0u, 1u});
772+
SmallVector<uint64_t> Elts64({0u, 1u});
773+
SmallVector<float> EltsF32({0.0, 1.0});
774+
SmallVector<double> EltsF64({0.0, 1.0});
775+
auto *CDV8 = sandboxir::ConstantDataVector::get(Ctx, Elts8);
776+
EXPECT_EQ(CDV8, cast<sandboxir::ConstantDataVector>(
777+
Ctx.getValue(llvm::ConstantDataVector::get(C, Elts8))));
778+
auto *CDV16 = sandboxir::ConstantDataVector::get(Ctx, Elts16);
779+
EXPECT_EQ(CDV16, cast<sandboxir::ConstantDataVector>(Ctx.getValue(
780+
llvm::ConstantDataVector::get(C, Elts16))));
781+
auto *CDV32 = sandboxir::ConstantDataVector::get(Ctx, Elts32);
782+
EXPECT_EQ(CDV32, cast<sandboxir::ConstantDataVector>(Ctx.getValue(
783+
llvm::ConstantDataVector::get(C, Elts32))));
784+
auto *CDVF32 = sandboxir::ConstantDataVector::get(Ctx, EltsF32);
785+
EXPECT_EQ(CDVF32, cast<sandboxir::ConstantDataVector>(Ctx.getValue(
786+
llvm::ConstantDataVector::get(C, EltsF32))));
787+
auto *CDVF64 = sandboxir::ConstantDataVector::get(Ctx, EltsF64);
788+
EXPECT_EQ(CDVF64, cast<sandboxir::ConstantDataVector>(Ctx.getValue(
789+
llvm::ConstantDataVector::get(C, EltsF64))));
790+
// Check getRaw().
791+
auto *CDVRaw = sandboxir::ConstantDataVector::getRaw(
792+
StringRef("HELLO"), 5, sandboxir::Type::getInt8Ty(Ctx));
793+
EXPECT_EQ(CDVRaw,
794+
cast<sandboxir::ConstantDataVector>(
795+
Ctx.getValue(llvm::ConstantDataVector::getRaw(
796+
StringRef("HELLO"), 5, llvm::Type::getInt8Ty(C)))));
797+
// Check getFP().
798+
auto *CDVFP16 = sandboxir::ConstantDataVector::getFP(F16Ty, Elts16);
799+
EXPECT_EQ(CDVFP16, cast<sandboxir::ConstantDataVector>(
800+
Ctx.getValue(llvm::ConstantDataVector::getFP(
801+
llvm::Type::getHalfTy(C), Elts16))));
802+
auto *CDVFP32 = sandboxir::ConstantDataVector::getFP(F32Ty, Elts32);
803+
EXPECT_EQ(CDVFP32, cast<sandboxir::ConstantDataVector>(
804+
Ctx.getValue(llvm::ConstantDataVector::getFP(
805+
llvm::Type::getFloatTy(C), Elts32))));
806+
auto *CDVFP64 = sandboxir::ConstantDataVector::getFP(F64Ty, Elts64);
807+
EXPECT_EQ(CDVFP64, cast<sandboxir::ConstantDataVector>(
808+
Ctx.getValue(llvm::ConstantDataVector::getFP(
809+
llvm::Type::getDoubleTy(C), Elts64))));
810+
// Check getSplat().
811+
auto *NewSplat = cast<sandboxir::ConstantDataVector>(
812+
sandboxir::ConstantDataVector::getSplat(4, One8));
813+
EXPECT_EQ(NewSplat, Splat);
814+
// Check isSplat().
815+
EXPECT_TRUE(NewSplat->isSplat());
816+
EXPECT_FALSE(Vector->isSplat());
817+
// Check getSplatValue().
818+
EXPECT_EQ(NewSplat->getSplatValue(), One8);
819+
// Check getType().
820+
EXPECT_TRUE(isa<sandboxir::FixedVectorType>(NewSplat->getType()));
821+
EXPECT_EQ(
822+
cast<sandboxir::FixedVectorType>(NewSplat->getType())->getNumElements(),
823+
4u);
824+
}
756825
}
757826

758827
TEST_F(SandboxIRTest, ConstantPointerNull) {

0 commit comments

Comments
 (0)