Skip to content

Commit 7a480c7

Browse files
committed
[CIR] Upstream shift operators for VectorType
1 parent 3e393d9 commit 7a480c7

File tree

6 files changed

+168
-16
lines changed

6 files changed

+168
-16
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,18 +1401,19 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
14011401
The `cir.shift` operation performs a bitwise shift, either to the left or to
14021402
the right, based on the first operand. The second operand specifies the
14031403
value to be shifted, and the third operand determines the number of
1404-
positions by which the shift is applied. Both the second and third operands
1405-
are required to be integers.
1404+
positions by which the shift is applied, they must be either all vector of
1405+
integer type, or all integer type. If they are vectors, each vector element of
1406+
the shift target is shifted by the corresponding shift amount in
1407+
the shift amount vector.
14061408

14071409
```mlir
1408-
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
1410+
%res = cir.shift(left, %lhs : !u64i, %amount : !s32i) -> !u64i
1411+
%new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs : !cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i>
14091412
```
14101413
}];
14111414

1412-
// TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
1413-
// update the description above.
1414-
let results = (outs CIR_IntType:$result);
1415-
let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
1415+
let results = (outs CIR_AnyIntOrVecOfInt:$result);
1416+
let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
14161417
UnitAttr:$isShiftleft);
14171418

14181419
let assemblyFormat = [{

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,23 @@ def CIR_PtrToVoidPtrType
174174
"$_builder.getType<" # cppType # ">("
175175
"cir::VoidType::get($_builder.getContext())))">;
176176

177+
//===----------------------------------------------------------------------===//
178+
// Vector Type predicates
179+
//===----------------------------------------------------------------------===//
180+
181+
// Vector of integral type
182+
def IntegerVector : Type<
183+
And<[
184+
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
185+
CPred<"::mlir::isa<::cir::IntType>("
186+
"::mlir::cast<::cir::VectorType>($_self).getElementType())">,
187+
CPred<"::mlir::cast<::cir::IntType>("
188+
"::mlir::cast<::cir::VectorType>($_self).getElementType())"
189+
".isFundamental()">
190+
]>, "!cir.vector of !cir.int"> {
191+
}
192+
193+
// Any Integer or Vector of Integer Constraints
194+
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
195+
177196
#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,9 +1297,8 @@ OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
12971297
LogicalResult cir::ShiftOp::verify() {
12981298
mlir::Operation *op = getOperation();
12991299
mlir::Type resType = getResult().getType();
1300-
assert(!cir::MissingFeatures::vectorType());
1301-
bool isOp0Vec = false;
1302-
bool isOp1Vec = false;
1300+
const bool isOp0Vec = mlir::isa<cir::VectorType>(op->getOperand(0).getType());
1301+
const bool isOp1Vec = mlir::isa<cir::VectorType>(op->getOperand(1).getType());
13031302
if (isOp0Vec != isOp1Vec)
13041303
return emitOpError() << "input types cannot be one vector and one scalar";
13051304
if (isOp1Vec && op->getOperand(1).getType() != resType) {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,16 +1376,17 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
13761376
auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
13771377

13781378
// Operands could also be vector type
1379-
assert(!cir::MissingFeatures::vectorType());
1379+
auto cirAmtVTy = mlir::dyn_cast<cir::VectorType>(op.getAmount().getType());
1380+
auto cirValVTy = mlir::dyn_cast<cir::VectorType>(op.getValue().getType());
13801381
mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
13811382
mlir::Value amt = adaptor.getAmount();
13821383
mlir::Value val = adaptor.getValue();
13831384

1384-
// TODO(cir): Assert for vector types
1385-
assert((cirValTy && cirAmtTy) &&
1385+
assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) &&
13861386
"shift input type must be integer or vector type, otherwise NYI");
13871387

1388-
assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");
1388+
assert((cirValTy == op.getType() || cirValVTy == op.getType()) &&
1389+
"inconsistent operands' types NYI");
13891390

13901391
// Ensure shift amount is the same type as the value. Some undefined
13911392
// behavior might occur in the casts below as per [C99 6.5.7.3].
@@ -1399,8 +1400,10 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
13991400
if (op.getIsShiftleft()) {
14001401
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
14011402
} else {
1402-
assert(!cir::MissingFeatures::vectorType());
1403-
bool isUnsigned = !cirValTy.isSigned();
1403+
const bool isUnsigned =
1404+
cirValTy
1405+
? !cirValTy.isSigned()
1406+
: !mlir::cast<cir::IntType>(cirValVTy.getElementType()).isSigned();
14041407
if (isUnsigned)
14051408
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
14061409
else

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,68 @@ void foo4() {
213213
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
214214
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
215215
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
216+
217+
void foo9() {
218+
vi4 a = {1, 2, 3, 4};
219+
vi4 b = {5, 6, 7, 8};
220+
221+
vi4 shl = a << b;
222+
vi4 shr = a >> b;
223+
}
224+
225+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
226+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
227+
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
228+
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
229+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
230+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
231+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
232+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
233+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
234+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
235+
// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
236+
// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
237+
// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
238+
// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
239+
// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
240+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
241+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
242+
// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
243+
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
244+
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
245+
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
246+
// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
247+
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
248+
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
249+
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
250+
// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
251+
252+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
253+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
254+
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
255+
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
256+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
257+
// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
258+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
259+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
260+
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
261+
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
262+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
263+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
264+
// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
265+
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
266+
267+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
268+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
269+
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
270+
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
271+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
272+
// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
273+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
274+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
275+
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
276+
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
277+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
278+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
279+
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
280+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,68 @@ void foo4() {
201201
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
202202
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
203203
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
204+
205+
void foo9() {
206+
vi4 a = {1, 2, 3, 4};
207+
vi4 b = {5, 6, 7, 8};
208+
209+
vi4 shl = a << b;
210+
vi4 shr = a >> b;
211+
}
212+
213+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
214+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
215+
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
216+
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
217+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
218+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
219+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
220+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
221+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
222+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
223+
// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
224+
// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
225+
// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
226+
// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
227+
// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
228+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
229+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
230+
// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
231+
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
232+
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
233+
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
234+
// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
235+
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
236+
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
237+
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
238+
// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
239+
240+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
241+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
242+
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
243+
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
244+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
245+
// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
246+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
247+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
248+
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
249+
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
250+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
251+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
252+
// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
253+
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
254+
255+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
256+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
257+
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
258+
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
259+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
260+
// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
261+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
262+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
263+
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
264+
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
265+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
266+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
267+
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
268+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16

0 commit comments

Comments
 (0)