Skip to content

Commit 8e662a8

Browse files
committed
[CIR] Upstream ShuffleDynamicOp for VectorType
1 parent cf56b53 commit 8e662a8

File tree

8 files changed

+229
-1
lines changed

8 files changed

+229
-1
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2141,4 +2141,37 @@ def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {
21412141
}];
21422142
}
21432143

2144+
//===----------------------------------------------------------------------===//
2145+
// VecShuffleDynamicOp
2146+
//===----------------------------------------------------------------------===//
2147+
2148+
def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
2149+
[Pure, AllTypesMatch<["vec", "result"]>]> {
2150+
let summary = "Shuffle a vector using indices in another vector";
2151+
let description = [{
2152+
The `cir.vec.shuffle.dynamic` operation implements the undocumented form of
2153+
Clang's __builtin_shufflevector, where the indices of the shuffled result
2154+
can be runtime values.
2155+
2156+
There are two input vectors, which must have the same number of elements.
2157+
The second input vector must have an integral element type. The elements of
2158+
the second vector are interpreted as indices into the first vector. The
2159+
result vector is constructed by taking the elements from the first input
2160+
vector from the indices indicated by the elements of the second vector.
2161+
2162+
```mlir
2163+
%new_vec = cir.vec.shuffle.dynamic %vec : !cir.vector<4 x !s32i>, %indices : !cir.vector<4 x !s32i>
2164+
```
2165+
}];
2166+
2167+
let arguments = (ins CIR_VectorType:$vec, IntegerVector:$indices);
2168+
let results = (outs CIR_VectorType:$result);
2169+
let assemblyFormat = [{
2170+
$vec `:` qualified(type($vec)) `,` $indices `:` qualified(type($indices))
2171+
attr-dict
2172+
}];
2173+
2174+
let hasVerifier = 1;
2175+
}
2176+
21442177
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
171171
return emitLoadOfLValue(e);
172172
}
173173

174+
mlir::Value VisitShuffleVectorExpr(ShuffleVectorExpr *e) {
175+
if (e->getNumSubExprs() == 2) {
176+
// The undocumented form of __builtin_shufflevector.
177+
mlir::Value inputVec = Visit(e->getExpr(0));
178+
mlir::Value indexVec = Visit(e->getExpr(1));
179+
return cgf.builder.create<cir::VecShuffleDynamicOp>(
180+
cgf.getLoc(e->getSourceRange()), inputVec, indexVec);
181+
}
182+
183+
cgf.getCIRGenModule().errorNYI(e->getSourceRange(),
184+
"ShuffleVectorExpr with indices");
185+
return {};
186+
}
187+
174188
mlir::Value VisitMemberExpr(MemberExpr *e);
175189

176190
mlir::Value VisitInitListExpr(InitListExpr *e);

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,20 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
15641564
return elements[index];
15651565
}
15661566

1567+
//===----------------------------------------------------------------------===//
1568+
// VecShuffleDynamicOp
1569+
//===----------------------------------------------------------------------===//
1570+
1571+
LogicalResult cir::VecShuffleDynamicOp::verify() {
1572+
// The number of elements in the two input vectors must match.
1573+
if (getVec().getType().getSize() !=
1574+
mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
1575+
return emitOpError() << ": the number of elements in " << getVec().getType()
1576+
<< " and " << getIndices().getType() << " don't match";
1577+
}
1578+
return success();
1579+
}
1580+
15671581
//===----------------------------------------------------------------------===//
15681582
// TableGen'd op method definitions
15691583
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1717,7 +1717,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
17171717
CIRToLLVMVecCreateOpLowering,
17181718
CIRToLLVMVecExtractOpLowering,
17191719
CIRToLLVMVecInsertOpLowering,
1720-
CIRToLLVMVecCmpOpLowering
1720+
CIRToLLVMVecCmpOpLowering,
1721+
CIRToLLVMVecShuffleDynamicOpLowering
17211722
// clang-format on
17221723
>(converter, patterns.getContext());
17231724

@@ -1871,6 +1872,54 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
18711872
return mlir::success();
18721873
}
18731874

1875+
mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
1876+
cir::VecShuffleDynamicOp op, OpAdaptor adaptor,
1877+
mlir::ConversionPatternRewriter &rewriter) const {
1878+
// LLVM IR does not have an operation that corresponds to this form of
1879+
// the built-in.
1880+
// __builtin_shufflevector(V, I)
1881+
// is implemented as this pseudocode, where the for loop is unrolled
1882+
// and N is the number of elements:
1883+
// masked = I & (N-1)
1884+
// for (i in 0 <= i < N)
1885+
// result[i] = V[masked[i]]
1886+
mlir::Location loc = op.getLoc();
1887+
mlir::Value input = adaptor.getVec();
1888+
mlir::Type llvmIndexVecType =
1889+
getTypeConverter()->convertType(op.getIndices().getType());
1890+
mlir::Type llvmIndexType = getTypeConverter()->convertType(
1891+
elementTypeIfVector(op.getIndices().getType()));
1892+
uint64_t numElements =
1893+
mlir::cast<cir::VectorType>(op.getVec().getType()).getSize();
1894+
mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>(
1895+
loc, llvmIndexType,
1896+
mlir::IntegerAttr::get(llvmIndexType, numElements - 1));
1897+
mlir::Value maskVector =
1898+
rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType);
1899+
for (uint64_t i = 0; i < numElements; ++i) {
1900+
mlir::Value iValue =
1901+
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
1902+
maskVector = rewriter.create<mlir::LLVM::InsertElementOp>(
1903+
loc, maskVector, maskValue, iValue);
1904+
}
1905+
mlir::Value maskedIndices = rewriter.create<mlir::LLVM::AndOp>(
1906+
loc, llvmIndexVecType, adaptor.getIndices(), maskVector);
1907+
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
1908+
loc, getTypeConverter()->convertType(op.getVec().getType()));
1909+
for (uint64_t i = 0; i < numElements; ++i) {
1910+
mlir::Value iValue =
1911+
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
1912+
mlir::Value indexValue = rewriter.create<mlir::LLVM::ExtractElementOp>(
1913+
loc, maskedIndices, iValue);
1914+
mlir::Value valueAtIndex =
1915+
rewriter.create<mlir::LLVM::ExtractElementOp>(loc, input, indexValue);
1916+
result = rewriter.create<mlir::LLVM::InsertElementOp>(loc, result,
1917+
valueAtIndex, iValue);
1918+
}
1919+
rewriter.replaceOp(op, result);
1920+
return mlir::success();
1921+
}
1922+
18741923
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
18751924
return std::make_unique<ConvertCIRToLLVMPass>();
18761925
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,17 @@ class CIRToLLVMVecCmpOpLowering
352352
mlir::ConversionPatternRewriter &) const override;
353353
};
354354

355+
class CIRToLLVMVecShuffleDynamicOpLowering
356+
: public mlir::OpConversionPattern<cir::VecShuffleDynamicOp> {
357+
public:
358+
using mlir::OpConversionPattern<
359+
cir::VecShuffleDynamicOp>::OpConversionPattern;
360+
361+
mlir::LogicalResult
362+
matchAndRewrite(cir::VecShuffleDynamicOp op, OpAdaptor,
363+
mlir::ConversionPatternRewriter &) const override;
364+
};
365+
355366
} // namespace direct
356367
} // namespace cir
357368

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,3 +988,45 @@ void foo14() {
988988
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
989989
// OGCG: %[[GE:.*]] = fcmp oge <4 x float> %[[TMP_A]], %[[TMP_B]]
990990
// OGCG: %[[RES:.*]] = sext <4 x i1> %[[GE]] to <4 x i32>
991+
992+
void foo15() {
993+
vi4 a;
994+
vi4 b;
995+
vi4 r = __builtin_shufflevector(a, b);
996+
}
997+
998+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
999+
// CIR: %[[TMP_B:.*]] = cir.load{{>*}} {{.*}} : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1000+
// CIR: %[[NEW_VEC:.*]] = cir.vec.shuffle.dynamic %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>
1001+
1002+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr {{.*}}, align 16
1003+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr {{.*}}, align 16
1004+
// LLVM: %[[MASK:.*]] = and <4 x i32> %[[TMP_B]], splat (i32 3)
1005+
// LLVM: %[[SHUF_IDX_0:.*]] = extractelement <4 x i32> %[[MASK]], i64 0
1006+
// LLVM: %[[SHUF_ELE_0:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_0]]
1007+
// LLVM: %[[SHUF_INS_0:.*]] = insertelement <4 x i32> undef, i32 %[[SHUF_ELE_0]], i64 0
1008+
// LLVM: %[[SHUF_IDX_1:.*]] = extractelement <4 x i32> %[[MASK]], i64 1
1009+
// LLVM: %[[SHUF_ELE_1:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_1]]
1010+
// LLVM: %[[SHUF_INS_1:.*]] = insertelement <4 x i32> %[[SHUF_INS_0]], i32 %[[SHUF_ELE_1]], i64 1
1011+
// LLVM: %[[SHUF_IDX_2:.*]] = extractelement <4 x i32> %[[MASK]], i64 2
1012+
// LLVM: %[[SHUF_ELE_2:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_2]]
1013+
// LLVM: %[[SHUF_INS_2:.*]] = insertelement <4 x i32> %[[SHUF_INS_1]], i32 %[[SHUF_ELE_2]], i64 2
1014+
// LLVM: %[[SHUF_IDX_3:.*]] = extractelement <4 x i32> %[[MASK]], i64 3
1015+
// LLVM: %[[SHUF_ELE_3:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_3]]
1016+
// LLVM: %[[SHUF_INS_3:.*]] = insertelement <4 x i32> %[[SHUF_INS_2]], i32 %[[SHUF_ELE_3]], i64 3
1017+
1018+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr {{.*}}, align 16
1019+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr {{.*}}, align 16
1020+
// OGCG: %[[MASK:.*]] = and <4 x i32> %[[TMP_B]], splat (i32 3)
1021+
// OGCG: %[[SHUF_IDX_0:.*]] = extractelement <4 x i32> %[[MASK]], i64 0
1022+
// OGCG: %[[SHUF_ELE_0:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_0]]
1023+
// OGCG: %[[SHUF_INS_0:.*]] = insertelement <4 x i32> poison, i32 %[[SHUF_ELE_0]], i64 0
1024+
// OGCG: %[[SHUF_IDX_1:.*]] = extractelement <4 x i32> %[[MASK]], i64 1
1025+
// OGCG: %[[SHUF_ELE_1:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_1]]
1026+
// OGCG: %[[SHUF_INS_1:.*]] = insertelement <4 x i32> %[[SHUF_INS_0]], i32 %[[SHUF_ELE_1]], i64 1
1027+
// OGCG: %[[SHUF_IDX_2:.*]] = extractelement <4 x i32> %[[MASK]], i64 2
1028+
// OGCG: %[[SHUF_ELE_2:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_2]]
1029+
// OGCG: %[[SHUF_INS_2:.*]] = insertelement <4 x i32> %[[SHUF_INS_1]], i32 %[[SHUF_ELE_2]], i64 2
1030+
// OGCG: %[[SHUF_IDX_3:.*]] = extractelement <4 x i32> %[[MASK]], i64 3
1031+
// OGCG: %[[SHUF_ELE_3:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_3]]
1032+
// OGCG: %[[SHUF_INS_3:.*]] = insertelement <4 x i32> %[[SHUF_INS_2]], i32 %[[SHUF_ELE_3]], i64 3

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,3 +967,46 @@ void foo14() {
967967
// OGCG: %[[GE:.*]] = fcmp oge <4 x float> %[[TMP_A]], %[[TMP_B]]
968968
// OGCG: %[[RES:.*]] = sext <4 x i1> %[[GE]] to <4 x i32>
969969
// OGCG: store <4 x i32> %[[RES]], ptr {{.*}}, align 16
970+
971+
void foo15() {
972+
vi4 a;
973+
vi4 b;
974+
vi4 r = __builtin_shufflevector(a, b);
975+
}
976+
977+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
978+
// CIR: %[[TMP_B:.*]] = cir.load{{>*}} {{.*}} : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
979+
// CIR: %[[NEW_VEC:.*]] = cir.vec.shuffle.dynamic %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>
980+
981+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr {{.*}}, align 16
982+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr {{.*}}, align 16
983+
// LLVM: %[[MASK:.*]] = and <4 x i32> %[[TMP_B]], splat (i32 3)
984+
// LLVM: %[[SHUF_IDX_0:.*]] = extractelement <4 x i32> %[[MASK]], i64 0
985+
// LLVM: %[[SHUF_ELE_0:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_0]]
986+
// LLVM: %[[SHUF_INS_0:.*]] = insertelement <4 x i32> undef, i32 %[[SHUF_ELE_0]], i64 0
987+
// LLVM: %[[SHUF_IDX_1:.*]] = extractelement <4 x i32> %[[MASK]], i64 1
988+
// LLVM: %[[SHUF_ELE_1:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_1]]
989+
// LLVM: %[[SHUF_INS_1:.*]] = insertelement <4 x i32> %[[SHUF_INS_0]], i32 %[[SHUF_ELE_1]], i64 1
990+
// LLVM: %[[SHUF_IDX_2:.*]] = extractelement <4 x i32> %[[MASK]], i64 2
991+
// LLVM: %[[SHUF_ELE_2:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_2]]
992+
// LLVM: %[[SHUF_INS_2:.*]] = insertelement <4 x i32> %[[SHUF_INS_1]], i32 %[[SHUF_ELE_2]], i64 2
993+
// LLVM: %[[SHUF_IDX_3:.*]] = extractelement <4 x i32> %[[MASK]], i64 3
994+
// LLVM: %[[SHUF_ELE_3:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_3]]
995+
// LLVM: %[[SHUF_INS_3:.*]] = insertelement <4 x i32> %[[SHUF_INS_2]], i32 %[[SHUF_ELE_3]], i64 3
996+
997+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr {{.*}}, align 16
998+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr {{.*}}, align 16
999+
// OGCG: %[[MASK:.*]] = and <4 x i32> %[[TMP_B]], splat (i32 3)
1000+
// OGCG: %[[SHUF_IDX_0:.*]] = extractelement <4 x i32> %[[MASK]], i64 0
1001+
// OGCG: %[[SHUF_ELE_0:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_0]]
1002+
// OGCG: %[[SHUF_INS_0:.*]] = insertelement <4 x i32> poison, i32 %[[SHUF_ELE_0]], i64 0
1003+
// OGCG: %[[SHUF_IDX_1:.*]] = extractelement <4 x i32> %[[MASK]], i64 1
1004+
// OGCG: %[[SHUF_ELE_1:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_1]]
1005+
// OGCG: %[[SHUF_INS_1:.*]] = insertelement <4 x i32> %[[SHUF_INS_0]], i32 %[[SHUF_ELE_1]], i64 1
1006+
// OGCG: %[[SHUF_IDX_2:.*]] = extractelement <4 x i32> %[[MASK]], i64 2
1007+
// OGCG: %[[SHUF_ELE_2:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_2]]
1008+
// OGCG: %[[SHUF_INS_2:.*]] = insertelement <4 x i32> %[[SHUF_INS_1]], i32 %[[SHUF_ELE_2]], i64 2
1009+
// OGCG: %[[SHUF_IDX_3:.*]] = extractelement <4 x i32> %[[MASK]], i64 3
1010+
// OGCG: %[[SHUF_ELE_3:.*]] = extractelement <4 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_3]]
1011+
// OGCG: %[[SHUF_INS_3:.*]] = insertelement <4 x i32> %[[SHUF_INS_2]], i32 %[[SHUF_ELE_3]], i64 3
1012+

clang/test/CIR/IR/vector.cir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,26 @@ cir.func @vector_compare_test() {
165165
// CHECK: cir.return
166166
// CHECK: }
167167

168+
cir.func @vector_shuffle_dynamic_test() {
169+
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
170+
%1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b"]
171+
%2 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["r", init]
172+
%3 = cir.load align(16) %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
173+
%4 = cir.load align(16) %1 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
174+
%5 = cir.vec.shuffle.dynamic %3 : !cir.vector<4 x !s32i>, %4 : !cir.vector<4 x !s32i>
175+
cir.store align(16) %5, %2 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
176+
cir.return
177+
}
178+
179+
// CHECK: cir.func @vector_shuffle_dynamic_test() {
180+
// CHECK: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
181+
// CHECK: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b"]
182+
// CHECK: %[[RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["r", init]
183+
// CHECK: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
184+
// CHECK: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
185+
// CHECK: %[[VEC_SHUF:.*]] = cir.vec.shuffle.dynamic %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>
186+
// CHECK: cir.store{{.*}} %[[VEC_SHUF]], %[[RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
187+
// CHECK: cir.return
188+
// CHECK: }
189+
168190
}

0 commit comments

Comments
 (0)