Skip to content

Commit 370e54d

Browse files
authored
[CIR] Upstream splat op for VectorType (#139827)
This change adds support for splat op for VectorType Issue #136487
1 parent 145b1b0 commit 370e54d

File tree

7 files changed

+261
-0
lines changed

7 files changed

+261
-0
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,6 +2277,38 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
22772277
let hasFolder = 1;
22782278
}
22792279

2280+
//===----------------------------------------------------------------------===//
2281+
// VecSplatOp
2282+
//===----------------------------------------------------------------------===//
2283+
2284+
def VecSplatOp : CIR_Op<"vec.splat", [Pure,
2285+
TypesMatchWith<"type of 'value' matches element type of 'result'", "result",
2286+
"value", "cast<VectorType>($_self).getElementType()">]> {
2287+
2288+
let summary = "Convert a scalar into a vector";
2289+
let description = [{
2290+
The `cir.vec.splat` operation creates a vector value from a scalar value.
2291+
All elements of the vector have the same value, that of the given scalar.
2292+
2293+
It's a separate operation from `cir.vec.create` because more
2294+
efficient LLVM IR can be generated for it, and because some optimization and
2295+
analysis passes can benefit from knowing that all elements of the vector
2296+
have the same value.
2297+
2298+
```mlir
2299+
%value = cir.const #cir.int<3> : !s32i
2300+
%value_vec = cir.vec.splat %value : !s32i, !cir.vector<4 x !s32i>
2301+
```
2302+
}];
2303+
2304+
let arguments = (ins CIR_VectorElementType:$value);
2305+
let results = (outs CIR_VectorType:$result);
2306+
2307+
let assemblyFormat = [{
2308+
$value `:` type($value) `,` qualified(type($result)) attr-dict
2309+
}];
2310+
}
2311+
22802312
//===----------------------------------------------------------------------===//
22812313
// BaseClassAddrOp
22822314
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,14 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
17801780
cgf.convertType(destTy));
17811781
}
17821782

1783+
case CK_VectorSplat: {
1784+
// Create a vector object and fill all elements with the same scalar value.
1785+
assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type");
1786+
return builder.create<cir::VecSplatOp>(
1787+
cgf.getLoc(subExpr->getSourceRange()), cgf.convertType(destTy),
1788+
Visit(subExpr));
1789+
}
1790+
17831791
default:
17841792
cgf.getCIRGenModule().errorNYI(subExpr->getSourceRange(),
17851793
"CastExpr: ", ce->getCastKindName());

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,6 +1803,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
18031803
CIRToLLVMVecExtractOpLowering,
18041804
CIRToLLVMVecInsertOpLowering,
18051805
CIRToLLVMVecCmpOpLowering,
1806+
CIRToLLVMVecSplatOpLowering,
18061807
CIRToLLVMVecShuffleOpLowering,
18071808
CIRToLLVMVecShuffleDynamicOpLowering,
18081809
CIRToLLVMVecTernaryOpLowering
@@ -1956,6 +1957,56 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
19561957
return mlir::success();
19571958
}
19581959

1960+
mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
1961+
cir::VecSplatOp op, OpAdaptor adaptor,
1962+
mlir::ConversionPatternRewriter &rewriter) const {
1963+
// Vector splat can be implemented with an `insertelement` and a
1964+
// `shufflevector`, which is better than an `insertelement` for each
1965+
// element in the vector. Start with an undef vector. Insert the value into
1966+
// the first element. Then use a `shufflevector` with a mask of all 0 to
1967+
// fill out the entire vector with that value.
1968+
cir::VectorType vecTy = op.getType();
1969+
mlir::Type llvmTy = typeConverter->convertType(vecTy);
1970+
mlir::Location loc = op.getLoc();
1971+
mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
1972+
1973+
mlir::Value elementValue = adaptor.getValue();
1974+
if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
1975+
// If the splat value is poison, then we can just use poison value
1976+
// for the entire vector.
1977+
rewriter.replaceOp(op, poison);
1978+
return mlir::success();
1979+
}
1980+
1981+
if (auto constValue =
1982+
dyn_cast<mlir::LLVM::ConstantOp>(elementValue.getDefiningOp())) {
1983+
if (auto intAttr = dyn_cast<mlir::IntegerAttr>(constValue.getValue())) {
1984+
mlir::DenseIntElementsAttr denseVec = mlir::DenseIntElementsAttr::get(
1985+
mlir::cast<mlir::ShapedType>(llvmTy), intAttr.getValue());
1986+
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
1987+
op, denseVec.getType(), denseVec);
1988+
return mlir::success();
1989+
}
1990+
1991+
if (auto fpAttr = dyn_cast<mlir::FloatAttr>(constValue.getValue())) {
1992+
mlir::DenseFPElementsAttr denseVec = mlir::DenseFPElementsAttr::get(
1993+
mlir::cast<mlir::ShapedType>(llvmTy), fpAttr.getValue());
1994+
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
1995+
op, denseVec.getType(), denseVec);
1996+
return mlir::success();
1997+
}
1998+
}
1999+
2000+
mlir::Value indexValue =
2001+
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
2002+
mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
2003+
loc, poison, elementValue, indexValue);
2004+
SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
2005+
rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement,
2006+
poison, zeroValues);
2007+
return mlir::success();
2008+
}
2009+
19592010
mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite(
19602011
cir::VecShuffleOp op, OpAdaptor adaptor,
19612012
mlir::ConversionPatternRewriter &rewriter) const {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,16 @@ class CIRToLLVMVecCmpOpLowering
367367
mlir::ConversionPatternRewriter &) const override;
368368
};
369369

370+
class CIRToLLVMVecSplatOpLowering
371+
: public mlir::OpConversionPattern<cir::VecSplatOp> {
372+
public:
373+
using mlir::OpConversionPattern<cir::VecSplatOp>::OpConversionPattern;
374+
375+
mlir::LogicalResult
376+
matchAndRewrite(cir::VecSplatOp op, OpAdaptor,
377+
mlir::ConversionPatternRewriter &) const override;
378+
};
379+
370380
class CIRToLLVMVecShuffleOpLowering
371381
: public mlir::OpConversionPattern<cir::VecShuffleOp> {
372382
public:

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,7 @@ void foo14() {
990990
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
991991
// OGCG: %[[GE:.*]] = fcmp oge <4 x float> %[[TMP_A]], %[[TMP_B]]
992992
// OGCG: %[[RES:.*]] = sext <4 x i1> %[[GE]] to <4 x i32>
993+
// OGCG: store <4 x i32> %[[RES]], ptr {{.*}}, align 16
993994

994995
void foo15() {
995996
vi4 a;
@@ -1092,6 +1093,69 @@ void foo17() {
10921093
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
10931094
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
10941095

1096+
void foo18() {
1097+
vi4 a = {1, 2, 3, 4};
1098+
vi4 shl = a << 3;
1099+
1100+
uvi4 b = {1u, 2u, 3u, 4u};
1101+
uvi4 shr = b >> 3u;
1102+
}
1103+
1104+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
1105+
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
1106+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
1107+
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
1108+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
1109+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
1110+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
1111+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
1112+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
1113+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
1114+
// CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
1115+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1116+
// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
1117+
// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
1118+
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
1119+
// CIR: cir.store{{.*}} %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
1120+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
1121+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
1122+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
1123+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
1124+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
1125+
// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
1126+
// CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
1127+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
1128+
// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
1129+
// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
1130+
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
1131+
// CIR: cir.store{{.*}} %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
1132+
1133+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
1134+
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
1135+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
1136+
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
1137+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
1138+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
1139+
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
1140+
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
1141+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
1142+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
1143+
// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
1144+
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
1145+
1146+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
1147+
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
1148+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
1149+
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
1150+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
1151+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
1152+
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
1153+
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
1154+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
1155+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
1156+
// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
1157+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
1158+
10951159
void foo19() {
10961160
vi4 a;
10971161
vi4 b;

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,69 @@ void foo17() {
10711071
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
10721072
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
10731073

1074+
void foo18() {
1075+
vi4 a = {1, 2, 3, 4};
1076+
vi4 shl = a << 3;
1077+
1078+
uvi4 b = {1u, 2u, 3u, 4u};
1079+
uvi4 shr = b >> 3u;
1080+
}
1081+
1082+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
1083+
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
1084+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
1085+
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
1086+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
1087+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
1088+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
1089+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
1090+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
1091+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
1092+
// CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
1093+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1094+
// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
1095+
// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
1096+
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
1097+
// CIR: cir.store{{.*}} %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
1098+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
1099+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
1100+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
1101+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
1102+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
1103+
// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
1104+
// CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
1105+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
1106+
// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
1107+
// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
1108+
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
1109+
// CIR: cir.store{{.*}} %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
1110+
1111+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
1112+
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
1113+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
1114+
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
1115+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
1116+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
1117+
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
1118+
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
1119+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
1120+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
1121+
// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
1122+
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
1123+
1124+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
1125+
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
1126+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
1127+
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
1128+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
1129+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
1130+
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
1131+
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
1132+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
1133+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
1134+
// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
1135+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
1136+
10741137
void foo19() {
10751138
vi4 a;
10761139
vi4 b;

clang/test/CIR/IR/vector.cir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,37 @@ cir.func @vector_shuffle_dynamic_test() {
187187
// CHECK: cir.return
188188
// CHECK: }
189189

190+
cir.func @vector_splat_test() {
191+
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
192+
%1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
193+
%2 = cir.const #cir.int<1> : !s32i
194+
%3 = cir.const #cir.int<2> : !s32i
195+
%4 = cir.const #cir.int<3> : !s32i
196+
%5 = cir.const #cir.int<4> : !s32i
197+
%6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
198+
cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
199+
%7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
200+
%8 = cir.const #cir.int<3> : !s32i
201+
%9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
202+
%10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
203+
cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
204+
cir.return
205+
}
206+
207+
// CHECK: cir.func @vector_splat_test() {
208+
// CHECK-NEXT: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
209+
// CHECK-NEXT: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
210+
// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
211+
// CHECK-NEXT: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
212+
// CHECK-NEXT: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
213+
// CHECK-NEXT: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
214+
// CHECK-NEXT: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
215+
// CHECK-NEXT: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
216+
// CHECK-NEXT: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
217+
// CHECK-NEXT: %[[SPLAT_VAL:.*]] = cir.const #cir.int<3> : !s32i
218+
// CHECK-NEXT: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SPLAT_VAL]] : !s32i, !cir.vector<4 x !s32i>
219+
// CHECK-NEXT: %[[SHL:.*]] = cir.shift(left, %[[TMP]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
220+
// CHECK-NEXT: cir.store %[[SHL]], %[[SHL_RES:.*]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
221+
// CHECK-NEXT: cir.return
222+
190223
}

0 commit comments

Comments
 (0)