Skip to content

Commit 2f03a9d

Browse files
AmrDevelopertomtor
authored andcommitted
[CIR] Implement folder for VecShuffleOp (llvm#143260)
This change adds a folder for the VecShuffleOp Issue llvm#136487
1 parent 360a43f commit 2f03a9d

File tree

4 files changed

+98
-4
lines changed

4 files changed

+98
-4
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2199,7 +2199,9 @@ def VecShuffleOp : CIR_Op<"vec.shuffle",
21992199
`(` $vec1 `,` $vec2 `:` qualified(type($vec1)) `)` $indices `:`
22002200
qualified(type($result)) attr-dict
22012201
}];
2202+
22022203
let hasVerifier = 1;
2204+
let hasFolder = 1;
22032205
}
22042206

22052207
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,9 +1580,43 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
15801580
}
15811581

15821582
//===----------------------------------------------------------------------===//
1583-
// VecShuffle
1583+
// VecShuffleOp
15841584
//===----------------------------------------------------------------------===//
15851585

1586+
OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
1587+
auto vec1Attr =
1588+
mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
1589+
auto vec2Attr =
1590+
mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
1591+
if (!vec1Attr || !vec2Attr)
1592+
return {};
1593+
1594+
mlir::Type vec1ElemTy =
1595+
mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
1596+
1597+
mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
1598+
mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
1599+
mlir::ArrayAttr indicesElts = adaptor.getIndices();
1600+
1601+
SmallVector<mlir::Attribute, 16> elements;
1602+
elements.reserve(indicesElts.size());
1603+
1604+
uint64_t vec1Size = vec1Elts.size();
1605+
for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
1606+
if (idxAttr.getSInt() == -1) {
1607+
elements.push_back(cir::UndefAttr::get(vec1ElemTy));
1608+
continue;
1609+
}
1610+
1611+
uint64_t idxValue = idxAttr.getUInt();
1612+
elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
1613+
: vec2Elts[idxValue - vec1Size]);
1614+
}
1615+
1616+
return cir::ConstVectorAttr::get(
1617+
getType(), mlir::ArrayAttr::get(getContext(), elements));
1618+
}
1619+
15861620
LogicalResult cir::VecShuffleOp::verify() {
15871621
// The number of elements in the indices array must match the number of
15881622
// elements in the result type.
@@ -1613,7 +1647,6 @@ OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
16131647
mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
16141648
auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
16151649
auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
1616-
auto vecTy = mlir::cast<cir::VectorType>(vecAttr.getType());
16171650

16181651
mlir::ArrayAttr vecElts = vecAttr.getElts();
16191652
mlir::ArrayAttr indicesElts = indicesAttr.getElts();
@@ -1631,7 +1664,7 @@ OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
16311664
}
16321665

16331666
return cir::ConstVectorAttr::get(
1634-
vecTy, mlir::ArrayAttr::get(getContext(), elements));
1667+
getType(), mlir::ArrayAttr::get(getContext(), elements));
16351668
}
16361669

16371670
return {};

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void CIRCanonicalizePass::runOnOperation() {
142142
// Many operations are here to perform a manual `fold` in
143143
// applyOpPatternsGreedily.
144144
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
145-
VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
145+
VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op))
146146
ops.push_back(op);
147147
});
148148

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// RUN: cir-opt %s -cir-canonicalize -o - -split-input-file | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
!s64i = !cir.int<s, 64>
5+
6+
module {
7+
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
8+
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
9+
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
10+
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<0> : !s64i, #cir.int<4> : !s64i,
11+
#cir.int<1> : !s64i, #cir.int<5> : !s64i] : !cir.vector<4 x !s32i>
12+
cir.return %new_vec : !cir.vector<4 x !s32i>
13+
}
14+
15+
// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
16+
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i,
17+
// CHECK-SAME: #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
18+
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
19+
}
20+
21+
// -----
22+
23+
!s32i = !cir.int<s, 32>
24+
!s64i = !cir.int<s, 64>
25+
26+
module {
27+
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> {
28+
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
29+
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
30+
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<0> : !s64i, #cir.int<4> : !s64i,
31+
#cir.int<1> : !s64i, #cir.int<5> : !s64i, #cir.int<2> : !s64i, #cir.int<6> : !s64i] : !cir.vector<6 x !s32i>
32+
cir.return %new_vec : !cir.vector<6 x !s32i>
33+
}
34+
35+
// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> {
36+
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i,
37+
// CHECK-SAME: #cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i]> : !cir.vector<6 x !s32i>
38+
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<6 x !s32i>
39+
}
40+
41+
// -----
42+
43+
!s32i = !cir.int<s, 32>
44+
!s64i = !cir.int<s, 64>
45+
46+
module {
47+
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
48+
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
49+
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
50+
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<-1> : !s64i, #cir.int<4> : !s64i,
51+
#cir.int<1> : !s64i, #cir.int<5> : !s64i] : !cir.vector<4 x !s32i>
52+
cir.return %new_vec : !cir.vector<4 x !s32i>
53+
}
54+
55+
// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
56+
// CHECK: cir.const #cir.const_vector<[#cir.undef : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i,
57+
// CHECK-SAME: #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
58+
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
59+
}

0 commit comments

Comments
 (0)