Skip to content

Commit d0a6c4f

Browse files
committed
[CIR] Upstream extract op for VectorType
1 parent edb690d commit d0a6c4f

File tree

7 files changed

+147
-3
lines changed

7 files changed

+147
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,4 +1976,28 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
19761976
let hasVerifier = 1;
19771977
}
19781978

1979+
//===----------------------------------------------------------------------===//
1980+
// VecExtractOp
1981+
//===----------------------------------------------------------------------===//
1982+
1983+
def VecExtractOp : CIR_Op<"vec.extract", [Pure,
1984+
TypesMatchWith<"type of 'result' matches element type of 'vec'", "vec",
1985+
"result", "cast<VectorType>($_self).getElementType()">]> {
1986+
1987+
let summary = "Extract one element from a vector object";
1988+
let description = [{
1989+
The `cir.vec.extract` operation extracts the element at the given index
1990+
from a vector object.
1991+
}];
1992+
1993+
let arguments = (ins CIR_VectorType:$vec, CIR_AnyFundamentalIntType:$index);
1994+
let results = (outs CIR_AnyType:$result);
1995+
1996+
let assemblyFormat = [{
1997+
$vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
1998+
}];
1999+
2000+
let hasVerifier = 0;
2001+
}
2002+
19792003
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
161161
mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *e) {
162162
if (e->getBase()->getType()->isVectorType()) {
163163
assert(!cir::MissingFeatures::scalableVectors());
164-
cgf.getCIRGenModule().errorNYI("VisitArraySubscriptExpr: VectorType");
165-
return {};
164+
165+
const mlir::Location loc = cgf.getLoc(e->getSourceRange());
166+
const mlir::Value vecValue = Visit(e->getBase());
167+
const mlir::Value indexValue = Visit(e->getIdx());
168+
return cgf.builder.create<cir::VecExtractOp>(loc, vecValue, indexValue);
166169
}
167170
// Just load the lvalue formed by the subscript expression.
168171
return emitLoadOfLValue(e);

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
16001600
CIRToLLVMStackRestoreOpLowering,
16011601
CIRToLLVMTrapOpLowering,
16021602
CIRToLLVMUnaryOpLowering,
1603-
CIRToLLVMVecCreateOpLowering
1603+
CIRToLLVMVecCreateOpLowering,
1604+
CIRToLLVMVecExtractOpLowering
16041605
// clang-format on
16051606
>(converter, patterns.getContext());
16061607

@@ -1709,6 +1710,14 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
17091710
return mlir::success();
17101711
}
17111712

1713+
mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
1714+
cir::VecExtractOp op, OpAdaptor adaptor,
1715+
mlir::ConversionPatternRewriter &rewriter) const {
1716+
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
1717+
op, adaptor.getVec(), adaptor.getIndex());
1718+
return mlir::success();
1719+
}
1720+
17121721
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
17131722
return std::make_unique<ConvertCIRToLLVMPass>();
17141723
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,16 @@ class CIRToLLVMVecCreateOpLowering
303303
mlir::ConversionPatternRewriter &) const override;
304304
};
305305

306+
class CIRToLLVMVecExtractOpLowering
307+
: public mlir::OpConversionPattern<cir::VecExtractOp> {
308+
public:
309+
using mlir::OpConversionPattern<cir::VecExtractOp>::OpConversionPattern;
310+
311+
mlir::LogicalResult
312+
matchAndRewrite(cir::VecExtractOp op, OpAdaptor,
313+
mlir::ConversionPatternRewriter &) const override;
314+
};
315+
306316
} // namespace direct
307317
} // namespace cir
308318

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,36 @@ void foo2(vi4 p) {}
109109

110110
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
111111
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
112+
113+
void foo3() {
114+
vi4 a = { 1, 2, 3, 4 };
115+
int e = a[1];
116+
}
117+
118+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
119+
// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
120+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
121+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
122+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
123+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
124+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
125+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
126+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
127+
// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
128+
// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
129+
// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
130+
// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
131+
132+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
133+
// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
134+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
135+
// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
136+
// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
137+
// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
138+
139+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
140+
// OGCG: %[[INIT:.*]] = alloca i32, align 4
141+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
142+
// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
143+
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
144+
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,36 @@ void foo2(vi4 p) {}
9696

9797
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
9898
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
99+
100+
void foo3() {
101+
vi4 a = { 1, 2, 3, 4 };
102+
int e = a[1];
103+
}
104+
105+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
106+
// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
107+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
108+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
109+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
110+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
111+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
112+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
113+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
114+
// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
115+
// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
116+
// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
117+
// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
118+
119+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
120+
// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
121+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
122+
// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
123+
// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
124+
// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
125+
126+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
127+
// OGCG: %[[INIT:.*]] = alloca i32, align 4
128+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
129+
// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
130+
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
131+
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4

clang/test/CIR/IR/vector.cir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,36 @@ cir.func @local_vector_create_test() {
6565
// CHECK: cir.return
6666
// CHECK: }
6767

68+
cir.func @vector_extract_element_test() {
69+
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
70+
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
71+
%2 = cir.const #cir.int<1> : !s32i
72+
%3 = cir.const #cir.int<2> : !s32i
73+
%4 = cir.const #cir.int<3> : !s32i
74+
%5 = cir.const #cir.int<4> : !s32i
75+
%6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
76+
cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
77+
%7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
78+
%8 = cir.const #cir.int<1> : !s32i
79+
%9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
80+
cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
81+
cir.return
82+
}
83+
84+
// CHECK: cir.func @vector_extract_element_test() {
85+
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
86+
// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
87+
// CHECK: %2 = cir.const #cir.int<1> : !s32i
88+
// CHECK: %3 = cir.const #cir.int<2> : !s32i
89+
// CHECK: %4 = cir.const #cir.int<3> : !s32i
90+
// CHECK: %5 = cir.const #cir.int<4> : !s32i
91+
// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
92+
// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
93+
// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
94+
// CHECK: %8 = cir.const #cir.int<1> : !s32i
95+
// CHECK: %9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
96+
// CHECK: cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
97+
// CHECK: cir.return
98+
// CHECK: }
99+
68100
}

0 commit comments

Comments
 (0)