Skip to content

Commit d591630

Browse files
authored
[CIR] Upstream initial support for fixed size VectorType (#136488)
This change adds the initial support for VectorType Issue #136487
1 parent bb17651 commit d591630

File tree

12 files changed

+287
-4
lines changed

12 files changed

+287
-4
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
9090
return cir::FPAttr::getZero(ty);
9191
if (auto arrTy = mlir::dyn_cast<cir::ArrayType>(ty))
9292
return cir::ZeroAttr::get(arrTy);
93+
if (auto vecTy = mlir::dyn_cast<cir::VectorType>(ty))
94+
return cir::ZeroAttr::get(vecTy);
9395
if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty))
9496
return getConstNullPtrAttr(ptrTy);
9597
if (auto recordTy = mlir::dyn_cast<cir::RecordType>(ty))

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,43 @@ def CIR_ArrayType : CIR_Type<"Array", "array",
307307
}];
308308
}
309309

310+
//===----------------------------------------------------------------------===//
311+
// VectorType (fixed size)
312+
//===----------------------------------------------------------------------===//
313+
314+
def CIR_VectorType : CIR_Type<"Vector", "vector",
315+
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
316+
317+
let summary = "CIR vector type";
318+
let description = [{
319+
`!cir.vector' represents fixed-size vector types, parameterized
320+
by the element type and the number of elements.
321+
322+
Example:
323+
324+
```mlir
325+
!cir.vector<!u64i x 2>
326+
!cir.vector<!cir.float x 4>
327+
```
328+
}];
329+
330+
let parameters = (ins "mlir::Type":$elementType, "uint64_t":$size);
331+
332+
let assemblyFormat = [{
333+
`<` $size `x` $elementType `>`
334+
}];
335+
336+
let builders = [
337+
TypeBuilderWithInferredContext<(ins
338+
"mlir::Type":$elementType, "uint64_t":$size
339+
), [{
340+
return $_get(elementType.getContext(), elementType, size);
341+
}]>,
342+
];
343+
344+
let genVerifyDecl = 1;
345+
}
346+
310347
//===----------------------------------------------------------------------===//
311348
// FuncType
312349
//===----------------------------------------------------------------------===//
@@ -533,8 +570,8 @@ def CIRRecordType : Type<
533570
//===----------------------------------------------------------------------===//
534571

535572
def CIR_AnyType : AnyTypeOf<[
536-
CIR_VoidType, CIR_BoolType, CIR_ArrayType, CIR_IntType, CIR_AnyFloat,
537-
CIR_PointerType, CIR_FuncType, CIR_RecordType
573+
CIR_VoidType, CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_IntType,
574+
CIR_AnyFloat, CIR_PointerType, CIR_FuncType, CIR_RecordType
538575
]>;
539576

540577
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
8383
cir::IntType>(ty))
8484
return true;
8585

86+
if (const auto vt = mlir::dyn_cast<cir::VectorType>(ty))
87+
return isSized(vt.getElementType());
88+
8689
assert(!cir::MissingFeatures::unsizedTypes());
8790
return false;
8891
}

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,14 @@ mlir::Type CIRGenTypes::convertType(QualType type) {
399399
break;
400400
}
401401

402+
case Type::ExtVector:
403+
case Type::Vector: {
404+
const VectorType *vec = cast<VectorType>(ty);
405+
const mlir::Type elemTy = convertType(vec->getElementType());
406+
resultType = cir::VectorType::get(elemTy, vec->getNumElements());
407+
break;
408+
}
409+
402410
case Type::FunctionNoProto:
403411
case Type::FunctionProto:
404412
resultType = convertFunctionTypeInternal(type);

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
220220
}
221221

222222
if (isa<cir::ZeroAttr>(attrType)) {
223-
if (isa<cir::RecordType, cir::ArrayType>(opType))
223+
if (isa<cir::RecordType, cir::ArrayType, cir::VectorType>(opType))
224224
return success();
225225
return op->emitOpError("zero expects struct or array type");
226226
}

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ BoolType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
652652
}
653653

654654
//===----------------------------------------------------------------------===//
655-
// Definitions
655+
// ArrayType Definitions
656656
//===----------------------------------------------------------------------===//
657657

658658
llvm::TypeSize
@@ -667,6 +667,41 @@ ArrayType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
667667
return dataLayout.getTypeABIAlignment(getEltType());
668668
}
669669

670+
//===----------------------------------------------------------------------===//
671+
// VectorType Definitions
672+
//===----------------------------------------------------------------------===//
673+
674+
llvm::TypeSize cir::VectorType::getTypeSizeInBits(
675+
const ::mlir::DataLayout &dataLayout,
676+
::mlir::DataLayoutEntryListRef params) const {
677+
return llvm::TypeSize::getFixed(
678+
getSize() * dataLayout.getTypeSizeInBits(getElementType()));
679+
}
680+
681+
uint64_t
682+
cir::VectorType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
683+
::mlir::DataLayoutEntryListRef params) const {
684+
return llvm::NextPowerOf2(dataLayout.getTypeSizeInBits(*this));
685+
}
686+
687+
mlir::LogicalResult cir::VectorType::verify(
688+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
689+
mlir::Type elementType, uint64_t size) {
690+
if (size == 0)
691+
return emitError() << "the number of vector elements must be non-zero";
692+
693+
// Check if it a valid FixedVectorType
694+
if (mlir::isa<cir::PointerType, cir::FP128Type>(elementType))
695+
return success();
696+
697+
// Check if it a valid VectorType
698+
if (mlir::isa<cir::IntType>(elementType) ||
699+
isAnyFloatingPointType(elementType))
700+
return success();
701+
702+
return emitError() << "unsupported element type for CIR vector";
703+
}
704+
670705
//===----------------------------------------------------------------------===//
671706
// PointerType Definitions
672707
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/DLTI/DLTI.h"
2020
#include "mlir/Dialect/Func/IR/FuncOps.h"
2121
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
2223
#include "mlir/IR/BuiltinAttributes.h"
2324
#include "mlir/IR/BuiltinDialect.h"
2425
#include "mlir/IR/BuiltinOps.h"
@@ -1392,6 +1393,10 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
13921393
convertTypeForMemory(converter, dataLayout, type.getEltType());
13931394
return mlir::LLVM::LLVMArrayType::get(ty, type.getSize());
13941395
});
1396+
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
1397+
const mlir::Type ty = converter.convertType(type.getElementType());
1398+
return mlir::VectorType::get(type.getSize(), ty);
1399+
});
13951400
converter.addConversion([&](cir::BoolType type) -> mlir::Type {
13961401
return mlir::IntegerType::get(type.getContext(), 1,
13971402
mlir::IntegerType::Signless);

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
7+
8+
typedef int vi4 __attribute__((ext_vector_type(4)));
9+
typedef int vi3 __attribute__((ext_vector_type(3)));
10+
typedef int vi2 __attribute__((ext_vector_type(2)));
11+
typedef double vd2 __attribute__((ext_vector_type(2)));
12+
13+
vi4 vec_a;
14+
// CIR: cir.global external @[[VEC_A:.*]] = #cir.zero : !cir.vector<4 x !s32i>
15+
16+
// LLVM: @[[VEC_A:.*]] = dso_local global <4 x i32> zeroinitializer
17+
18+
// OGCG: @[[VEC_A:.*]] = global <4 x i32> zeroinitializer
19+
20+
vi3 vec_b;
21+
// CIR: cir.global external @[[VEC_B:.*]] = #cir.zero : !cir.vector<3 x !s32i>
22+
23+
// LLVM: @[[VEC_B:.*]] = dso_local global <3 x i32> zeroinitializer
24+
25+
// OGCG: @[[VEC_B:.*]] = global <3 x i32> zeroinitializer
26+
27+
vi2 vec_c;
28+
// CIR: cir.global external @[[VEC_C:.*]] = #cir.zero : !cir.vector<2 x !s32i>
29+
30+
// LLVM: @[[VEC_C:.*]] = dso_local global <2 x i32> zeroinitializer
31+
32+
// OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer
33+
34+
vd2 d;
35+
36+
// CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x !cir.double>
37+
38+
// LLVM: @[[VEC_D:.*]] = dso_local global <2 x double> zeroinitialize
39+
40+
// OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer
41+
42+
void foo() {
43+
vi4 a;
44+
vi3 b;
45+
vi2 c;
46+
vd2 d;
47+
}
48+
49+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
50+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
51+
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
52+
// CIR: %[[VEC_D:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["d"]
53+
54+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
55+
// LLVM: %[[VEC_B:.*]] = alloca <3 x i32>, i64 1, align 16
56+
// LLVM: %[[VEC_C:.*]] = alloca <2 x i32>, i64 1, align 8
57+
// LLVM: %[[VEC_D:.*]] = alloca <2 x double>, i64 1, align 16
58+
59+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
60+
// OGCG: %[[VEC_B:.*]] = alloca <3 x i32>, align 16
61+
// OGCG: %[[VEC_C:.*]] = alloca <2 x i32>, align 8
62+
// OGCG: %[[VEC_D:.*]] = alloca <2 x double>, align 16
63+
64+
void foo2(vi4 p) {}
65+
66+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["p", init]
67+
// CIR: cir.store %{{.*}}, %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
68+
69+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
70+
// LLVM: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
71+
72+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
73+
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
7+
8+
typedef int vi4 __attribute__((vector_size(16)));
9+
typedef double vd2 __attribute__((vector_size(16)));
10+
typedef long long vll2 __attribute__((vector_size(16)));
11+
12+
vi4 vec_a;
13+
// CIR: cir.global external @[[VEC_A:.*]] = #cir.zero : !cir.vector<4 x !s32i>
14+
15+
// LLVM: @[[VEC_A:.*]] = dso_local global <4 x i32> zeroinitializer
16+
17+
// OGCG: @[[VEC_A:.*]] = global <4 x i32> zeroinitializer
18+
19+
vd2 b;
20+
// CIR: cir.global external @[[VEC_B:.*]] = #cir.zero : !cir.vector<2 x !cir.double>
21+
22+
// LLVM: @[[VEC_B:.*]] = dso_local global <2 x double> zeroinitialize
23+
24+
// OGCG: @[[VEC_B:.*]] = global <2 x double> zeroinitializer
25+
26+
vll2 c;
27+
// CIR: cir.global external @[[VEC_C:.*]] = #cir.zero : !cir.vector<2 x !s64i>
28+
29+
// LLVM: @[[VEC_C:.*]] = dso_local global <2 x i64> zeroinitialize
30+
31+
// OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer
32+
33+
void vec_int_test() {
34+
vi4 a;
35+
vd2 b;
36+
vll2 c;
37+
}
38+
39+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
40+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["b"]
41+
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s64i>, !cir.ptr<!cir.vector<2 x !s64i>>, ["c"]
42+
43+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
44+
// LLVM: %[[VEC_B:.*]] = alloca <2 x double>, i64 1, align 16
45+
// LLVM: %[[VEC_C:.*]] = alloca <2 x i64>, i64 1, align 16
46+
47+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
48+
// OGCG: %[[VEC_B:.*]] = alloca <2 x double>, align 16
49+
// OGCG: %[[VEC_C:.*]] = alloca <2 x i64>, align 16
50+
51+
void foo2(vi4 p) {}
52+
53+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["p", init]
54+
// CIR: cir.store %{{.*}}, %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
55+
56+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
57+
// LLVM: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
58+
59+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
60+
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: cir-opt %s -verify-diagnostics -split-input-file
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
7+
// expected-error @below {{the number of vector elements must be non-zero}}
8+
cir.global external @vec_a = #cir.zero : !cir.vector<0 x !s32i>
9+
10+
}

clang/test/CIR/IR/invalid-vector.cir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: cir-opt %s -verify-diagnostics -split-input-file
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
7+
// expected-error @below {{unsupported element type for CIR vector}}
8+
cir.global external @vec_b = #cir.zero : !cir.vector<4 x !cir.array<!s32i x 10>>
9+
10+
}

clang/test/CIR/IR/vector.cir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: cir-opt %s | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
7+
cir.global external @vec_a = #cir.zero : !cir.vector<4 x !s32i>
8+
// CHECK: cir.global external @vec_a = #cir.zero : !cir.vector<4 x !s32i>
9+
10+
cir.global external @vec_b = #cir.zero : !cir.vector<3 x !s32i>
11+
// CHECK: cir.global external @vec_b = #cir.zero : !cir.vector<3 x !s32i>
12+
13+
cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>
14+
// CHECK: cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>
15+
16+
cir.func @vec_int_test() {
17+
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
18+
%1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
19+
%2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
20+
cir.return
21+
}
22+
23+
// CHECK: cir.func @vec_int_test() {
24+
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
25+
// CHECK: %1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
26+
// CHECK: %2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
27+
// CHECK: cir.return
28+
// CHECK: }
29+
30+
cir.func @vec_double_test() {
31+
%0 = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["a"]
32+
cir.return
33+
}
34+
35+
// CHECK: cir.func @vec_double_test() {
36+
// CHECK: %0 = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["a"]
37+
// CHECK: cir.return
38+
// CHECK: }
39+
40+
}

0 commit comments

Comments
 (0)