Skip to content

Commit b705233

Browse files
committed
[CIR] Upstream TernaryOp for VectorType
1 parent 4b2cb11 commit b705233

File tree

7 files changed

+141
-2
lines changed

7 files changed

+141
-2
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,4 +2194,40 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
21942194
let hasVerifier = 1;
21952195
}
21962196

2197+
//===----------------------------------------------------------------------===//
2198+
// VecTernaryOp
2199+
//===----------------------------------------------------------------------===//
2200+
2201+
def VecTernaryOp : CIR_Op<"vec.ternary",
2202+
[Pure, AllTypesMatch<["result", "vec1", "vec2"]>]> {
2203+
let summary = "The `cond ? a : b` ternary operator for vector types";
2204+
let description = [{
2205+
The `cir.vec.ternary` operation represents the C/C++ ternary operator,
2206+
`?:`, for vector types, which does a `select` on individual elements of the
2207+
vectors. Unlike a regular `?:` operator, there is no short circuiting. All
2208+
three arguments are always evaluated. Because there is no short
2209+
circuiting, there are no regions in this operation, unlike cir.ternary.
2210+
2211+
The first argument is a vector of integral type. The second and third
2212+
arguments are vectors of the same type and have the same number of elements
2213+
as the first argument.
2214+
2215+
The result is a vector of the same type as the second and third arguments.
2216+
Each element of the result is `(bool)a[n] ? b[n] : c[n]`.
2217+
}];
2218+
2219+
let arguments = (ins
2220+
IntegerVector:$cond,
2221+
CIR_VectorType:$vec1,
2222+
CIR_VectorType:$vec2
2223+
);
2224+
2225+
let results = (outs CIR_VectorType:$result);
2226+
let assemblyFormat = [{
2227+
`(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,`
2228+
qualified(type($vec1)) attr-dict
2229+
}];
2230+
let hasVerifier = 1;
2231+
}
2232+
21972233
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,36 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
201201
e->getSourceRange().getBegin());
202202
}
203203

204+
mlir::Value
205+
VisitAbstractConditionalOperator(const AbstractConditionalOperator *e) {
206+
mlir::Location loc = cgf.getLoc(e->getSourceRange());
207+
Expr *condExpr = e->getCond();
208+
Expr *lhsExpr = e->getTrueExpr();
209+
Expr *rhsExpr = e->getFalseExpr();
210+
211+
// OpenCL: If the condition is a vector, we can treat this condition like
212+
// the select function.
213+
if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) ||
214+
condExpr->getType()->isExtVectorType()) {
215+
cgf.getCIRGenModule().errorNYI(loc,
216+
"TernaryOp OpenCL VectorType condition");
217+
return {};
218+
}
219+
220+
if (condExpr->getType()->isVectorType() ||
221+
condExpr->getType()->isSveVLSBuiltinType()) {
222+
assert(condExpr->getType()->isVectorType() && "?: op for SVE vector NYI");
223+
mlir::Value condValue = Visit(condExpr);
224+
mlir::Value lhsValue = Visit(lhsExpr);
225+
mlir::Value rhsValue = Visit(rhsExpr);
226+
return builder.create<cir::VecTernaryOp>(loc, condValue, lhsValue,
227+
rhsValue);
228+
}
229+
230+
cgf.getCIRGenModule().errorNYI(loc, "TernaryOp for non vector types");
231+
return {};
232+
}
233+
204234
mlir::Value VisitMemberExpr(MemberExpr *e);
205235

206236
mlir::Value VisitInitListExpr(InitListExpr *e);

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,24 @@ LogicalResult cir::VecShuffleDynamicOp::verify() {
15891589
return success();
15901590
}
15911591

1592+
//===----------------------------------------------------------------------===//
1593+
// VecTernaryOp
1594+
//===----------------------------------------------------------------------===//
1595+
1596+
LogicalResult cir::VecTernaryOp::verify() {
1597+
// Verify that the condition operand has the same number of elements as the
1598+
// other operands. (The automatic verification already checked that all
1599+
// operands are vector types and that the second and third operands are the
1600+
// same type.)
1601+
if (mlir::cast<cir::VectorType>(getCond().getType()).getSize() !=
1602+
getVec1().getType().getSize()) {
1603+
return emitOpError() << ": the number of elements in "
1604+
<< getCond().getType() << " and "
1605+
<< getVec1().getType() << " don't match";
1606+
}
1607+
return success();
1608+
}
1609+
15921610
//===----------------------------------------------------------------------===//
15931611
// TableGen'd op method definitions
15941612
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1730,7 +1730,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
17301730
CIRToLLVMVecExtractOpLowering,
17311731
CIRToLLVMVecInsertOpLowering,
17321732
CIRToLLVMVecCmpOpLowering,
1733-
CIRToLLVMVecShuffleDynamicOpLowering
1733+
CIRToLLVMVecShuffleDynamicOpLowering,
1734+
CIRToLLVMVecTernaryOpLowering
17341735
// clang-format on
17351736
>(converter, patterns.getContext());
17361737

@@ -1934,6 +1935,20 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
19341935
return mlir::success();
19351936
}
19361937

1938+
mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
1939+
cir::VecTernaryOp op, OpAdaptor adaptor,
1940+
mlir::ConversionPatternRewriter &rewriter) const {
1941+
// Convert `cond` into a vector of i1, then use that in a `select` op.
1942+
mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
1943+
op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
1944+
rewriter.create<mlir::LLVM::ZeroOp>(
1945+
op.getCond().getLoc(),
1946+
typeConverter->convertType(op.getCond().getType())));
1947+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1948+
op, bitVec, adaptor.getVec1(), adaptor.getVec2());
1949+
return mlir::success();
1950+
}
1951+
19371952
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
19381953
return std::make_unique<ConvertCIRToLLVMPass>();
19391954
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,16 @@ class CIRToLLVMVecShuffleDynamicOpLowering
368368
mlir::ConversionPatternRewriter &) const override;
369369
};
370370

371+
class CIRToLLVMVecTernaryOpLowering
372+
: public mlir::OpConversionPattern<cir::VecTernaryOp> {
373+
public:
374+
using mlir::OpConversionPattern<cir::VecTernaryOp>::OpConversionPattern;
375+
376+
mlir::LogicalResult
377+
matchAndRewrite(cir::VecTernaryOp op, OpAdaptor,
378+
mlir::ConversionPatternRewriter &) const override;
379+
};
380+
371381
} // namespace direct
372382
} // namespace cir
373383

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,3 +1091,18 @@ void foo17() {
10911091
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
10921092
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
10931093
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1094+
1095+
void foo20() {
1096+
vi4 a;
1097+
vi4 b;
1098+
vi4 c;
1099+
vi4 r = c ? a : b;
1100+
}
1101+
1102+
// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
1103+
1104+
// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1105+
// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
1106+
1107+
// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1108+
// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,4 +1069,19 @@ void foo17() {
10691069

10701070
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
10711071
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
1072-
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1072+
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1073+
1074+
void foo20() {
1075+
vi4 a;
1076+
vi4 b;
1077+
vi4 c;
1078+
vi4 r = c ? a : b;
1079+
}
1080+
1081+
// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
1082+
1083+
// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1084+
// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
1085+
1086+
// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1087+
// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}

0 commit comments

Comments
 (0)