Skip to content

Commit 81e1cd7

Browse files
committed
[CIR] Upstream ArraySubscriptExpr for fixed size array
1 parent ba3fa39 commit 81e1cd7

File tree

10 files changed

+268
-11
lines changed

10 files changed

+268
-11
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct MissingFeatures {
108108
static bool cgFPOptionsRAII() { return false; }
109109
static bool metaDataNode() { return false; }
110110
static bool fastMathFlags() { return false; }
111+
static bool emitCheckedInBoundsGEP() { return false; }
111112

112113
// Missing types
113114
static bool dataMemberType() { return false; }
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "CIRGenBuilder.h"
10+
11+
using namespace clang::CIRGen;
12+
13+
mlir::Value CIRGenBuilderTy::maybeBuildArrayDecay(mlir::Location loc,
14+
mlir::Value arrayPtr,
15+
mlir::Type eltTy) {
16+
const auto arrayPtrTy = mlir::cast<cir::PointerType>(arrayPtr.getType());
17+
const auto arrayTy = mlir::dyn_cast<cir::ArrayType>(arrayPtrTy.getPointee());
18+
19+
if (arrayTy) {
20+
const cir::PointerType flatPtrTy = getPointerTo(arrayTy.getEltType());
21+
return create<cir::CastOp>(loc, flatPtrTy, cir::CastKind::array_to_ptrdecay,
22+
arrayPtr);
23+
}
24+
25+
assert(arrayPtrTy.getPointee() == eltTy &&
26+
"flat pointee type must match original array element type");
27+
return arrayPtr;
28+
}
29+
30+
mlir::Value CIRGenBuilderTy::getArrayElement(mlir::Location arrayLocBegin,
31+
mlir::Location arrayLocEnd,
32+
mlir::Value arrayPtr,
33+
mlir::Type eltTy, mlir::Value idx,
34+
bool shouldDecay) {
35+
mlir::Value basePtr = arrayPtr;
36+
if (shouldDecay)
37+
basePtr = maybeBuildArrayDecay(arrayLocBegin, arrayPtr, eltTy);
38+
const mlir::Type flatPtrTy = basePtr.getType();
39+
return create<cir::PtrStrideOp>(arrayLocEnd, flatPtrTy, basePtr, idx);
40+
}

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,19 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
198198

199199
return create<cir::BinOp>(loc, cir::BinOpKind::Div, lhs, rhs);
200200
}
201+
202+
/// Create a cir.ptr_stride operation to get access to an array element.
203+
/// idx is the index of the element to access, shouldDecay is true if the
204+
/// result should decay to a pointer to the element type.
205+
mlir::Value getArrayElement(mlir::Location arrayLocBegin,
206+
mlir::Location arrayLocEnd, mlir::Value arrayPtr,
207+
mlir::Type eltTy, mlir::Value idx,
208+
bool shouldDecay);
209+
210+
/// Returns a decayed pointer to the first element of the array
211+
/// pointed to by arrayPtr.
212+
mlir::Value maybeBuildArrayDecay(mlir::Location loc, mlir::Value arrayPtr,
213+
mlir::Type eltTy);
201214
};
202215

203216
} // namespace clang::CIRGen

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "Address.h"
1414
#include "CIRGenFunction.h"
15+
#include "CIRGenModule.h"
1516
#include "CIRGenValue.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
1718
#include "clang/AST/Attr.h"
@@ -232,6 +233,161 @@ LValue CIRGenFunction::emitUnaryOpLValue(const UnaryOperator *e) {
232233
llvm_unreachable("Unknown unary operator kind!");
233234
}
234235

236+
/// If the specified expr is a simple decay from an array to pointer,
237+
/// return the array subexpression.
238+
/// FIXME: this could be abstracted into a commeon AST helper.
239+
static const Expr *isSimpleArrayDecayOperand(const Expr *e) {
240+
// If this isn't just an array->pointer decay, bail out.
241+
const auto *castExpr = dyn_cast<CastExpr>(e);
242+
if (!castExpr || castExpr->getCastKind() != CK_ArrayToPointerDecay)
243+
return nullptr;
244+
245+
// If this is a decay from variable width array, bail out.
246+
const Expr *subExpr = castExpr->getSubExpr();
247+
if (subExpr->getType()->isVariableArrayType())
248+
return nullptr;
249+
250+
return subExpr;
251+
}
252+
253+
static mlir::IntegerAttr getConstantIndexOrNull(mlir::Value idx) {
254+
// TODO(cir): should we consider using MLIRs IndexType instead of IntegerAttr?
255+
if (auto constantOp = dyn_cast<cir::ConstantOp>(idx.getDefiningOp()))
256+
return mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue());
257+
return {};
258+
}
259+
260+
static CharUnits getArrayElementAlign(CharUnits arrayAlign, mlir::Value idx,
261+
CharUnits eltSize) {
262+
// If we have a constant index, we can use the exact offset of the
263+
// element we're accessing.
264+
const mlir::IntegerAttr constantIdx = getConstantIndexOrNull(idx);
265+
if (constantIdx) {
266+
const CharUnits offset = constantIdx.getValue().getZExtValue() * eltSize;
267+
return arrayAlign.alignmentAtOffset(offset);
268+
}
269+
// Otherwise, use the worst-case alignment for any element.
270+
return arrayAlign.alignmentOfArrayElement(eltSize);
271+
}
272+
273+
static QualType getFixedSizeElementType(const ASTContext &astContext,
274+
const VariableArrayType *vla) {
275+
QualType eltType;
276+
do {
277+
eltType = vla->getElementType();
278+
} while ((vla = astContext.getAsVariableArrayType(eltType)));
279+
return eltType;
280+
}
281+
282+
static mlir::Value
283+
emitArraySubscriptPtr(CIRGenFunction &cgf, mlir::Location beginLoc,
284+
mlir::Location endLoc, mlir::Value ptr, mlir::Type eltTy,
285+
ArrayRef<mlir::Value> indices, bool inbounds,
286+
bool signedIndices, bool shouldDecay,
287+
const llvm::Twine &name = "arrayidx") {
288+
if (indices.size() > 1) {
289+
cgf.cgm.errorNYI("emitArraySubscriptPtr: handle multiple indices");
290+
return {};
291+
}
292+
293+
const mlir::Value idx = indices.back();
294+
CIRGenModule &cgm = cgf.getCIRGenModule();
295+
// TODO(cir): LLVM codegen emits in bound gep check here, is there anything
296+
// that would enhance tracking this later in CIR?
297+
if (inbounds)
298+
assert(!cir::MissingFeatures::emitCheckedInBoundsGEP() && "NYI");
299+
return cgm.getBuilder().getArrayElement(beginLoc, endLoc, ptr, eltTy, idx,
300+
shouldDecay);
301+
}
302+
303+
static Address emitArraySubscriptPtr(
304+
CIRGenFunction &cgf, mlir::Location beginLoc, mlir::Location endLoc,
305+
Address addr, ArrayRef<mlir::Value> indices, QualType eltType,
306+
bool inbounds, bool signedIndices, mlir::Location loc, bool shouldDecay,
307+
QualType *arrayType = nullptr, const Expr *base = nullptr,
308+
const llvm::Twine &name = "arrayidx") {
309+
310+
// Determine the element size of the statically-sized base. This is
311+
// the thing that the indices are expressed in terms of.
312+
if (const VariableArrayType *vla =
313+
cgf.getContext().getAsVariableArrayType(eltType)) {
314+
eltType = getFixedSizeElementType(cgf.getContext(), vla);
315+
}
316+
317+
// We can use that to compute the best alignment of the element.
318+
const CharUnits eltSize = cgf.getContext().getTypeSizeInChars(eltType);
319+
const CharUnits eltAlign =
320+
getArrayElementAlign(addr.getAlignment(), indices.back(), eltSize);
321+
322+
mlir::Value eltPtr;
323+
const mlir::IntegerAttr lastIndex = getConstantIndexOrNull(indices.back());
324+
if (!lastIndex) {
325+
eltPtr = emitArraySubscriptPtr(cgf, beginLoc, endLoc, addr.getPointer(),
326+
addr.getElementType(), indices, inbounds,
327+
signedIndices, shouldDecay, name);
328+
}
329+
const mlir::Type elementType = cgf.convertTypeForMem(eltType);
330+
return Address(eltPtr, elementType, eltAlign);
331+
}
332+
333+
LValue
334+
CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
335+
if (e->getBase()->getType()->isVectorType() &&
336+
!isa<ExtVectorElementExpr>(e->getBase())) {
337+
cgm.errorNYI(e->getSourceRange(), "emitArraySubscriptExpr: VectorType");
338+
return {};
339+
}
340+
341+
if (isa<ExtVectorElementExpr>(e->getBase())) {
342+
cgm.errorNYI(e->getSourceRange(),
343+
"emitArraySubscriptExpr: ExtVectorElementExpr");
344+
return {};
345+
}
346+
347+
if (getContext().getAsVariableArrayType(e->getType())) {
348+
cgm.errorNYI(e->getSourceRange(),
349+
"emitArraySubscriptExpr: VariableArrayType");
350+
return {};
351+
}
352+
353+
if (e->getType()->getAs<ObjCObjectType>()) {
354+
cgm.errorNYI(e->getSourceRange(), "emitArraySubscriptExpr: ObjCObjectType");
355+
return {};
356+
}
357+
358+
// The index must always be an integer, which is not an aggregate. Emit it
359+
// in lexical order (this complexity is, sadly, required by C++17).
360+
assert((e->getIdx() == e->getLHS() || e->getIdx() == e->getRHS()) &&
361+
"index was neither LHS nor RHS");
362+
const mlir::Value idx = emitScalarExpr(e->getIdx());
363+
const QualType idxTy = e->getIdx()->getType();
364+
const bool signedIndices = idxTy->isSignedIntegerOrEnumerationType();
365+
366+
if (const Expr *array = isSimpleArrayDecayOperand(e->getBase())) {
367+
LValue arrayLV;
368+
if (const auto *ase = dyn_cast<ArraySubscriptExpr>(array))
369+
arrayLV = emitArraySubscriptExpr(ase);
370+
else
371+
arrayLV = emitLValue(array);
372+
373+
// Propagate the alignment from the array itself to the result.
374+
QualType arrayType = array->getType();
375+
const Address addr = emitArraySubscriptPtr(
376+
*this, cgm.getLoc(array->getBeginLoc()), cgm.getLoc(array->getEndLoc()),
377+
arrayLV.getAddress(), {idx}, e->getType(),
378+
!getLangOpts().isSignedOverflowDefined(), signedIndices,
379+
cgm.getLoc(e->getExprLoc()), /*shouldDecay=*/true, &arrayType,
380+
e->getBase());
381+
382+
return LValue::makeAddr(addr, e->getType());
383+
}
384+
385+
// The base must be a pointer; emit it with an estimate of its alignment.
386+
cgm.errorNYI(e->getSourceRange(),
387+
"emitArraySubscriptExpr: The base must be a pointer");
388+
return {};
389+
}
390+
235391
LValue CIRGenFunction::emitBinaryOperatorLValue(const BinaryOperator *e) {
236392
// Comma expressions just emit their LHS then their RHS as an l-value.
237393
if (e->getOpcode() == BO_Comma) {

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,15 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
157157

158158
mlir::Value VisitCastExpr(CastExpr *e);
159159

160+
mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *e) {
161+
if (e->getBase()->getType()->isVectorType()) {
162+
assert(!cir::MissingFeatures::scalableVectors() &&
163+
"NYI: index into scalable vector");
164+
}
165+
// Just load the lvalue formed by the subscript expression.
166+
return emitLoadOfLValue(e);
167+
}
168+
160169
mlir::Value VisitExplicitCastExpr(ExplicitCastExpr *e) {
161170
return VisitCastExpr(e);
162171
}

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,8 @@ LValue CIRGenFunction::emitLValue(const Expr *e) {
444444
std::string("l-value not implemented for '") +
445445
e->getStmtClassName() + "'");
446446
return LValue();
447+
case Expr::ArraySubscriptExprClass:
448+
return emitArraySubscriptExpr(cast<ArraySubscriptExpr>(e));
447449
case Expr::UnaryOperatorClass:
448450
return emitUnaryOpLValue(cast<UnaryOperator>(e));
449451
case Expr::BinaryOperatorClass:

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ class CIRGenFunction : public CIRGenTypeCache {
387387
/// should be returned.
388388
RValue emitAnyExpr(const clang::Expr *e);
389389

390+
LValue emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e);
391+
390392
AutoVarEmission emitAutoVarAlloca(const clang::VarDecl &d);
391393

392394
/// Emit code and set up symbol table for a variable declaration with auto,

clang/lib/CIR/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
88

99
add_clang_library(clangCIR
1010
CIRGenerator.cpp
11+
CIRGenBuilder.cpp
1112
CIRGenDecl.cpp
1213
CIRGenExpr.cpp
1314
CIRGenExprAggregate.cpp

clang/test/CIR/CodeGen/array.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,15 @@ int f[5] = {1, 2};
2929

3030
void func() {
3131
int arr[10];
32-
3332
// CHECK: %[[ARR:.*]] = cir.alloca !cir.array<!s32i x 10>, !cir.ptr<!cir.array<!s32i x 10>>, ["arr"]
33+
34+
int e = arr[1];
35+
// CHECK: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
36+
// CHECK: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
37+
// CHECK: %[[ARR_PTR:.*]] = cir.cast(array_to_ptrdecay, %[[ARR]] : !cir.ptr<!cir.array<!s32i x 10>>), !cir.ptr<!s32i>
38+
// CHECK: %[[ELE_PTR:.*]] = cir.ptr_stride(%[[ARR_PTR]] : !cir.ptr<!s32i>, %[[IDX]] : !s32i), !cir.ptr<!s32i>
39+
// CHECK: %[[TMP:.*]] = cir.load %[[ELE_PTR]] : !cir.ptr<!s32i>, !s32i
40+
// CHECK" cir.store %[[TMP]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
3441
}
3542

3643
void func2() {
@@ -69,6 +76,7 @@ void func4() {
6976
int arr[2][1] = {{5}, {6}};
7077

7178
// CHECK: %[[ARR:.*]] = cir.alloca !cir.array<!cir.array<!s32i x 1> x 2>, !cir.ptr<!cir.array<!cir.array<!s32i x 1> x 2>>, ["arr", init]
79+
// CHECK: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
7280
// CHECK: %[[ARR_PTR:.*]] = cir.cast(array_to_ptrdecay, %[[ARR]] : !cir.ptr<!cir.array<!cir.array<!s32i x 1> x 2>>), !cir.ptr<!cir.array<!s32i x 1>>
7381
// CHECK: %[[ARR_0_PTR:.*]] = cir.cast(array_to_ptrdecay, %[[ARR_PTR]] : !cir.ptr<!cir.array<!s32i x 1>>), !cir.ptr<!s32i>
7482
// CHECK: %[[V_0_0:.*]] = cir.const #cir.int<5> : !s32i
@@ -78,6 +86,17 @@ void func4() {
7886
// CHECK: %[[ARR_1_PTR:.*]] = cir.cast(array_to_ptrdecay, %[[ARR_1]] : !cir.ptr<!cir.array<!s32i x 1>>), !cir.ptr<!s32i>
7987
// CHECK: %[[V_1_0:.*]] = cir.const #cir.int<6> : !s32i
8088
// CHECK: cir.store %[[V_1_0]], %[[ARR_1_PTR]] : !s32i, !cir.ptr<!s32i>
89+
90+
int e = arr[1][0];
91+
92+
// CHECK: %[[IDX:.*]] = cir.const #cir.int<0> : !s32i
93+
// CHECK: %[[IDX_1:.*]] = cir.const #cir.int<1> : !s32i
94+
// CHECK: %[[ARR_PTR:.*]] = cir.cast(array_to_ptrdecay, %[[ARR]] : !cir.ptr<!cir.array<!cir.array<!s32i x 1> x 2>>), !cir.ptr<!cir.array<!s32i x 1>>
95+
// CHECK: %[[ARR_1:.*]] = cir.ptr_stride(%[[ARR_PTR]] : !cir.ptr<!cir.array<!s32i x 1>>, %[[IDX_1]] : !s32i), !cir.ptr<!cir.array<!s32i x 1>>
96+
// CHECK: %[[ARR_1_PTR:.*]] = cir.cast(array_to_ptrdecay, %[[ARR_1]] : !cir.ptr<!cir.array<!s32i x 1>>), !cir.ptr<!s32i>
97+
// CHECK: %[[ELE_0:.*]] = cir.ptr_stride(%[[ARR_1_PTR]] : !cir.ptr<!s32i>, %[[IDX]] : !s32i), !cir.ptr<!s32i>
98+
// CHECK: %[[TMP:.*]] = cir.load %[[ELE_0]] : !cir.ptr<!s32i>, !s32i
99+
// CHECK: cir.store %[[TMP]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
81100
}
82101

83102
void func5() {

clang/test/CIR/Lowering/array.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,18 @@ int f[5] = {1, 2};
3131

3232
void func() {
3333
int arr[10];
34+
int e = arr[1];
3435
}
3536
// CHECK: define void @func()
36-
// CHECK-NEXT: alloca [10 x i32], i64 1, align 16
37+
// CHECK-NEXT: %[[ARR_ALLOCA:.*]] = alloca [10 x i32], i64 1, align 16
38+
// CHECK-NEXT: %[[INIT:.*]] = alloca i32, i64 1, align 4
39+
// CHECK-NEXT: %[[ARR_PTR:.*]] = getelementptr i32, ptr %[[ARR_ALLOCA]], i32 0
40+
// CHECK-NEXT: %[[ELE_PTR:.*]] = getelementptr i32, ptr %[[ARR_PTR]], i64 1
41+
// CHECK-NEXT: %[[TMP:.*]] = load i32, ptr %[[ELE_PTR]], align 4
42+
// CHECK-NEXT: store i32 %[[TMP]], ptr %[[INIT]], align 4
3743

3844
void func2() {
39-
int arr2[2] = {5};
45+
int arr[2] = {5};
4046
}
4147
// CHECK: define void @func2()
4248
// CHECK: %[[ARR_ALLOCA:.*]] = alloca [2 x i32], i64 1, align 4
@@ -61,19 +67,27 @@ void func3() {
6167
// CHECK: store i32 6, ptr %[[ELE_1_PTR]], align 4
6268

6369
void func4() {
64-
int arr4[2][1] = {{5}, {6}};
70+
int arr[2][1] = {{5}, {6}};
71+
int e = arr[1][0];
6572
}
6673
// CHECK: define void @func4()
6774
// CHECK: %[[ARR_ALLOCA:.*]] = alloca [2 x [1 x i32]], i64 1, align 4
68-
// CHECK: %[[ARR_0:.*]] = getelementptr [1 x i32], ptr %[[ARR_ALLOCA]], i32 0
69-
// CHECK: %[[ARR_0_ELE_0:.*]] = getelementptr i32, ptr %[[ARR_0]], i32 0
70-
// CHECK: store i32 5, ptr %[[ARR_0_ELE_0]], align 4
71-
// CHECK: %[[ARR_1:.*]] = getelementptr [1 x i32], ptr %2, i64 1
72-
// CHECK: %[[ARR_0_ELE_0:.*]] = getelementptr i32, ptr %[[ARR_1]], i32 0
73-
// CHECK: store i32 6, ptr %[[ARR_0_ELE_0]], align 4
75+
// CHECK: %[[INIT:.*]] = alloca i32, i64 1, align 4
76+
// CHECK: %[[ARR_PTR:.*]] = getelementptr [1 x i32], ptr %[[ARR_ALLOCA]], i32 0
77+
// CHECK: %[[ARR_0_0:.*]] = getelementptr i32, ptr %[[ARR_PTR]], i32 0
78+
// CHECK: store i32 5, ptr %[[ARR_0_0]], align 4
79+
// CHECK: %[[ARR_1:.*]] = getelementptr [1 x i32], ptr %[[ARR_PTR]], i64 1
80+
// CHECK: %[[ARR_1_0:.*]] = getelementptr i32, ptr %[[ARR_1]], i32 0
81+
// CHECK: store i32 6, ptr %[[ARR_1_0]], align 4
82+
// CHECK: %[[ARR_PTR:.*]] = getelementptr [1 x i32], ptr %[[ARR_ALLOCA]], i32 0
83+
// CHECK: %[[ARR_1:.*]] = getelementptr [1 x i32], ptr %[[ARR_PTR]], i64 1
84+
// CHECK: %[[ARR_1_0:.*]] = getelementptr i32, ptr %[[ARR_1]], i32 0
85+
// CHECK: %[[ELE_PTR:.*]] = getelementptr i32, ptr %[[ARR_1_0]], i64 0
86+
// CHECK: %[[TMP:.*]] = load i32, ptr %[[ELE_PTR]], align 4
87+
// CHECK: store i32 %[[TMP]], ptr %[[INIT]], align 4
7488

7589
void func5() {
76-
int arr5[2][1] = {{5}};
90+
int arr[2][1] = {{5}};
7791
}
7892
// CHECK: define void @func5()
7993
// CHECK: %[[ARR_ALLOCA:.*]] = alloca [2 x [1 x i32]], i64 1, align 4

0 commit comments

Comments
 (0)