Skip to content

Commit 9ee55e7

Browse files
authored
[CIR] Implement folder for VecSplatOp (#143771)
This change adds a folder for the VecSplatOp Issue #136487
1 parent 2c2ad9a commit 9ee55e7

File tree

3 files changed

+45
-5
lines changed

3 files changed

+45
-5
lines changed

clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,31 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
260260
}
261261
};
262262

263+
struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
264+
using OpRewritePattern<VecSplatOp>::OpRewritePattern;
265+
LogicalResult matchAndRewrite(VecSplatOp op,
266+
PatternRewriter &rewriter) const override {
267+
mlir::Value splatValue = op.getValue();
268+
auto constant =
269+
mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp());
270+
if (!constant)
271+
return mlir::failure();
272+
273+
auto value = constant.getValue();
274+
if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
275+
!mlir::isa_and_nonnull<cir::FPAttr>(value))
276+
return mlir::failure();
277+
278+
cir::VectorType resultType = op.getResult().getType();
279+
SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
280+
auto constVecAttr = cir::ConstVectorAttr::get(
281+
resultType, mlir::ArrayAttr::get(getContext(), elements));
282+
283+
rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
284+
return mlir::success();
285+
}
286+
};
287+
263288
//===----------------------------------------------------------------------===//
264289
// CIRSimplifyPass
265290
//===----------------------------------------------------------------------===//
@@ -275,7 +300,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
275300
patterns.add<
276301
SimplifyTernary,
277302
SimplifySelect,
278-
SimplifySwitch
303+
SimplifySwitch,
304+
SimplifyVecSplat
279305
>(patterns.getContext());
280306
// clang-format on
281307
}
@@ -288,7 +314,7 @@ void CIRSimplifyPass::runOnOperation() {
288314
// Collect operations to apply patterns.
289315
llvm::SmallVector<Operation *, 16> ops;
290316
getOperation()->walk([&](Operation *op) {
291-
if (isa<TernaryOp, SelectOp, SwitchOp>(op))
317+
if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
292318
ops.push_back(op);
293319
});
294320

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -979,9 +979,7 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
979979
}
980980

981981
attr = rewriter.getArrayAttr(components);
982-
}
983-
984-
else {
982+
} else {
985983
return op.emitError() << "unsupported constant type " << op.getType();
986984
}
987985

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: cir-opt %s -cir-simplify -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
7+
%v = cir.const #cir.int<3> : !s32i
8+
%vec = cir.vec.splat %v : !s32i, !cir.vector<4 x !s32i>
9+
cir.return %vec : !cir.vector<4 x !s32i>
10+
}
11+
12+
// CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
13+
// CHECK-NEXT: %0 = cir.const #cir.const_vector<[#cir.int<3> : !s32i, #cir.int<3> : !s32i,
14+
// CHECK-SAME: #cir.int<3> : !s32i, #cir.int<3> : !s32i]> : !cir.vector<4 x !s32i>
15+
// CHECK-NEXT: cir.return %0 : !cir.vector<4 x !s32i>
16+
}

0 commit comments

Comments
 (0)