Skip to content

[CIR] Upstream local initialization for VectorType #138107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1903,4 +1903,28 @@ def TrapOp : CIR_Op<"trap", [Terminator]> {
let assemblyFormat = "attr-dict";
}

//===----------------------------------------------------------------------===//
// VecCreate
//===----------------------------------------------------------------------===//

def VecCreateOp : CIR_Op<"vec.create", [Pure]> {

let summary = "Create a vector value";
let description = [{
The `cir.vec.create` operation creates a vector value with the given element
values. The number of element arguments must match the number of elements
in the vector type.
}];

let arguments = (ins Variadic<CIR_AnyType>:$elements);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
`(` ($elements^ `:` type($elements))? `)` `:` qualified(type($result))
attr-dict
}];

let hasVerifier = 1;
}

#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
17 changes: 14 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,20 @@ void CIRGenFunction::emitStoreOfScalar(mlir::Value value, Address addr,
bool isInit, bool isNontemporal) {
assert(!cir::MissingFeatures::opLoadStoreThreadLocal());

if (ty->getAs<clang::VectorType>()) {
cgm.errorNYI(addr.getPointer().getLoc(), "emitStoreOfScalar vector type");
return;
if (const auto *clangVecTy = ty->getAs<clang::VectorType>()) {
// Boolean vectors use `iN` as storage type.
if (clangVecTy->isExtVectorBoolType())
cgm.errorNYI(addr.getPointer().getLoc(),
"emitStoreOfScalar ExtVectorBoolType");

// Handle vectors of size 3 like size 4 for better performance.
const mlir::Type elementType = addr.getElementType();
const auto vecTy = cast<cir::VectorType>(elementType);

// TODO(CIR): Use `ABIInfo::getOptimalVectorMemoryType` once it upstreamed
if (vecTy.getSize() == 3 && !getLangOpts().PreserveVec3Type)
cgm.errorNYI(addr.getPointer().getLoc(),
"emitStoreOfScalar Vec3 & PreserveVec3Type disabled");
}

value = emitToMemory(value, ty);
Expand Down
43 changes: 43 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {

mlir::Value VisitMemberExpr(MemberExpr *e);

mlir::Value VisitInitListExpr(InitListExpr *e);

mlir::Value VisitExplicitCastExpr(ExplicitCastExpr *e) {
return VisitCastExpr(e);
}
Expand Down Expand Up @@ -1584,6 +1586,47 @@ mlir::Value ScalarExprEmitter::VisitMemberExpr(MemberExpr *e) {
return emitLoadOfLValue(e);
}

mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *e) {
const unsigned numInitElements = e->getNumInits();

if (e->hadArrayRangeDesignator()) {
cgf.cgm.errorNYI(e->getSourceRange(), "ArrayRangeDesignator");
return {};
}

if (numInitElements == 0) {
cgf.cgm.errorNYI(e->getSourceRange(), "InitListExpr with 0 init elements");
return {};
}

if (e->getType()->isVectorType()) {
const auto vectorType =
mlir::cast<cir::VectorType>(cgf.convertType(e->getType()));

SmallVector<mlir::Value, 16> elements;
for (Expr *init : e->inits()) {
elements.push_back(Visit(init));
}

// Zero-initialize any remaining values.
if (numInitElements < vectorType.getSize()) {
mlir::TypedAttr zeroInitAttr =
cgf.getBuilder().getZeroInitAttr(vectorType.getElementType());
cir::ConstantOp zeroValue = cgf.getBuilder().getConstant(
cgf.getLoc(e->getSourceRange()), zeroInitAttr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mlir::TypedAttr zeroInitAttr =
cgf.getBuilder().getZeroInitAttr(vectorType.getElementType());
cir::ConstantOp zeroValue = cgf.getBuilder().getConstant(
cgf.getLoc(e->getSourceRange()), zeroInitAttr);
cir::ConstantOp zeroValue = cgf.getBuilder().getNullValue(
vectorType.getElementType(), cgf.getLoc(e->getSourceRange()));


for (uint64_t i = numInitElements; i < vectorType.getSize(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (uint64_t i = numInitElements; i < vectorType.getSize(); ++i) {
elements.assign(vectorType.getSize(), zeroValue);

Copy link
Member Author

@AmrDeveloper AmrDeveloper May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understood, elements.assign will assign zeroValue to all elements, not just the remaining uninit ones?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I had forgotten this vector already had entries from above. What about this?

std::fill_n(std::back_inserter(elements), vectorType.getSize(), zeroValue);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be like this

std::fill_n(std::back_inserter(elements), vectorType.getSize() - numInitElements, zeroValue);

To fill in the remaining size not init + full size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that looks right.

elements.push_back(zeroValue);
}
}

return cgf.getBuilder().create<cir::VecCreateOp>(
cgf.getLoc(e->getSourceRange()), vectorType, elements);
}

return Visit(e->getInit(0));
}

mlir::Value CIRGenFunction::emitScalarConversion(mlir::Value src,
QualType srcTy, QualType dstTy,
SourceLocation loc) {
Expand Down
27 changes: 27 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,33 @@ LogicalResult cir::GetMemberOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// VecCreateOp
//===----------------------------------------------------------------------===//

LogicalResult cir::VecCreateOp::verify() {
// Verify that the number of arguments matches the number of elements in the
// vector, and that the type of all the arguments matches the type of the
// elements in the vector.
const VectorType vecTy = getResult().getType();
if (getElements().size() != vecTy.getSize()) {
return emitOpError() << "operand count of " << getElements().size()
<< " doesn't match vector type " << vecTy
<< " element count of " << vecTy.getSize();
}

const mlir::Type elementType = vecTy.getElementType();
for (const mlir::Value element : getElements()) {
if (element.getType() != elementType) {
return emitOpError() << "operand type " << element.getType()
<< " doesn't match vector element type "
<< elementType;
}
}

return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 25 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1599,7 +1599,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMStackSaveOpLowering,
CIRToLLVMStackRestoreOpLowering,
CIRToLLVMTrapOpLowering,
CIRToLLVMUnaryOpLowering
CIRToLLVMUnaryOpLowering,
CIRToLLVMVecCreateOpLowering
// clang-format on
>(converter, patterns.getContext());

Expand Down Expand Up @@ -1685,6 +1686,29 @@ mlir::LogicalResult CIRToLLVMStackRestoreOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
cir::VecCreateOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// Start with an 'undef' value for the vector. Then 'insertelement' for
// each of the vector elements.
const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
const mlir::Type llvmTy = typeConverter->convertType(vecTy);
const mlir::Location loc = op.getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
assert(vecTy.getSize() == op.getElements().size() &&
"cir.vec.create op count doesn't match vector type elements count");

for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not do this with an llvm..store operation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to see if i can pass the values to StoreOp and use it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed all comments except this one :D,

To use llvm::StoreOp we need to construct mlir::DenseElementsAttr, and it takes Array<Attribute>, but we have Array<Value> for Constant values. it's easy to get the attribute from them using getDefiningOp but now what is needed is to get value from non constant value for example x in { 1, 2, X, 3}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. We can possibly revisit this later. For constant values, it looks like it ends up as a single store in LLVM IR anyway.

const mlir::Value indexValue =
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
result = rewriter.create<mlir::LLVM::InsertElementOp>(
loc, result, adaptor.getElements()[i], indexValue);
}

rewriter.replaceOp(op, result);
return mlir::success();
}

std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ class CIRToLLVMStackRestoreOpLowering
mlir::ConversionPatternRewriter &rewriter) const override;
};

class CIRToLLVMVecCreateOpLowering
: public mlir::OpConversionPattern<cir::VecCreateOp> {
public:
using mlir::OpConversionPattern<cir::VecCreateOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::VecCreateOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

} // namespace direct
} // namespace cir

Expand Down
20 changes: 20 additions & 0 deletions clang/test/CIR/CodeGen/vector-ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,47 @@ vi4 vec_e = { 1, 2, 3, 4 };

// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

int x = 5;

void foo() {
vi4 a;
vi3 b;
vi2 c;
vd2 d;

vi4 e = { 1, 2, 3, 4 };

vi4 f = { x, 5, 6, x + 1 };
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
// CIR: %[[VEC_D:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["d"]
// CIR: %[[VEC_E:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["e", init]
// CIR: %[[VEC_F:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["f", init]
// CIR: %[[VEC_E_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_E_VAL]], %[[VEC_E]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[VEC_F_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_F_VAL]], %[[VEC_F]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <3 x i32>, i64 1, align 16
// LLVM: %[[VEC_C:.*]] = alloca <2 x i32>, i64 1, align 8
// LLVM: %[[VEC_D:.*]] = alloca <2 x double>, i64 1, align 16
// LLVM: %[[VEC_E:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_F:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_E]], align 16
// LLVM: store <4 x i32> {{.*}}, ptr %[[VEC_F:.*]], align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <3 x i32>, align 16
// OGCG: %[[VEC_C:.*]] = alloca <2 x i32>, align 8
// OGCG: %[[VEC_D:.*]] = alloca <2 x double>, align 16
// OGCG: %[[VEC_E:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_F:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_E]], align 16
// OGCG: store <4 x i32> {{.*}}, ptr %[[VEC_F:.*]], align 16

void foo2(vi4 p) {}

Expand Down
22 changes: 21 additions & 1 deletion clang/test/CIR/CodeGen/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,43 @@ vi4 d = { 1, 2, 3, 4 };

// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>

void vec_int_test() {
int x = 5;

void foo() {
vi4 a;
vd2 b;
vll2 c;

vi4 d = { 1, 2, 3, 4 };

vi4 e = { x, 5, 6, x + 1 };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test case where the initializer list doesn't fill the entire vector?

}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["b"]
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s64i>, !cir.ptr<!cir.vector<2 x !s64i>>, ["c"]
// CIR: %[[VEC_D:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
// CIR: %[[VEC_E:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["e", init]
// CIR: %[[VEC_D_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_D_VAL]], %[[VEC_D]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[VEC_E_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_E_VAL]], %[[VEC_E]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <2 x double>, i64 1, align 16
// LLVM: %[[VEC_C:.*]] = alloca <2 x i64>, i64 1, align 16
// LLVM: %[[VEC_D:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_E:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_D]], align 16
// LLVM: store <4 x i32> {{.*}}, ptr %[[VEC_E:.*]], align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <2 x double>, align 16
// OGCG: %[[VEC_C:.*]] = alloca <2 x i64>, align 16
// OGCG: %[[VEC_D:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_E:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_D]], align 16
// OGCG: store <4 x i32> {{.*}}, ptr %[[VEC_E:.*]], align 16

void foo2(vi4 p) {}

Expand Down
16 changes: 16 additions & 0 deletions clang/test/CIR/IR/invalid-vector-create-wrong-size.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file

!s32i = !cir.int<s, 32>

module {
cir.func @foo() {
%1 = cir.const #cir.int<1> : !s32i
%2 = cir.const #cir.int<2> : !s32i
%3 = cir.const #cir.int<3> : !s32i
%4 = cir.const #cir.int<4> : !s32i

// expected-error @below {{operand count of 4 doesn't match vector type '!cir.vector<8 x !cir.int<s, 32>>' element count of 8}}
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<8 x !s32i>
cir.return
}
}
17 changes: 17 additions & 0 deletions clang/test/CIR/IR/invalid-vector-create-wrong-type.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file

!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>

module {
cir.func @foo() {
%1 = cir.const #cir.int<1> : !s32i
%2 = cir.const #cir.int<2> : !s32i
%3 = cir.const #cir.int<3> : !s32i
%4 = cir.const #cir.int<4> : !s64i

// expected-error @below {{operand type '!cir.int<s, 64>' doesn't match vector element type '!cir.int<s, 32>'}}
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s64i) : !cir.vector<4 x !s32i>
cir.return
}
}
22 changes: 22 additions & 0 deletions clang/test/CIR/IR/vector.cir
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,26 @@ cir.func @vec_double_test() {
// CHECK: cir.return
// CHECK: }

cir.func @local_vector_create_test() {
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
%1 = cir.const #cir.int<1> : !s32i
%2 = cir.const #cir.int<2> : !s32i
%3 = cir.const #cir.int<3> : !s32i
%4 = cir.const #cir.int<4> : !s32i
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
cir.store %5, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
cir.return
}

// CHECK: cir.func @local_vector_create_test() {
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
// CHECK: %1 = cir.const #cir.int<1> : !s32i
// CHECK: %2 = cir.const #cir.int<2> : !s32i
// CHECK: %3 = cir.const #cir.int<3> : !s32i
// CHECK: %4 = cir.const #cir.int<4> : !s32i
// CHECK: %5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CHECK: cir.store %5, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CHECK: cir.return
// CHECK: }

}