Skip to content

Commit 93ff19c

Browse files
authored
[CIR] Upstream global initialization for VectorType (#137511)
This change adds global initialization for VectorType Issue #136487
1 parent c51be1b commit 93ff19c

File tree

8 files changed

+156
-10
lines changed

8 files changed

+156
-10
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
204204
}]>
205205
];
206206

207-
// Printing and parsing available in CIRDialect.cpp
207+
// Printing and parsing available in CIRAttrs.cpp
208208
let hasCustomAssemblyFormat = 1;
209209

210210
// Enable verifier.
@@ -215,6 +215,38 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
215215
}];
216216
}
217217

218+
//===----------------------------------------------------------------------===//
219+
// ConstVectorAttr
220+
//===----------------------------------------------------------------------===//
221+
222+
def ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector",
223+
[TypedAttrInterface]> {
224+
let summary = "A constant vector from ArrayAttr";
225+
let description = [{
226+
A CIR vector attribute is an array of literals of the specified attribute
227+
types.
228+
}];
229+
230+
let parameters = (ins AttributeSelfTypeParameter<"">:$type,
231+
"mlir::ArrayAttr":$elts);
232+
233+
// Define a custom builder for the type; that removes the need to pass in an
234+
// MLIRContext instance, as it can be inferred from the `type`.
235+
let builders = [
236+
AttrBuilderWithInferredContext<(ins "cir::VectorType":$type,
237+
"mlir::ArrayAttr":$elts), [{
238+
return $_get(type.getContext(), type, elts);
239+
}]>
240+
];
241+
242+
let assemblyFormat = [{
243+
`<` $elts `>`
244+
}];
245+
246+
// Enable verifier.
247+
let genVerifyDecl = 1;
248+
}
249+
218250
//===----------------------------------------------------------------------===//
219251
// ConstPtrAttr
220252
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,27 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
373373
elements, typedFiller);
374374
}
375375
case APValue::Vector: {
376-
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate vector");
377-
return {};
376+
const QualType elementType =
377+
destType->castAs<VectorType>()->getElementType();
378+
const unsigned numElements = value.getVectorLength();
379+
380+
SmallVector<mlir::Attribute, 16> elements;
381+
elements.reserve(numElements);
382+
383+
for (unsigned i = 0; i < numElements; ++i) {
384+
const mlir::Attribute element =
385+
tryEmitPrivateForMemory(value.getVectorElt(i), elementType);
386+
if (!element)
387+
return {};
388+
elements.push_back(element);
389+
}
390+
391+
const auto desiredVecTy =
392+
mlir::cast<cir::VectorType>(cgm.convertType(destType));
393+
394+
return cir::ConstVectorAttr::get(
395+
desiredVecTy,
396+
mlir::ArrayAttr::get(cgm.getBuilder().getContext(), elements));
378397
}
379398
case APValue::MemberPointer: {
380399
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate member pointer");

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,47 @@ void ConstArrayAttr::print(AsmPrinter &printer) const {
299299
printer << ">";
300300
}
301301

302+
//===----------------------------------------------------------------------===//
303+
// CIR ConstVectorAttr
304+
//===----------------------------------------------------------------------===//
305+
306+
LogicalResult cir::ConstVectorAttr::verify(
307+
function_ref<::mlir::InFlightDiagnostic()> emitError, Type type,
308+
ArrayAttr elts) {
309+
310+
if (!mlir::isa<cir::VectorType>(type)) {
311+
return emitError() << "type of cir::ConstVectorAttr is not a "
312+
"cir::VectorType: "
313+
<< type;
314+
}
315+
316+
const auto vecType = mlir::cast<cir::VectorType>(type);
317+
318+
if (vecType.getSize() != elts.size()) {
319+
return emitError()
320+
<< "number of constant elements should match vector size";
321+
}
322+
323+
// Check if the types of the elements match
324+
LogicalResult elementTypeCheck = success();
325+
elts.walkImmediateSubElements(
326+
[&](Attribute element) {
327+
if (elementTypeCheck.failed()) {
328+
// An earlier element didn't match
329+
return;
330+
}
331+
auto typedElement = mlir::dyn_cast<TypedAttr>(element);
332+
if (!typedElement ||
333+
typedElement.getType() != vecType.getElementType()) {
334+
elementTypeCheck = failure();
335+
emitError() << "constant type should match vector element type";
336+
}
337+
},
338+
[&](Type) {});
339+
340+
return elementTypeCheck;
341+
}
342+
302343
//===----------------------------------------------------------------------===//
303344
// CIR Dialect
304345
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
244244
return success();
245245
}
246246

247-
if (mlir::isa<cir::ConstArrayAttr>(attrType))
247+
if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType))
248248
return success();
249249

250250
assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,17 @@ class CIRAttrToValue {
188188

189189
mlir::Value visit(mlir::Attribute attr) {
190190
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
191-
.Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, cir::ConstPtrAttr,
192-
cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
191+
.Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
192+
cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
193+
[&](auto attrT) { return visitCirAttr(attrT); })
193194
.Default([&](auto attrT) { return mlir::Value(); });
194195
}
195196

196197
mlir::Value visitCirAttr(cir::IntAttr intAttr);
197198
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
198199
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
199200
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
201+
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
200202
mlir::Value visitCirAttr(cir::ZeroAttr attr);
201203

202204
private:
@@ -275,6 +277,33 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
275277
return result;
276278
}
277279

280+
/// ConstVectorAttr visitor.
281+
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
282+
const mlir::Type llvmTy = converter->convertType(attr.getType());
283+
const mlir::Location loc = parentOp->getLoc();
284+
285+
SmallVector<mlir::Attribute> mlirValues;
286+
for (const mlir::Attribute elementAttr : attr.getElts()) {
287+
mlir::Attribute mlirAttr;
288+
if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
289+
mlirAttr = rewriter.getIntegerAttr(
290+
converter->convertType(intAttr.getType()), intAttr.getValue());
291+
} else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
292+
mlirAttr = rewriter.getFloatAttr(
293+
converter->convertType(floatAttr.getType()), floatAttr.getValue());
294+
} else {
295+
llvm_unreachable(
296+
"vector constant with an element that is neither an int nor a float");
297+
}
298+
mlirValues.push_back(mlirAttr);
299+
}
300+
301+
return rewriter.create<mlir::LLVM::ConstantOp>(
302+
loc, llvmTy,
303+
mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy),
304+
mlirValues));
305+
}
306+
278307
/// ZeroAttr visitor.
279308
mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) {
280309
mlir::Location loc = parentOp->getLoc();
@@ -888,7 +917,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
888917
cir::GlobalOp op, mlir::Attribute init,
889918
mlir::ConversionPatternRewriter &rewriter) const {
890919
// TODO: Generalize this handling when more types are needed here.
891-
assert((isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(init)));
920+
assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
921+
cir::ZeroAttr>(init)));
892922

893923
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
894924
// should be updated. For now, we use a custom op to initialize globals
@@ -941,8 +971,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
941971
op.emitError() << "unsupported initializer '" << init.value() << "'";
942972
return mlir::failure();
943973
}
944-
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
945-
init.value())) {
974+
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
975+
cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) {
946976
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
947977
// should be updated. For now, we use a custom op to initialize globals
948978
// to the appropriate value.

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,23 @@ vi2 vec_c;
3131

3232
// OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer
3333

34-
vd2 d;
34+
vd2 vec_d;
3535

3636
// CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x !cir.double>
3737

3838
// LLVM: @[[VEC_D:.*]] = dso_local global <2 x double> zeroinitialize
3939

4040
// OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer
4141

42+
vi4 vec_e = { 1, 2, 3, 4 };
43+
44+
// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
45+
// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
46+
47+
// LLVM: @[[VEC_E:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
48+
49+
// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
50+
4251
void foo() {
4352
vi4 a;
4453
vi3 b;

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ vll2 c;
3030

3131
// OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer
3232

33+
vi4 d = { 1, 2, 3, 4 };
34+
35+
// CIR: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
36+
// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
37+
38+
// LLVM: @[[VEC_D:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
39+
40+
// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
41+
3342
void vec_int_test() {
3443
vi4 a;
3544
vd2 b;

clang/test/CIR/IR/vector.cir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ cir.global external @vec_b = #cir.zero : !cir.vector<3 x !s32i>
1313
cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>
1414
// CHECK: cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>
1515

16+
cir.global external @vec_d = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2>
17+
: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
18+
19+
// CIR: cir.global external @vec_d = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
20+
// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
21+
1622
cir.func @vec_int_test() {
1723
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
1824
%1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]

0 commit comments

Comments
 (0)