Skip to content

[CIR] Upstream local initialization for VectorType #138107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1903,4 +1903,28 @@ def TrapOp : CIR_Op<"trap", [Terminator]> {
let assemblyFormat = "attr-dict";
}

//===----------------------------------------------------------------------===//
// VecCreate
//===----------------------------------------------------------------------===//

def VecCreateOp : CIR_Op<"vec.create", [Pure]> {

let summary = "Create a vector value";
let description = [{
The `cir.vec.create` operation creates a vector value with the given element
values. The number of element arguments must match the number of elements
in the vector type.
}];

let arguments = (ins Variadic<CIR_AnyType>:$elements);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
`(` ($elements^ `:` type($elements))? `)` `:` qualified(type($result))
attr-dict
}];

let hasVerifier = 1;
}

#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
17 changes: 14 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,20 @@ void CIRGenFunction::emitStoreOfScalar(mlir::Value value, Address addr,
bool isInit, bool isNontemporal) {
assert(!cir::MissingFeatures::opLoadStoreThreadLocal());

if (ty->getAs<clang::VectorType>()) {
cgm.errorNYI(addr.getPointer().getLoc(), "emitStoreOfScalar vector type");
return;
if (const auto *clangVecTy = ty->getAs<clang::VectorType>()) {
// Boolean vectors use `iN` as storage type.
if (clangVecTy->isExtVectorBoolType())
cgm.errorNYI(addr.getPointer().getLoc(),
"emitStoreOfScalar ExtVectorBoolType");

// Handle vectors of size 3 like size 4 for better performance.
const mlir::Type elementType = addr.getElementType();
const auto vecTy = cast<cir::VectorType>(elementType);

// TODO(CIR): Use `ABIInfo::getOptimalVectorMemoryType` once it upstreamed
if (vecTy.getSize() == 3 && !getLangOpts().PreserveVec3Type)
cgm.errorNYI(addr.getPointer().getLoc(),
"emitStoreOfScalar Vec3 & PreserveVec3Type disabled");
}

value = emitToMemory(value, ty);
Expand Down
39 changes: 39 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {

mlir::Value VisitMemberExpr(MemberExpr *e);

mlir::Value VisitInitListExpr(InitListExpr *e);

mlir::Value VisitExplicitCastExpr(ExplicitCastExpr *e) {
return VisitCastExpr(e);
}
Expand Down Expand Up @@ -1584,6 +1586,43 @@ mlir::Value ScalarExprEmitter::VisitMemberExpr(MemberExpr *e) {
return emitLoadOfLValue(e);
}

mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *e) {
const unsigned numInitElements = e->getNumInits();

if (e->hadArrayRangeDesignator()) {
cgf.cgm.errorNYI(e->getSourceRange(), "ArrayRangeDesignator");
return {};
}

if (numInitElements == 0) {
cgf.cgm.errorNYI(e->getSourceRange(), "InitListExpr with 0 init elements");
return {};
}

if (e->getType()->isVectorType()) {
const auto vectorType =
mlir::cast<cir::VectorType>(cgf.convertType(e->getType()));

SmallVector<mlir::Value, 16> elements;
for (Expr *init : e->inits()) {
elements.push_back(Visit(init));
}

// Zero-initialize any remaining values.
if (numInitElements < vectorType.getSize()) {
const mlir::Value zeroValue = cgf.getBuilder().getNullValue(
vectorType.getElementType(), cgf.getLoc(e->getSourceRange()));
std::fill_n(std::back_inserter(elements),
vectorType.getSize() - numInitElements, zeroValue);
}

return cgf.getBuilder().create<cir::VecCreateOp>(
cgf.getLoc(e->getSourceRange()), vectorType, elements);
}

return Visit(e->getInit(0));
}

mlir::Value CIRGenFunction::emitScalarConversion(mlir::Value src,
QualType srcTy, QualType dstTy,
SourceLocation loc) {
Expand Down
27 changes: 27 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,33 @@ LogicalResult cir::GetMemberOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// VecCreateOp
//===----------------------------------------------------------------------===//

LogicalResult cir::VecCreateOp::verify() {
// Verify that the number of arguments matches the number of elements in the
// vector, and that the type of all the arguments matches the type of the
// elements in the vector.
const VectorType vecTy = getResult().getType();
if (getElements().size() != vecTy.getSize()) {
return emitOpError() << "operand count of " << getElements().size()
<< " doesn't match vector type " << vecTy
<< " element count of " << vecTy.getSize();
}

const mlir::Type elementType = vecTy.getElementType();
for (const mlir::Value element : getElements()) {
if (element.getType() != elementType) {
return emitOpError() << "operand type " << element.getType()
<< " doesn't match vector element type "
<< elementType;
}
}

return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 25 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1599,7 +1599,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMStackSaveOpLowering,
CIRToLLVMStackRestoreOpLowering,
CIRToLLVMTrapOpLowering,
CIRToLLVMUnaryOpLowering
CIRToLLVMUnaryOpLowering,
CIRToLLVMVecCreateOpLowering
// clang-format on
>(converter, patterns.getContext());

Expand Down Expand Up @@ -1685,6 +1686,29 @@ mlir::LogicalResult CIRToLLVMStackRestoreOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
cir::VecCreateOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// Start with an 'undef' value for the vector. Then 'insertelement' for
// each of the vector elements.
const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
const mlir::Type llvmTy = typeConverter->convertType(vecTy);
const mlir::Location loc = op.getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
assert(vecTy.getSize() == op.getElements().size() &&
"cir.vec.create op count doesn't match vector type elements count");

for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not do this with an llvm..store operation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to see if i can pass the values to StoreOp and use it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed all comments except this one :D,

To use llvm::StoreOp we need to construct mlir::DenseElementsAttr, and it takes Array<Attribute>, but we have Array<Value> for Constant values. it's easy to get the attribute from them using getDefiningOp but now what is needed is to get value from non constant value for example x in { 1, 2, X, 3}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. We can possibly revisit this later. For constant values, it looks like it ends up as a single store in LLVM IR anyway.

const mlir::Value indexValue =
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
result = rewriter.create<mlir::LLVM::InsertElementOp>(
loc, result, adaptor.getElements()[i], indexValue);
}

rewriter.replaceOp(op, result);
return mlir::success();
}

std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ class CIRToLLVMStackRestoreOpLowering
mlir::ConversionPatternRewriter &rewriter) const override;
};

class CIRToLLVMVecCreateOpLowering
: public mlir::OpConversionPattern<cir::VecCreateOp> {
public:
using mlir::OpConversionPattern<cir::VecCreateOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::VecCreateOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

} // namespace direct
} // namespace cir

Expand Down
29 changes: 29 additions & 0 deletions clang/test/CIR/CodeGen/vector-ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,56 @@ vi4 vec_e = { 1, 2, 3, 4 };

// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

int x = 5;

void foo() {
vi4 a;
vi3 b;
vi2 c;
vd2 d;

vi4 e = { 1, 2, 3, 4 };

vi4 f = { x, 5, 6, x + 1 };

vi4 g = { 5 };
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
// CIR: %[[VEC_D:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["d"]
// CIR: %[[VEC_E:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["e", init]
// CIR: %[[VEC_F:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["f", init]
// CIR: %[[VEC_G:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["g", init]
// CIR: %[[VEC_E_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_E_VAL]], %[[VEC_E]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[VEC_F_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_F_VAL]], %[[VEC_F]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[VEC_G_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_G_VAL]], %[[VEC_G]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <3 x i32>, i64 1, align 16
// LLVM: %[[VEC_C:.*]] = alloca <2 x i32>, i64 1, align 8
// LLVM: %[[VEC_D:.*]] = alloca <2 x double>, i64 1, align 16
// LLVM: %[[VEC_E:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_F:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_G:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_E]], align 16
// LLVM: store <4 x i32> {{.*}}, ptr %[[VEC_F]], align 16
// LLVM: store <4 x i32> <i32 5, i32 0, i32 0, i32 0>, ptr %[[VEC_G]], align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <3 x i32>, align 16
// OGCG: %[[VEC_C:.*]] = alloca <2 x i32>, align 8
// OGCG: %[[VEC_D:.*]] = alloca <2 x double>, align 16
// OGCG: %[[VEC_E:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_F:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_G:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_E]], align 16
// OGCG: store <4 x i32> {{.*}}, ptr %[[VEC_F]], align 16
// OGCG: store <4 x i32> <i32 5, i32 0, i32 0, i32 0>, ptr %[[VEC_G]], align 16

void foo2(vi4 p) {}

Expand Down
31 changes: 30 additions & 1 deletion clang/test/CIR/CodeGen/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,52 @@ vi4 d = { 1, 2, 3, 4 };

// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

void vec_int_test() {
int x = 5;

void foo() {
vi4 a;
vd2 b;
vll2 c;

vi4 d = { 1, 2, 3, 4 };

vi4 e = { x, 5, 6, x + 1 };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test case where the initializer list doesn't fill the entire vector?


vi4 f = { 5 };
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["b"]
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s64i>, !cir.ptr<!cir.vector<2 x !s64i>>, ["c"]
// CIR: %[[VEC_D:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
// CIR: %[[VEC_E:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["e", init]
// CIR: %[[VEC_F:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["f", init]
// CIR: %[[VEC_D_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_D_VAL]], %[[VEC_D]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[VEC_E_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_E_VAL]], %[[VEC_E]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[VEC_F_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_F_VAL]], %[[VEC_F]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <2 x double>, i64 1, align 16
// LLVM: %[[VEC_C:.*]] = alloca <2 x i64>, i64 1, align 16
// LLVM: %[[VEC_D:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_E:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_F:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_D]], align 16
// LLVM: store <4 x i32> {{.*}}, ptr %[[VEC_E]], align 16
// LLVM: store <4 x i32> <i32 5, i32 0, i32 0, i32 0>, ptr %[[VEC_F]], align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <2 x double>, align 16
// OGCG: %[[VEC_C:.*]] = alloca <2 x i64>, align 16
// OGCG: %[[VEC_D:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_E:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_F:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_D]], align 16
// OGCG: store <4 x i32> {{.*}}, ptr %[[VEC_E]], align 16
// OGCG: store <4 x i32> <i32 5, i32 0, i32 0, i32 0>, ptr %[[VEC_F]], align 16

void foo2(vi4 p) {}

Expand Down
16 changes: 16 additions & 0 deletions clang/test/CIR/IR/invalid-vector-create-wrong-size.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file

!s32i = !cir.int<s, 32>

module {
cir.func @foo() {
%1 = cir.const #cir.int<1> : !s32i
%2 = cir.const #cir.int<2> : !s32i
%3 = cir.const #cir.int<3> : !s32i
%4 = cir.const #cir.int<4> : !s32i

// expected-error @below {{operand count of 4 doesn't match vector type '!cir.vector<8 x !cir.int<s, 32>>' element count of 8}}
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<8 x !s32i>
cir.return
}
}
17 changes: 17 additions & 0 deletions clang/test/CIR/IR/invalid-vector-create-wrong-type.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file

!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>

module {
cir.func @foo() {
%1 = cir.const #cir.int<1> : !s32i
%2 = cir.const #cir.int<2> : !s32i
%3 = cir.const #cir.int<3> : !s32i
%4 = cir.const #cir.int<4> : !s64i

// expected-error @below {{operand type '!cir.int<s, 64>' doesn't match vector element type '!cir.int<s, 32>'}}
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s64i) : !cir.vector<4 x !s32i>
cir.return
}
}
22 changes: 22 additions & 0 deletions clang/test/CIR/IR/vector.cir
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,26 @@ cir.func @vec_double_test() {
// CHECK: cir.return
// CHECK: }

cir.func @local_vector_create_test() {
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
%1 = cir.const #cir.int<1> : !s32i
%2 = cir.const #cir.int<2> : !s32i
%3 = cir.const #cir.int<3> : !s32i
%4 = cir.const #cir.int<4> : !s32i
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
cir.store %5, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
cir.return
}

// CHECK: cir.func @local_vector_create_test() {
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
// CHECK: %1 = cir.const #cir.int<1> : !s32i
// CHECK: %2 = cir.const #cir.int<2> : !s32i
// CHECK: %3 = cir.const #cir.int<3> : !s32i
// CHECK: %4 = cir.const #cir.int<4> : !s32i
// CHECK: %5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CHECK: cir.store %5, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CHECK: cir.return
// CHECK: }

}