Skip to content

Commit 3b4e7c9

Browse files
[SandboxIR] Implement ScalableVectorType (#108124)
As in the heading.
1 parent c571113 commit 3b4e7c9

File tree

3 files changed

+133
-15
lines changed

3 files changed

+133
-15
lines changed

llvm/include/llvm/SandboxIR/Type.h

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class Context;
2626
class PointerType;
2727
class VectorType;
2828
class FixedVectorType;
29+
class ScalableVectorType;
2930
class IntegerType;
3031
class FunctionType;
3132
class ArrayType;
@@ -39,21 +40,22 @@ class StructType;
3940
class Type {
4041
protected:
4142
llvm::Type *LLVMTy;
42-
friend class ArrayType; // For LLVMTy.
43-
friend class StructType; // For LLVMTy.
44-
friend class VectorType; // For LLVMTy.
45-
friend class FixedVectorType; // For LLVMTy.
46-
friend class PointerType; // For LLVMTy.
47-
friend class FunctionType; // For LLVMTy.
48-
friend class IntegerType; // For LLVMTy.
49-
friend class Function; // For LLVMTy.
50-
friend class CallBase; // For LLVMTy.
51-
friend class ConstantInt; // For LLVMTy.
52-
friend class ConstantArray; // For LLVMTy.
53-
friend class ConstantStruct; // For LLVMTy.
54-
friend class ConstantVector; // For LLVMTy.
55-
friend class CmpInst; // For LLVMTy. TODO: Cleanup after
56-
// sandboxir::VectorType is more complete.
43+
friend class ArrayType; // For LLVMTy.
44+
friend class StructType; // For LLVMTy.
45+
friend class VectorType; // For LLVMTy.
46+
friend class FixedVectorType; // For LLVMTy.
47+
friend class ScalableVectorType; // For LLVMTy.
48+
friend class PointerType; // For LLVMTy.
49+
friend class FunctionType; // For LLVMTy.
50+
friend class IntegerType; // For LLVMTy.
51+
friend class Function; // For LLVMTy.
52+
friend class CallBase; // For LLVMTy.
53+
friend class ConstantInt; // For LLVMTy.
54+
friend class ConstantArray; // For LLVMTy.
55+
friend class ConstantStruct; // For LLVMTy.
56+
friend class ConstantVector; // For LLVMTy.
57+
friend class CmpInst; // For LLVMTy. TODO: Cleanup after
58+
// sandboxir::VectorType is more complete.
5759

5860
// Friend all instruction classes because `create()` functions use LLVMTy.
5961
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
@@ -390,6 +392,57 @@ class FixedVectorType : public VectorType {
390392
}
391393
};
392394

395+
class ScalableVectorType : public VectorType {
396+
public:
397+
static ScalableVectorType *get(Type *ElementType, unsigned MinNumElts);
398+
399+
static ScalableVectorType *get(Type *ElementType,
400+
const ScalableVectorType *SVTy) {
401+
return get(ElementType, SVTy->getMinNumElements());
402+
}
403+
404+
static ScalableVectorType *getInteger(ScalableVectorType *VTy) {
405+
return cast<ScalableVectorType>(VectorType::getInteger(VTy));
406+
}
407+
408+
static ScalableVectorType *
409+
getExtendedElementVectorType(ScalableVectorType *VTy) {
410+
return cast<ScalableVectorType>(
411+
VectorType::getExtendedElementVectorType(VTy));
412+
}
413+
414+
static ScalableVectorType *
415+
getTruncatedElementVectorType(ScalableVectorType *VTy) {
416+
return cast<ScalableVectorType>(
417+
VectorType::getTruncatedElementVectorType(VTy));
418+
}
419+
420+
static ScalableVectorType *getSubdividedVectorType(ScalableVectorType *VTy,
421+
int NumSubdivs) {
422+
return cast<ScalableVectorType>(
423+
VectorType::getSubdividedVectorType(VTy, NumSubdivs));
424+
}
425+
426+
static ScalableVectorType *
427+
getHalfElementsVectorType(ScalableVectorType *VTy) {
428+
return cast<ScalableVectorType>(VectorType::getHalfElementsVectorType(VTy));
429+
}
430+
431+
static ScalableVectorType *
432+
getDoubleElementsVectorType(ScalableVectorType *VTy) {
433+
return cast<ScalableVectorType>(
434+
VectorType::getDoubleElementsVectorType(VTy));
435+
}
436+
437+
unsigned getMinNumElements() const {
438+
return cast<llvm::ScalableVectorType>(LLVMTy)->getMinNumElements();
439+
}
440+
441+
static bool classof(const Type *T) {
442+
return isa<llvm::ScalableVectorType>(T->LLVMTy);
443+
}
444+
};
445+
393446
class FunctionType : public Type {
394447
public:
395448
// TODO: add missing functions

llvm/lib/SandboxIR/Type.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
108108
llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
109109
}
110110

111+
ScalableVectorType *ScalableVectorType::get(Type *ElementType,
112+
unsigned NumElts) {
113+
return cast<ScalableVectorType>(ElementType->getContext().getType(
114+
llvm::ScalableVectorType::get(ElementType->LLVMTy, NumElts)));
115+
}
116+
111117
IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
112118
return cast<IntegerType>(
113119
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));

llvm/unittests/SandboxIR/TypesTest.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,65 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
381381
EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
382382
}
383383

384+
TEST_F(SandboxTypeTest, ScalableVectorType) {
385+
parseIR(C, R"IR(
386+
define void @foo(<vscale x 4 x i16> %vi0, <vscale x 4 x float> %vf1, i8 %i0) {
387+
ret void
388+
}
389+
)IR");
390+
llvm::Function *LLVMF = &*M->getFunction("foo");
391+
sandboxir::Context Ctx(C);
392+
auto *F = Ctx.createFunction(LLVMF);
393+
// Check classof(), creation, accessors
394+
auto *Vec4i16Ty =
395+
cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType());
396+
EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
397+
EXPECT_EQ(Vec4i16Ty->getMinNumElements(), 4u);
398+
399+
// get(ElementType, NumElements)
400+
EXPECT_EQ(
401+
sandboxir::ScalableVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
402+
F->getArg(0)->getType());
403+
// get(ElementType, Other)
404+
EXPECT_EQ(sandboxir::ScalableVectorType::get(
405+
sandboxir::Type::getInt16Ty(Ctx),
406+
cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType())),
407+
F->getArg(0)->getType());
408+
auto *Vec4FTy = cast<sandboxir::ScalableVectorType>(F->getArg(1)->getType());
409+
EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
410+
// getInteger
411+
auto *Vec4i32Ty = sandboxir::ScalableVectorType::getInteger(Vec4FTy);
412+
EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
413+
EXPECT_EQ(Vec4i32Ty->getMinNumElements(), Vec4FTy->getMinNumElements());
414+
// getExtendedElementCountVectorType
415+
auto *Vec4i64Ty =
416+
sandboxir::ScalableVectorType::getExtendedElementVectorType(Vec4i16Ty);
417+
EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32));
418+
EXPECT_EQ(Vec4i64Ty->getMinNumElements(), Vec4i16Ty->getMinNumElements());
419+
// getTruncatedElementVectorType
420+
auto *Vec4i8Ty =
421+
sandboxir::ScalableVectorType::getTruncatedElementVectorType(Vec4i16Ty);
422+
EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8));
423+
EXPECT_EQ(Vec4i8Ty->getMinNumElements(), Vec4i8Ty->getMinNumElements());
424+
// getSubdividedVectorType
425+
auto *Vec8i8Ty =
426+
sandboxir::ScalableVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
427+
EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
428+
EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
429+
// getMinNumElements
430+
EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
431+
// getHalfElementsVectorType
432+
auto *Vec2i16Ty =
433+
sandboxir::ScalableVectorType::getHalfElementsVectorType(Vec4i16Ty);
434+
EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
435+
EXPECT_EQ(Vec2i16Ty->getMinNumElements(), 2u);
436+
// getDoubleElementsVectorType
437+
auto *Vec8i16Ty =
438+
sandboxir::ScalableVectorType::getDoubleElementsVectorType(Vec4i16Ty);
439+
EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
440+
EXPECT_EQ(Vec8i16Ty->getMinNumElements(), 8u);
441+
}
442+
384443
TEST_F(SandboxTypeTest, FunctionType) {
385444
parseIR(C, R"IR(
386445
define void @foo() {

0 commit comments

Comments
 (0)