Skip to content

[CIR] Upstream initial support for fixed size VectorType #136488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return cir::FPAttr::getZero(ty);
if (auto arrTy = mlir::dyn_cast<cir::ArrayType>(ty))
return cir::ZeroAttr::get(arrTy);
if (auto vecTy = mlir::dyn_cast<cir::VectorType>(ty))
return cir::ZeroAttr::get(vecTy);
if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty))
return getConstNullPtrAttr(ptrTy);
if (auto recordTy = mlir::dyn_cast<cir::RecordType>(ty))
Expand Down
41 changes: 39 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,43 @@ def CIR_ArrayType : CIR_Type<"Array", "array",
}];
}

//===----------------------------------------------------------------------===//
// VectorType (fixed size)
//===----------------------------------------------------------------------===//

def CIR_VectorType : CIR_Type<"Vector", "vector",
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {

let summary = "CIR vector type";
let description = [{
`!cir.vector' represents fixed-size vector types, parameterized
by the element type and the number of elements.

Example:

```mlir
!cir.vector<!u64i x 2>
!cir.vector<!cir.float x 4>
```
}];

let parameters = (ins "mlir::Type":$elementType, "uint64_t":$size);

let assemblyFormat = [{
`<` $size `x` $elementType `>`
}];

let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$elementType, "uint64_t":$size
), [{
return $_get(elementType.getContext(), elementType, size);
}]>,
];

let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
// FuncType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -533,8 +570,8 @@ def CIRRecordType : Type<
//===----------------------------------------------------------------------===//

def CIR_AnyType : AnyTypeOf<[
CIR_VoidType, CIR_BoolType, CIR_ArrayType, CIR_IntType, CIR_AnyFloat,
CIR_PointerType, CIR_FuncType, CIR_RecordType
CIR_VoidType, CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_IntType,
CIR_AnyFloat, CIR_PointerType, CIR_FuncType, CIR_RecordType
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
3 changes: 3 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
cir::IntType>(ty))
return true;

if (const auto vt = mlir::dyn_cast<cir::VectorType>(ty))
return isSized(vt.getElementType());

assert(!cir::MissingFeatures::unsizedTypes());
return false;
}
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,14 @@ mlir::Type CIRGenTypes::convertType(QualType type) {
break;
}

case Type::ExtVector:
case Type::Vector: {
const VectorType *vec = cast<VectorType>(ty);
const mlir::Type elemTy = convertType(vec->getElementType());
resultType = cir::VectorType::get(elemTy, vec->getNumElements());
break;
}

case Type::FunctionNoProto:
case Type::FunctionProto:
resultType = convertFunctionTypeInternal(type);
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
}

if (isa<cir::ZeroAttr>(attrType)) {
if (isa<cir::RecordType, cir::ArrayType>(opType))
if (isa<cir::RecordType, cir::ArrayType, cir::VectorType>(opType))
return success();
return op->emitOpError("zero expects struct or array type");
}
Expand Down
37 changes: 36 additions & 1 deletion clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ BoolType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
}

//===----------------------------------------------------------------------===//
// Definitions
// ArrayType Definitions
//===----------------------------------------------------------------------===//

llvm::TypeSize
Expand All @@ -667,6 +667,41 @@ ArrayType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
return dataLayout.getTypeABIAlignment(getEltType());
}

//===----------------------------------------------------------------------===//
// VectorType Definitions
//===----------------------------------------------------------------------===//

llvm::TypeSize cir::VectorType::getTypeSizeInBits(
const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
return llvm::TypeSize::getFixed(
getSize() * dataLayout.getTypeSizeInBits(getElementType()));
}

uint64_t
cir::VectorType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
return llvm::NextPowerOf2(dataLayout.getTypeSizeInBits(*this));
}

mlir::LogicalResult cir::VectorType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type elementType, uint64_t size) {
if (size == 0)
return emitError() << "the number of vector elements must be non-zero";

// Check if it a valid FixedVectorType
if (mlir::isa<cir::PointerType, cir::FP128Type>(elementType))
return success();

// Check if it a valid VectorType
if (mlir::isa<cir::IntType>(elementType) ||
isAnyFloatingPointType(elementType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this check in mlir::LLVM::isCompatibleVectorType():

    if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
      return intType.isSignless();
    return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
                     Float80Type, Float128Type>(elementType);

So that's not necessarily compatible with isAnyFloatingPointType to match the error message below. I don't know if we want to form an LLVM type here and call that function directly, but it seems like we otherwise run the risk of this getting out of sync.

On the other hand, I'd like to be able to create CIR vectors of other types like float8 without requiring that the LLVM dialect support them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmmmm maybe we should put our checks like it isa (Int, Index, Ptr, Float) 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we want to form an LLVM type here and call that function directly, but it seems like we otherwise run the risk of this getting out of sync.

My suggestion is to stay out of using LLVM dialect directly besides lowering, my reasons include:

  • LLVM dialect changes frequently, and since CIR doesn't yet build by default, it's a maintenance burden for us (while in this phase).
  • Staying out of sync could be a good thing here since it forces us to write tests for what we support.
  • The thing you mentioned about being able to support other things not necessarily LLVM related.

return success();

return emitError() << "unsupported element type for CIR vector";
}

//===----------------------------------------------------------------------===//
// PointerType Definitions
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -1392,6 +1393,10 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
convertTypeForMemory(converter, dataLayout, type.getEltType());
return mlir::LLVM::LLVMArrayType::get(ty, type.getSize());
});
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
const mlir::Type ty = converter.convertType(type.getElementType());
return mlir::VectorType::get(type.getSize(), ty);
});
converter.addConversion([&](cir::BoolType type) -> mlir::Type {
return mlir::IntegerType::get(type.getContext(), 1,
mlir::IntegerType::Signless);
Expand Down
73 changes: 73 additions & 0 deletions clang/test/CIR/CodeGen/vector-ext.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG

typedef int vi4 __attribute__((ext_vector_type(4)));
typedef int vi3 __attribute__((ext_vector_type(3)));
typedef int vi2 __attribute__((ext_vector_type(2)));
typedef double vd2 __attribute__((ext_vector_type(2)));

vi4 vec_a;
// CIR: cir.global external @[[VEC_A:.*]] = #cir.zero : !cir.vector<4 x !s32i>

// LLVM: @[[VEC_A:.*]] = dso_local global <4 x i32> zeroinitializer

// OGCG: @[[VEC_A:.*]] = global <4 x i32> zeroinitializer

vi3 vec_b;
// CIR: cir.global external @[[VEC_B:.*]] = #cir.zero : !cir.vector<3 x !s32i>

// LLVM: @[[VEC_B:.*]] = dso_local global <3 x i32> zeroinitializer

// OGCG: @[[VEC_B:.*]] = global <3 x i32> zeroinitializer

vi2 vec_c;
// CIR: cir.global external @[[VEC_C:.*]] = #cir.zero : !cir.vector<2 x !s32i>

// LLVM: @[[VEC_C:.*]] = dso_local global <2 x i32> zeroinitializer

// OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer

vd2 d;

// CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x !cir.double>

// LLVM: @[[VEC_D:.*]] = dso_local global <2 x double> zeroinitialize

// OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer

void foo() {
vi4 a;
vi3 b;
vi2 c;
vd2 d;
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
// CIR: %[[VEC_D:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["d"]

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <3 x i32>, i64 1, align 16
// LLVM: %[[VEC_C:.*]] = alloca <2 x i32>, i64 1, align 8
// LLVM: %[[VEC_D:.*]] = alloca <2 x double>, i64 1, align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <3 x i32>, align 16
// OGCG: %[[VEC_C:.*]] = alloca <2 x i32>, align 8
// OGCG: %[[VEC_D:.*]] = alloca <2 x double>, align 16

void foo2(vi4 p) {}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["p", init]
// CIR: cir.store %{{.*}}, %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
60 changes: 60 additions & 0 deletions clang/test/CIR/CodeGen/vector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG

typedef int vi4 __attribute__((vector_size(16)));
typedef double vd2 __attribute__((vector_size(16)));
typedef long long vll2 __attribute__((vector_size(16)));

vi4 vec_a;
// CIR: cir.global external @[[VEC_A:.*]] = #cir.zero : !cir.vector<4 x !s32i>

// LLVM: @[[VEC_A:.*]] = dso_local global <4 x i32> zeroinitializer

// OGCG: @[[VEC_A:.*]] = global <4 x i32> zeroinitializer

vd2 b;
// CIR: cir.global external @[[VEC_B:.*]] = #cir.zero : !cir.vector<2 x !cir.double>

// LLVM: @[[VEC_B:.*]] = dso_local global <2 x double> zeroinitialize

// OGCG: @[[VEC_B:.*]] = global <2 x double> zeroinitializer

vll2 c;
// CIR: cir.global external @[[VEC_C:.*]] = #cir.zero : !cir.vector<2 x !s64i>

// LLVM: @[[VEC_C:.*]] = dso_local global <2 x i64> zeroinitialize

// OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer

void vec_int_test() {
vi4 a;
vd2 b;
vll2 c;
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["b"]
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s64i>, !cir.ptr<!cir.vector<2 x !s64i>>, ["c"]

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <2 x double>, i64 1, align 16
// LLVM: %[[VEC_C:.*]] = alloca <2 x i64>, i64 1, align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <2 x double>, align 16
// OGCG: %[[VEC_C:.*]] = alloca <2 x i64>, align 16

void foo2(vi4 p) {}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["p", init]
// CIR: cir.store %{{.*}}, %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
10 changes: 10 additions & 0 deletions clang/test/CIR/IR/invalid-vector-zero-size.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file

!s32i = !cir.int<s, 32>

module {

// expected-error @below {{the number of vector elements must be non-zero}}
cir.global external @vec_a = #cir.zero : !cir.vector<0 x !s32i>

}
10 changes: 10 additions & 0 deletions clang/test/CIR/IR/invalid-vector.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file

!s32i = !cir.int<s, 32>

module {

// expected-error @below {{unsupported element type for CIR vector}}
cir.global external @vec_b = #cir.zero : !cir.vector<4 x !cir.array<!s32i x 10>>

}
40 changes: 40 additions & 0 deletions clang/test/CIR/IR/vector.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: cir-opt %s | FileCheck %s

!s32i = !cir.int<s, 32>

module {

cir.global external @vec_a = #cir.zero : !cir.vector<4 x !s32i>
// CHECK: cir.global external @vec_a = #cir.zero : !cir.vector<4 x !s32i>

cir.global external @vec_b = #cir.zero : !cir.vector<3 x !s32i>
// CHECK: cir.global external @vec_b = #cir.zero : !cir.vector<3 x !s32i>

cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>
// CHECK: cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>

cir.func @vec_int_test() {
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
%1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
%2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
cir.return
}

// CHECK: cir.func @vec_int_test() {
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CHECK: %1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
// CHECK: %2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
// CHECK: cir.return
// CHECK: }

cir.func @vec_double_test() {
%0 = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["a"]
cir.return
}

// CHECK: cir.func @vec_double_test() {
// CHECK: %0 = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["a"]
// CHECK: cir.return
// CHECK: }

}
Loading