Skip to content

Commit a2eaa7f

Browse files
committed
[CIR] Implement folder for VecShuffleDynamicOp
1 parent 33bbce5 commit a2eaa7f

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,6 +2188,7 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
21882188
}];
21892189

21902190
let hasVerifier = 1;
2191+
let hasFolder = 1;
21912192
}
21922193

21932194
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,38 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
15791579
// VecShuffleDynamicOp
15801580
//===----------------------------------------------------------------------===//
15811581

1582+
OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
1583+
mlir::Attribute vec = adaptor.getVec();
1584+
mlir::Attribute indices = adaptor.getIndices();
1585+
if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
1586+
mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
1587+
auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
1588+
auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
1589+
auto vecTy = cast<cir::VectorType>(vecAttr.getType());
1590+
1591+
mlir::ArrayAttr vecElts = vecAttr.getElts();
1592+
mlir::ArrayAttr indicesElts = indicesAttr.getElts();
1593+
1594+
const uint64_t numElements = vecElts.size();
1595+
1596+
SmallVector<mlir::Attribute, 16> elements;
1597+
elements.reserve(numElements);
1598+
1599+
const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
1600+
for (uint64_t i = 0; i < numElements; i++) {
1601+
cir::IntAttr idxAttr = mlir::cast<cir::IntAttr>(indicesElts[i]);
1602+
uint64_t idxValue = idxAttr.getUInt();
1603+
uint64_t newIdx = idxValue & maskBits;
1604+
elements.push_back(vecElts[newIdx]);
1605+
}
1606+
1607+
return cir::ConstVectorAttr::get(
1608+
vecTy, mlir::ArrayAttr::get(getContext(), elements));
1609+
}
1610+
1611+
return {};
1612+
}
1613+
15821614
LogicalResult cir::VecShuffleDynamicOp::verify() {
15831615
// The number of elements in the two input vectors must match.
15841616
if (getVec().getType().getSize() !=

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ void CIRCanonicalizePass::runOnOperation() {
138138
assert(!cir::MissingFeatures::complexRealOp());
139139
assert(!cir::MissingFeatures::complexImagOp());
140140
assert(!cir::MissingFeatures::callOp());
141-
// CastOp, UnaryOp and VecExtractOp are here to perform a manual `fold` in
142-
// applyOpPatternsGreedily.
141+
// CastOp, UnaryOp, VecExtractOp and VecShuffleDynamicOp are here to perform
142+
// a manual `fold` in applyOpPatternsGreedily.
143143
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
144-
VecExtractOp>(op))
144+
VecExtractOp, VecShuffleDynamicOp>(op))
145145
ops.push_back(op);
146146
});
147147

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @fold_shuffle_dynamic_vector_op_test() {
7+
%alloca = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["r", init]
8+
%vec = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
9+
%indices = cir.const #cir.const_vector<[#cir.int<8> : !s32i, #cir.int<7> : !s32i, #cir.int<6> : !s32i, #cir.int<5> : !s32i]> : !cir.vector<4 x !s32i>
10+
%new_vec = cir.vec.shuffle.dynamic %vec : !cir.vector<4 x !s32i>, %indices : !cir.vector<4 x !s32i>
11+
cir.store align(16) %new_vec, %alloca : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
12+
cir.return
13+
}
14+
15+
// CHECK: %[[NEW_VEC:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<4> : !s32i, #cir.int<3> : !s32i, #cir.int<2> : !s32i]> : !cir.vector<4 x !s32i>
16+
}
17+
18+

0 commit comments

Comments
 (0)