Skip to content

Commit bb72865

Browse files
[SandboxIR] Implement FixedVectorType (#107930)
1 parent d14a600 commit bb72865

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

llvm/include/llvm/SandboxIR/Type.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Context;
2525
// Forward declare friend classes for MSVC.
2626
class PointerType;
2727
class VectorType;
28+
class FixedVectorType;
2829
class IntegerType;
2930
class FunctionType;
3031
class ArrayType;
@@ -41,6 +42,7 @@ class Type {
4142
friend class ArrayType; // For LLVMTy.
4243
friend class StructType; // For LLVMTy.
4344
friend class VectorType; // For LLVMTy.
45+
friend class FixedVectorType; // For LLVMTy.
4446
friend class PointerType; // For LLVMTy.
4547
friend class FunctionType; // For LLVMTy.
4648
friend class IntegerType; // For LLVMTy.
@@ -344,6 +346,50 @@ class VectorType : public Type {
344346
}
345347
};
346348

349+
class FixedVectorType : public VectorType {
350+
public:
351+
static FixedVectorType *get(Type *ElementType, unsigned NumElts);
352+
353+
static FixedVectorType *get(Type *ElementType, const FixedVectorType *FVTy) {
354+
return get(ElementType, FVTy->getNumElements());
355+
}
356+
357+
static FixedVectorType *getInteger(FixedVectorType *VTy) {
358+
return cast<FixedVectorType>(VectorType::getInteger(VTy));
359+
}
360+
361+
static FixedVectorType *getExtendedElementVectorType(FixedVectorType *VTy) {
362+
return cast<FixedVectorType>(VectorType::getExtendedElementVectorType(VTy));
363+
}
364+
365+
static FixedVectorType *getTruncatedElementVectorType(FixedVectorType *VTy) {
366+
return cast<FixedVectorType>(
367+
VectorType::getTruncatedElementVectorType(VTy));
368+
}
369+
370+
static FixedVectorType *getSubdividedVectorType(FixedVectorType *VTy,
371+
int NumSubdivs) {
372+
return cast<FixedVectorType>(
373+
VectorType::getSubdividedVectorType(VTy, NumSubdivs));
374+
}
375+
376+
static FixedVectorType *getHalfElementsVectorType(FixedVectorType *VTy) {
377+
return cast<FixedVectorType>(VectorType::getHalfElementsVectorType(VTy));
378+
}
379+
380+
static FixedVectorType *getDoubleElementsVectorType(FixedVectorType *VTy) {
381+
return cast<FixedVectorType>(VectorType::getDoubleElementsVectorType(VTy));
382+
}
383+
384+
static bool classof(const Type *T) {
385+
return isa<llvm::FixedVectorType>(T->LLVMTy);
386+
}
387+
388+
unsigned getNumElements() const {
389+
return cast<llvm::FixedVectorType>(LLVMTy)->getNumElements();
390+
}
391+
};
392+
347393
class FunctionType : public Type {
348394
public:
349395
// TODO: add missing functions

llvm/lib/SandboxIR/Type.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ bool VectorType::isValidElementType(Type *ElemTy) {
103103
return llvm::VectorType::isValidElementType(ElemTy->LLVMTy);
104104
}
105105

106+
FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
107+
return cast<FixedVectorType>(ElementType->getContext().getType(
108+
llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
109+
}
110+
106111
IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
107112
return cast<IntegerType>(
108113
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));

llvm/unittests/SandboxIR/TypesTest.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,64 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
323323
EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy));
324324
}
325325

326+
TEST_F(SandboxTypeTest, FixedVectorType) {
327+
parseIR(C, R"IR(
328+
define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
329+
ret void
330+
}
331+
)IR");
332+
llvm::Function *LLVMF = &*M->getFunction("foo");
333+
sandboxir::Context Ctx(C);
334+
auto *F = Ctx.createFunction(LLVMF);
335+
// Check classof(), creation, accessors
336+
auto *Vec4i16Ty = cast<sandboxir::FixedVectorType>(F->getArg(0)->getType());
337+
EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
338+
EXPECT_EQ(Vec4i16Ty->getElementCount(), ElementCount::getFixed(4));
339+
340+
// get(ElementType, NumElements)
341+
EXPECT_EQ(
342+
sandboxir::FixedVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
343+
F->getArg(0)->getType());
344+
// get(ElementType, Other)
345+
EXPECT_EQ(sandboxir::FixedVectorType::get(
346+
sandboxir::Type::getInt16Ty(Ctx),
347+
cast<sandboxir::FixedVectorType>(F->getArg(0)->getType())),
348+
F->getArg(0)->getType());
349+
auto *Vec4FTy = cast<sandboxir::FixedVectorType>(F->getArg(1)->getType());
350+
EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
351+
// getInteger
352+
auto *Vec4i32Ty = sandboxir::FixedVectorType::getInteger(Vec4FTy);
353+
EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
354+
EXPECT_EQ(Vec4i32Ty->getElementCount(), Vec4FTy->getElementCount());
355+
// getExtendedElementCountVectorType
356+
auto *Vec4i64Ty =
357+
sandboxir::FixedVectorType::getExtendedElementVectorType(Vec4i16Ty);
358+
EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32));
359+
EXPECT_EQ(Vec4i64Ty->getElementCount(), Vec4i16Ty->getElementCount());
360+
// getTruncatedElementVectorType
361+
auto *Vec4i8Ty =
362+
sandboxir::FixedVectorType::getTruncatedElementVectorType(Vec4i16Ty);
363+
EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8));
364+
EXPECT_EQ(Vec4i8Ty->getElementCount(), Vec4i8Ty->getElementCount());
365+
// getSubdividedVectorType
366+
auto *Vec8i8Ty =
367+
sandboxir::FixedVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
368+
EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
369+
EXPECT_EQ(Vec8i8Ty->getElementCount(), ElementCount::getFixed(8));
370+
// getNumElements
371+
EXPECT_EQ(Vec8i8Ty->getNumElements(), 8u);
372+
// getHalfElementsVectorType
373+
auto *Vec2i16Ty =
374+
sandboxir::FixedVectorType::getHalfElementsVectorType(Vec4i16Ty);
375+
EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
376+
EXPECT_EQ(Vec2i16Ty->getElementCount(), ElementCount::getFixed(2));
377+
// getDoubleElementsVectorType
378+
auto *Vec8i16Ty =
379+
sandboxir::FixedVectorType::getDoubleElementsVectorType(Vec4i16Ty);
380+
EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
381+
EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
382+
}
383+
326384
TEST_F(SandboxTypeTest, FunctionType) {
327385
parseIR(C, R"IR(
328386
define void @foo() {

0 commit comments

Comments
 (0)