Skip to content

Commit 61ba3e4

Browse files
authored
[CIR][LLVMLowering] Upstream unary operators for VectorType (#139444)
This change adds support for unary ops for VectorType Issue #136487
1 parent 0fc9cd1 commit 61ba3e4

File tree

3 files changed

+134
-11
lines changed

3 files changed

+134
-11
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,9 +1044,8 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
10441044
assert(op.getType() == op.getInput().getType() &&
10451045
"Unary operation's operand type and result type are different");
10461046
mlir::Type type = op.getType();
1047-
mlir::Type elementType = type;
1048-
bool isVector = false;
1049-
assert(!cir::MissingFeatures::vectorType());
1047+
mlir::Type elementType = elementTypeIfVector(type);
1048+
bool isVector = mlir::isa<cir::VectorType>(type);
10501049
mlir::Type llvmType = getTypeConverter()->convertType(type);
10511050
mlir::Location loc = op.getLoc();
10521051

@@ -1076,20 +1075,30 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
10761075
rewriter.replaceOp(op, adaptor.getInput());
10771076
return mlir::success();
10781077
case cir::UnaryOpKind::Minus: {
1079-
assert(!isVector &&
1080-
"Add vector handling when vector types are supported");
1081-
mlir::LLVM::ConstantOp zero = rewriter.create<mlir::LLVM::ConstantOp>(
1082-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
1078+
mlir::Value zero;
1079+
if (isVector)
1080+
zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
1081+
else
1082+
zero = rewriter.create<mlir::LLVM::ConstantOp>(
1083+
loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
10831084
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
10841085
op, llvmType, zero, adaptor.getInput(), maybeNSW);
10851086
return mlir::success();
10861087
}
10871088
case cir::UnaryOpKind::Not: {
10881089
// bit-wise compliment operator, implemented as an XOR with -1.
1089-
assert(!isVector &&
1090-
"Add vector handling when vector types are supported");
1091-
mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
1092-
loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
1090+
mlir::Value minusOne;
1091+
if (isVector) {
1092+
const uint64_t numElements =
1093+
mlir::dyn_cast<cir::VectorType>(type).getSize();
1094+
std::vector<int32_t> values(numElements, -1);
1095+
mlir::DenseIntElementsAttr denseVec = rewriter.getI32VectorAttr(values);
1096+
minusOne =
1097+
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, denseVec);
1098+
} else {
1099+
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
1100+
loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
1101+
}
10931102
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
10941103
op, llvmType, adaptor.getInput(), minusOne);
10951104
return mlir::success();

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,63 @@ void foo7() {
337337
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP2]], i32 %[[RES]], i32 2
338338
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
339339

340+
341+
void foo8() {
342+
vi4 a = { 1, 2, 3, 4 };
343+
vi4 plus_res = +a;
344+
vi4 minus_res = -a;
345+
vi4 not_res = ~a;
346+
}
347+
348+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
349+
// CIR: %[[PLUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["plus_res", init]
350+
// CIR: %[[MINUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["minus_res", init]
351+
// CIR: %[[NOT_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["not_res", init]
352+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
353+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
354+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
355+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
356+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
357+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
358+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
359+
// CIR: %[[TMP1:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
360+
// CIR: %[[PLUS:.*]] = cir.unary(plus, %[[TMP1]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
361+
// CIR: cir.store %[[PLUS]], %[[PLUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
362+
// CIR: %[[TMP2:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
363+
// CIR: %[[MINUS:.*]] = cir.unary(minus, %[[TMP2]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
364+
// CIR: cir.store %[[MINUS]], %[[MINUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
365+
// CIR: %[[TMP3:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
366+
// CIR: %[[NOT:.*]] = cir.unary(not, %[[TMP3]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
367+
// CIR: cir.store %[[NOT]], %[[NOT_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
368+
369+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
370+
// LLVM: %[[PLUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
371+
// LLVM: %[[MINUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
372+
// LLVM: %[[NOT_RES:.*]] = alloca <4 x i32>, i64 1, align 16
373+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
374+
// LLVM: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
375+
// LLVM: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
376+
// LLVM: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
377+
// LLVM: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
378+
// LLVM: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
379+
// LLVM: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
380+
// LLVM: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
381+
// LLVM: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
382+
383+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
384+
// OGCG: %[[PLUS_RES:.*]] = alloca <4 x i32>, align 16
385+
// OGCG: %[[MINUS_RES:.*]] = alloca <4 x i32>, align 16
386+
// OGCG: %[[NOT_RES:.*]] = alloca <4 x i32>, align 16
387+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
388+
// OGCG: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
389+
// OGCG: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
390+
// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
391+
// OGCG: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
392+
// OGCG: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
393+
// OGCG: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
394+
// OGCG: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
395+
// OGCG: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
396+
340397
void foo9() {
341398
vi4 a = {1, 2, 3, 4};
342399
vi4 b = {5, 6, 7, 8};

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,63 @@ void foo7() {
325325
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP2]], i32 %[[RES]], i32 2
326326
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
327327

328+
329+
void foo8() {
330+
vi4 a = { 1, 2, 3, 4 };
331+
vi4 plus_res = +a;
332+
vi4 minus_res = -a;
333+
vi4 not_res = ~a;
334+
}
335+
336+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
337+
// CIR: %[[PLUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["plus_res", init]
338+
// CIR: %[[MINUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["minus_res", init]
339+
// CIR: %[[NOT_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["not_res", init]
340+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
341+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
342+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
343+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
344+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
345+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
346+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
347+
// CIR: %[[TMP1:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
348+
// CIR: %[[PLUS:.*]] = cir.unary(plus, %[[TMP1]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
349+
// CIR: cir.store %[[PLUS]], %[[PLUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
350+
// CIR: %[[TMP2:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
351+
// CIR: %[[MINUS:.*]] = cir.unary(minus, %[[TMP2]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
352+
// CIR: cir.store %[[MINUS]], %[[MINUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
353+
// CIR: %[[TMP3:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
354+
// CIR: %[[NOT:.*]] = cir.unary(not, %[[TMP3]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
355+
// CIR: cir.store %[[NOT]], %[[NOT_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
356+
357+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
358+
// LLVM: %[[PLUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
359+
// LLVM: %[[MINUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
360+
// LLVM: %[[NOT_RES:.*]] = alloca <4 x i32>, i64 1, align 16
361+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
362+
// LLVM: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
363+
// LLVM: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
364+
// LLVM: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
365+
// LLVM: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
366+
// LLVM: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
367+
// LLVM: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
368+
// LLVM: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
369+
// LLVM: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
370+
371+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
372+
// OGCG: %[[PLUS_RES:.*]] = alloca <4 x i32>, align 16
373+
// OGCG: %[[MINUS_RES:.*]] = alloca <4 x i32>, align 16
374+
// OGCG: %[[NOT_RES:.*]] = alloca <4 x i32>, align 16
375+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
376+
// OGCG: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
377+
// OGCG: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
378+
// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
379+
// OGCG: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
380+
// OGCG: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
381+
// OGCG: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
382+
// OGCG: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
383+
// OGCG: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
384+
328385
void foo9() {
329386
vi4 a = {1, 2, 3, 4};
330387
vi4 b = {5, 6, 7, 8};

0 commit comments

Comments
 (0)