Skip to content

Commit 120e9a6

Browse files
committed
[CIR] Upstream TernaryOp for VectorType
1 parent 3a45d55 commit 120e9a6

File tree

7 files changed

+141
-2
lines changed

7 files changed

+141
-2
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,4 +2190,40 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
21902190
let hasVerifier = 1;
21912191
}
21922192

2193+
//===----------------------------------------------------------------------===//
2194+
// VecTernaryOp
2195+
//===----------------------------------------------------------------------===//
2196+
2197+
def VecTernaryOp : CIR_Op<"vec.ternary",
2198+
[Pure, AllTypesMatch<["result", "vec1", "vec2"]>]> {
2199+
let summary = "The `cond ? a : b` ternary operator for vector types";
2200+
let description = [{
2201+
The `cir.vec.ternary` operation represents the C/C++ ternary operator,
2202+
`?:`, for vector types, which does a `select` on individual elements of the
2203+
vectors. Unlike a regular `?:` operator, there is no short circuiting. All
2204+
three arguments are always evaluated. Because there is no short
2205+
circuiting, there are no regions in this operation, unlike cir.ternary.
2206+
2207+
The first argument is a vector of integral type. The second and third
2208+
arguments are vectors of the same type and have the same number of elements
2209+
as the first argument.
2210+
2211+
The result is a vector of the same type as the second and third arguments.
2212+
Each element of the result is `(bool)a[n] ? b[n] : c[n]`.
2213+
}];
2214+
2215+
let arguments = (ins
2216+
IntegerVector:$cond,
2217+
CIR_VectorType:$vec1,
2218+
CIR_VectorType:$vec2
2219+
);
2220+
2221+
let results = (outs CIR_VectorType:$result);
2222+
let assemblyFormat = [{
2223+
`(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,`
2224+
qualified(type($vec1)) attr-dict
2225+
}];
2226+
let hasVerifier = 1;
2227+
}
2228+
21932229
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,36 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
193193
e->getSourceRange().getBegin());
194194
}
195195

196+
mlir::Value
197+
VisitAbstractConditionalOperator(const AbstractConditionalOperator *e) {
198+
mlir::Location loc = cgf.getLoc(e->getSourceRange());
199+
Expr *condExpr = e->getCond();
200+
Expr *lhsExpr = e->getTrueExpr();
201+
Expr *rhsExpr = e->getFalseExpr();
202+
203+
// OpenCL: If the condition is a vector, we can treat this condition like
204+
// the select function.
205+
if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) ||
206+
condExpr->getType()->isExtVectorType()) {
207+
cgf.getCIRGenModule().errorNYI(loc,
208+
"TernaryOp OpenCL VectorType condition");
209+
return {};
210+
}
211+
212+
if (condExpr->getType()->isVectorType() ||
213+
condExpr->getType()->isSveVLSBuiltinType()) {
214+
assert(condExpr->getType()->isVectorType() && "?: op for SVE vector NYI");
215+
mlir::Value condValue = Visit(condExpr);
216+
mlir::Value lhsValue = Visit(lhsExpr);
217+
mlir::Value rhsValue = Visit(rhsExpr);
218+
return builder.create<cir::VecTernaryOp>(loc, condValue, lhsValue,
219+
rhsValue);
220+
}
221+
222+
cgf.getCIRGenModule().errorNYI(loc, "TernaryOp for non vector types");
223+
return {};
224+
}
225+
196226
mlir::Value VisitMemberExpr(MemberExpr *e);
197227

198228
mlir::Value VisitInitListExpr(InitListExpr *e);

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,24 @@ LogicalResult cir::VecShuffleDynamicOp::verify() {
15891589
return success();
15901590
}
15911591

1592+
//===----------------------------------------------------------------------===//
1593+
// VecTernaryOp
1594+
//===----------------------------------------------------------------------===//
1595+
1596+
LogicalResult cir::VecTernaryOp::verify() {
1597+
// Verify that the condition operand has the same number of elements as the
1598+
// other operands. (The automatic verification already checked that all
1599+
// operands are vector types and that the second and third operands are the
1600+
// same type.)
1601+
if (mlir::cast<cir::VectorType>(getCond().getType()).getSize() !=
1602+
getVec1().getType().getSize()) {
1603+
return emitOpError() << ": the number of elements in "
1604+
<< getCond().getType() << " and "
1605+
<< getVec1().getType() << " don't match";
1606+
}
1607+
return success();
1608+
}
1609+
15921610
//===----------------------------------------------------------------------===//
15931611
// TableGen'd op method definitions
15941612
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,7 +1708,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
17081708
CIRToLLVMVecExtractOpLowering,
17091709
CIRToLLVMVecInsertOpLowering,
17101710
CIRToLLVMVecCmpOpLowering,
1711-
CIRToLLVMVecShuffleDynamicOpLowering
1711+
CIRToLLVMVecShuffleDynamicOpLowering,
1712+
CIRToLLVMVecTernaryOpLowering
17121713
// clang-format on
17131714
>(converter, patterns.getContext());
17141715

@@ -1916,6 +1917,20 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
19161917
return mlir::success();
19171918
}
19181919

1920+
mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
1921+
cir::VecTernaryOp op, OpAdaptor adaptor,
1922+
mlir::ConversionPatternRewriter &rewriter) const {
1923+
// Convert `cond` into a vector of i1, then use that in a `select` op.
1924+
mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
1925+
op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
1926+
rewriter.create<mlir::LLVM::ZeroOp>(
1927+
op.getCond().getLoc(),
1928+
typeConverter->convertType(op.getCond().getType())));
1929+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1930+
op, bitVec, adaptor.getVec1(), adaptor.getVec2());
1931+
return mlir::success();
1932+
}
1933+
19191934
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
19201935
return std::make_unique<ConvertCIRToLLVMPass>();
19211936
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,16 @@ class CIRToLLVMVecShuffleDynamicOpLowering
363363
mlir::ConversionPatternRewriter &) const override;
364364
};
365365

366+
class CIRToLLVMVecTernaryOpLowering
367+
: public mlir::OpConversionPattern<cir::VecTernaryOp> {
368+
public:
369+
using mlir::OpConversionPattern<cir::VecTernaryOp>::OpConversionPattern;
370+
371+
mlir::LogicalResult
372+
matchAndRewrite(cir::VecTernaryOp op, OpAdaptor,
373+
mlir::ConversionPatternRewriter &) const override;
374+
};
375+
366376
} // namespace direct
367377
} // namespace cir
368378

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,3 +1091,18 @@ void foo17() {
10911091
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
10921092
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
10931093
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1094+
1095+
void foo20() {
1096+
vi4 a;
1097+
vi4 b;
1098+
vi4 c;
1099+
vi4 r = c ? a : b;
1100+
}
1101+
1102+
// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
1103+
1104+
// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1105+
// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
1106+
1107+
// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1108+
// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,4 +1069,19 @@ void foo17() {
10691069

10701070
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
10711071
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
1072-
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1072+
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1073+
1074+
void foo20() {
1075+
vi4 a;
1076+
vi4 b;
1077+
vi4 c;
1078+
vi4 r = c ? a : b;
1079+
}
1080+
1081+
// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
1082+
1083+
// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1084+
// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
1085+
1086+
// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1087+
// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}

0 commit comments

Comments
 (0)