@@ -260,6 +260,31 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
260
260
}
261
261
};
262
262
263
+ struct SimplifyVecSplat : public OpRewritePattern <VecSplatOp> {
264
+ using OpRewritePattern<VecSplatOp>::OpRewritePattern;
265
+ LogicalResult matchAndRewrite (VecSplatOp op,
266
+ PatternRewriter &rewriter) const override {
267
+ mlir::Value splatValue = op.getValue ();
268
+ auto constant =
269
+ mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp ());
270
+ if (!constant)
271
+ return mlir::failure ();
272
+
273
+ auto value = constant.getValue ();
274
+ if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
275
+ !mlir::isa_and_nonnull<cir::FPAttr>(value))
276
+ return mlir::failure ();
277
+
278
+ cir::VectorType resultType = op.getResult ().getType ();
279
+ SmallVector<mlir::Attribute, 16 > elements (resultType.getSize (), value);
280
+ auto constVecAttr = cir::ConstVectorAttr::get (
281
+ resultType, mlir::ArrayAttr::get (getContext (), elements));
282
+
283
+ rewriter.replaceOpWithNewOp <cir::ConstantOp>(op, constVecAttr);
284
+ return mlir::success ();
285
+ }
286
+ };
287
+
263
288
// ===----------------------------------------------------------------------===//
264
289
// CIRSimplifyPass
265
290
// ===----------------------------------------------------------------------===//
@@ -275,7 +300,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
275
300
patterns.add <
276
301
SimplifyTernary,
277
302
SimplifySelect,
278
- SimplifySwitch
303
+ SimplifySwitch,
304
+ SimplifyVecSplat
279
305
>(patterns.getContext ());
280
306
// clang-format on
281
307
}
@@ -288,7 +314,7 @@ void CIRSimplifyPass::runOnOperation() {
288
314
// Collect operations to apply patterns.
289
315
llvm::SmallVector<Operation *, 16 > ops;
290
316
getOperation ()->walk ([&](Operation *op) {
291
- if (isa<TernaryOp, SelectOp, SwitchOp>(op))
317
+ if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp >(op))
292
318
ops.push_back (op);
293
319
});
294
320
0 commit comments