Skip to content

Commit ea3d903

Browse files
committed
[CIR] Upstream local initialization for VectorType
1 parent 93ff19c commit ea3d903

File tree

11 files changed

+239
-5
lines changed

11 files changed

+239
-5
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,4 +1903,28 @@ def TrapOp : CIR_Op<"trap", [Terminator]> {
19031903
let assemblyFormat = "attr-dict";
19041904
}
19051905

1906+
//===----------------------------------------------------------------------===//
1907+
// VecCreate
1908+
//===----------------------------------------------------------------------===//
1909+
1910+
def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
1911+
1912+
let summary = "Create a vector value";
1913+
let description = [{
1914+
The `cir.vec.create` operation creates a vector value with the given element
1915+
values. The number of element arguments must match the number of elements
1916+
in the vector type.
1917+
}];
1918+
1919+
let arguments = (ins Variadic<CIR_AnyType>:$elements);
1920+
let results = (outs CIR_VectorType:$result);
1921+
1922+
let assemblyFormat = [{
1923+
`(` ($elements^ `:` type($elements))? `)` `:` qualified(type($result))
1924+
attr-dict
1925+
}];
1926+
1927+
let hasVerifier = 1;
1928+
}
1929+
19061930
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,20 @@ void CIRGenFunction::emitStoreOfScalar(mlir::Value value, Address addr,
258258
bool isInit, bool isNontemporal) {
259259
assert(!cir::MissingFeatures::opLoadStoreThreadLocal());
260260

261-
if (ty->getAs<clang::VectorType>()) {
262-
cgm.errorNYI(addr.getPointer().getLoc(), "emitStoreOfScalar vector type");
263-
return;
261+
if (const auto *clangVecTy = ty->getAs<clang::VectorType>()) {
262+
// Boolean vectors use `iN` as storage type.
263+
if (clangVecTy->isExtVectorBoolType())
264+
cgm.errorNYI(addr.getPointer().getLoc(),
265+
"emitStoreOfScalar ExtVectorBoolType");
266+
267+
// Handle vectors of size 3 like size 4 for better performance.
268+
const mlir::Type elementType = addr.getElementType();
269+
const auto vecTy = cast<cir::VectorType>(elementType);
270+
271+
// TODO(CIR): Use `ABIInfo::getOptimalVectorMemoryType` once it upstreamed
272+
if (vecTy.getSize() == 3 && !getLangOpts().PreserveVec3Type)
273+
cgm.errorNYI(addr.getPointer().getLoc(),
274+
"emitStoreOfScalar Vec3 & PreserveVec3Type disabled");
264275
}
265276

266277
value = emitToMemory(value, ty);

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
170170

171171
mlir::Value VisitMemberExpr(MemberExpr *e);
172172

173+
mlir::Value VisitInitListExpr(InitListExpr *e);
174+
173175
mlir::Value VisitExplicitCastExpr(ExplicitCastExpr *e) {
174176
return VisitCastExpr(e);
175177
}
@@ -1584,6 +1586,47 @@ mlir::Value ScalarExprEmitter::VisitMemberExpr(MemberExpr *e) {
15841586
return emitLoadOfLValue(e);
15851587
}
15861588

1589+
mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *e) {
1590+
const unsigned numInitElements = e->getNumInits();
1591+
1592+
if (e->hadArrayRangeDesignator()) {
1593+
cgf.cgm.errorNYI(e->getSourceRange(), "ArrayRangeDesignator");
1594+
return {};
1595+
}
1596+
1597+
if (numInitElements == 0) {
1598+
cgf.cgm.errorNYI(e->getSourceRange(), "InitListExpr with 0 init elements");
1599+
return {};
1600+
}
1601+
1602+
if (e->getType()->isVectorType()) {
1603+
const auto vectorType =
1604+
mlir::cast<cir::VectorType>(cgf.convertType(e->getType()));
1605+
1606+
SmallVector<mlir::Value, 16> elements;
1607+
for (Expr *init : e->inits()) {
1608+
elements.push_back(Visit(init));
1609+
}
1610+
1611+
// Zero-initialize any remaining values.
1612+
if (numInitElements < vectorType.getSize()) {
1613+
mlir::TypedAttr zeroInitAttr =
1614+
cgf.getBuilder().getZeroInitAttr(vectorType.getElementType());
1615+
cir::ConstantOp zeroValue =
1616+
cgf.getBuilder().getConstant(cgf.getLoc(e->getSourceRange()), zeroInitAttr);
1617+
1618+
for (uint64_t i = numInitElements; i < vectorType.getSize(); ++i) {
1619+
elements.push_back(zeroValue);
1620+
}
1621+
}
1622+
1623+
return cgf.getBuilder().create<cir::VecCreateOp>(
1624+
cgf.getLoc(e->getSourceRange()), vectorType, elements);
1625+
}
1626+
1627+
return Visit(e->getInit(0));
1628+
}
1629+
15871630
mlir::Value CIRGenFunction::emitScalarConversion(mlir::Value src,
15881631
QualType srcTy, QualType dstTy,
15891632
SourceLocation loc) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,33 @@ LogicalResult cir::GetMemberOp::verify() {
13071307
return mlir::success();
13081308
}
13091309

1310+
//===----------------------------------------------------------------------===//
1311+
// VecCreateOp
1312+
//===----------------------------------------------------------------------===//
1313+
1314+
LogicalResult cir::VecCreateOp::verify() {
1315+
// Verify that the number of arguments matches the number of elements in the
1316+
// vector, and that the type of all the arguments matches the type of the
1317+
// elements in the vector.
1318+
const VectorType vecTy = getResult().getType();
1319+
if (getElements().size() != vecTy.getSize()) {
1320+
return emitOpError() << "operand count of " << getElements().size()
1321+
<< " doesn't match vector type " << vecTy
1322+
<< " element count of " << vecTy.getSize();
1323+
}
1324+
1325+
const mlir::Type elementType = vecTy.getElementType();
1326+
for (const mlir::Value element : getElements()) {
1327+
if (element.getType() != elementType) {
1328+
return emitOpError() << "operand type " << element.getType()
1329+
<< " doesn't match vector element type "
1330+
<< elementType;
1331+
}
1332+
}
1333+
1334+
return success();
1335+
}
1336+
13101337
//===----------------------------------------------------------------------===//
13111338
// TableGen'd op method definitions
13121339
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1599,7 +1599,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
15991599
CIRToLLVMStackSaveOpLowering,
16001600
CIRToLLVMStackRestoreOpLowering,
16011601
CIRToLLVMTrapOpLowering,
1602-
CIRToLLVMUnaryOpLowering
1602+
CIRToLLVMUnaryOpLowering,
1603+
CIRToLLVMVecCreateOpLowering
16031604
// clang-format on
16041605
>(converter, patterns.getContext());
16051606

@@ -1685,6 +1686,29 @@ mlir::LogicalResult CIRToLLVMStackRestoreOpLowering::matchAndRewrite(
16851686
return mlir::success();
16861687
}
16871688

1689+
mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
1690+
cir::VecCreateOp op, OpAdaptor adaptor,
1691+
mlir::ConversionPatternRewriter &rewriter) const {
1692+
// Start with an 'undef' value for the vector. Then 'insertelement' for
1693+
// each of the vector elements.
1694+
const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
1695+
const mlir::Type llvmTy = typeConverter->convertType(vecTy);
1696+
const mlir::Location loc = op.getLoc();
1697+
mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
1698+
assert(vecTy.getSize() == op.getElements().size() &&
1699+
"cir.vec.create op count doesn't match vector type elements count");
1700+
1701+
for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
1702+
const mlir::Value indexValue =
1703+
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
1704+
result = rewriter.create<mlir::LLVM::InsertElementOp>(
1705+
loc, result, adaptor.getElements()[i], indexValue);
1706+
}
1707+
1708+
rewriter.replaceOp(op, result);
1709+
return mlir::success();
1710+
}
1711+
16881712
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
16891713
return std::make_unique<ConvertCIRToLLVMPass>();
16901714
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,16 @@ class CIRToLLVMStackRestoreOpLowering
293293
mlir::ConversionPatternRewriter &rewriter) const override;
294294
};
295295

296+
class CIRToLLVMVecCreateOpLowering
297+
: public mlir::OpConversionPattern<cir::VecCreateOp> {
298+
public:
299+
using mlir::OpConversionPattern<cir::VecCreateOp>::OpConversionPattern;
300+
301+
mlir::LogicalResult
302+
matchAndRewrite(cir::VecCreateOp op, OpAdaptor,
303+
mlir::ConversionPatternRewriter &) const override;
304+
};
305+
296306
} // namespace direct
297307
} // namespace cir
298308

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,47 @@ vi4 vec_e = { 1, 2, 3, 4 };
4848

4949
// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
5050

51+
int x = 5;
52+
5153
void foo() {
5254
vi4 a;
5355
vi3 b;
5456
vi2 c;
5557
vd2 d;
58+
59+
vi4 e = { 1, 2, 3, 4 };
60+
61+
vi4 f = { x, 5, 6, x + 1 };
5662
}
5763

5864
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
5965
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
6066
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
6167
// CIR: %[[VEC_D:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["d"]
68+
// CIR: %[[VEC_E:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["e", init]
69+
// CIR: %[[VEC_F:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["f", init]
70+
// CIR: %[[VEC_E_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
71+
// CIR: cir.store %[[VEC_E_VAL]], %[[VEC_E]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
72+
// CIR: %[[VEC_F_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
73+
// CIR: cir.store %[[VEC_F_VAL]], %[[VEC_F]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
6274

6375
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
6476
// LLVM: %[[VEC_B:.*]] = alloca <3 x i32>, i64 1, align 16
6577
// LLVM: %[[VEC_C:.*]] = alloca <2 x i32>, i64 1, align 8
6678
// LLVM: %[[VEC_D:.*]] = alloca <2 x double>, i64 1, align 16
79+
// LLVM: %[[VEC_E:.*]] = alloca <4 x i32>, i64 1, align 16
80+
// LLVM: %[[VEC_F:.*]] = alloca <4 x i32>, i64 1, align 16
81+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_E]], align 16
82+
// LLVM: store <4 x i32> {{.*}}, ptr %[[VEC_F:.*]], align 16
6783

6884
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
6985
// OGCG: %[[VEC_B:.*]] = alloca <3 x i32>, align 16
7086
// OGCG: %[[VEC_C:.*]] = alloca <2 x i32>, align 8
7187
// OGCG: %[[VEC_D:.*]] = alloca <2 x double>, align 16
88+
// OGCG: %[[VEC_E:.*]] = alloca <4 x i32>, align 16
89+
// OGCG: %[[VEC_F:.*]] = alloca <4 x i32>, align 16
90+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_E]], align 16
91+
// OGCG: store <4 x i32> {{.*}}, ptr %[[VEC_F:.*]], align 16
7292

7393
void foo2(vi4 p) {}
7494

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,43 @@ vi4 d = { 1, 2, 3, 4 };
3939

4040
// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
4141

42-
void vec_int_test() {
42+
int x = 5;
43+
44+
void foo() {
4345
vi4 a;
4446
vd2 b;
4547
vll2 c;
48+
49+
vi4 d = { 1, 2, 3, 4 };
50+
51+
vi4 e = { x, 5, 6, x + 1 };
4652
}
4753

4854
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
4955
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["b"]
5056
// CIR: %[[VEC_C:.*]] = cir.alloca !cir.vector<2 x !s64i>, !cir.ptr<!cir.vector<2 x !s64i>>, ["c"]
57+
// CIR: %[[VEC_D:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
58+
// CIR: %[[VEC_E:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["e", init]
59+
// CIR: %[[VEC_D_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
60+
// CIR: cir.store %[[VEC_D_VAL]], %[[VEC_D]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
61+
// CIR: %[[VEC_E_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
62+
// CIR: cir.store %[[VEC_E_VAL]], %[[VEC_E]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
5163

5264
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
5365
// LLVM: %[[VEC_B:.*]] = alloca <2 x double>, i64 1, align 16
5466
// LLVM: %[[VEC_C:.*]] = alloca <2 x i64>, i64 1, align 16
67+
// LLVM: %[[VEC_D:.*]] = alloca <4 x i32>, i64 1, align 16
68+
// LLVM: %[[VEC_E:.*]] = alloca <4 x i32>, i64 1, align 16
69+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_D]], align 16
70+
// LLVM: store <4 x i32> {{.*}}, ptr %[[VEC_E:.*]], align 16
5571

5672
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
5773
// OGCG: %[[VEC_B:.*]] = alloca <2 x double>, align 16
5874
// OGCG: %[[VEC_C:.*]] = alloca <2 x i64>, align 16
75+
// OGCG: %[[VEC_D:.*]] = alloca <4 x i32>, align 16
76+
// OGCG: %[[VEC_E:.*]] = alloca <4 x i32>, align 16
77+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_D]], align 16
78+
// OGCG: store <4 x i32> {{.*}}, ptr %[[VEC_E:.*]], align 16
5979

6080
void foo2(vi4 p) {}
6181

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: cir-opt %s -verify-diagnostics -split-input-file
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @foo() {
7+
%1 = cir.const #cir.int<1> : !s32i
8+
%2 = cir.const #cir.int<2> : !s32i
9+
%3 = cir.const #cir.int<3> : !s32i
10+
%4 = cir.const #cir.int<4> : !s32i
11+
12+
// expected-error @below {{operand count of 4 doesn't match vector type '!cir.vector<8 x !cir.int<s, 32>>' element count of 8}}
13+
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<8 x !s32i>
14+
cir.return
15+
}
16+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: cir-opt %s -verify-diagnostics -split-input-file
2+
3+
!s32i = !cir.int<s, 32>
4+
!s64i = !cir.int<s, 64>
5+
6+
module {
7+
cir.func @foo() {
8+
%1 = cir.const #cir.int<1> : !s32i
9+
%2 = cir.const #cir.int<2> : !s32i
10+
%3 = cir.const #cir.int<3> : !s32i
11+
%4 = cir.const #cir.int<4> : !s64i
12+
13+
// expected-error @below {{operand type '!cir.int<s, 64>' doesn't match vector element type '!cir.int<s, 32>'}}
14+
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s64i) : !cir.vector<4 x !s32i>
15+
cir.return
16+
}
17+
}

clang/test/CIR/IR/vector.cir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,26 @@ cir.func @vec_double_test() {
4343
// CHECK: cir.return
4444
// CHECK: }
4545

46+
cir.func @local_vector_create_test() {
47+
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
48+
%1 = cir.const #cir.int<1> : !s32i
49+
%2 = cir.const #cir.int<2> : !s32i
50+
%3 = cir.const #cir.int<3> : !s32i
51+
%4 = cir.const #cir.int<4> : !s32i
52+
%5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
53+
cir.store %5, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
54+
cir.return
55+
}
56+
57+
// CHECK: cir.func @local_vector_create_test() {
58+
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
59+
// CHECK: %1 = cir.const #cir.int<1> : !s32i
60+
// CHECK: %2 = cir.const #cir.int<2> : !s32i
61+
// CHECK: %3 = cir.const #cir.int<3> : !s32i
62+
// CHECK: %4 = cir.const #cir.int<4> : !s32i
63+
// CHECK: %5 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
64+
// CHECK: cir.store %5, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
65+
// CHECK: cir.return
66+
// CHECK: }
67+
4668
}

0 commit comments

Comments
 (0)