Skip to content

[CIR] Implement folder for VecSplatOp #143771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2025

Conversation

AmrDeveloper
Copy link
Member

This change adds a folder for the VecSplatOp

Issue #136487

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels Jun 11, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2025

@llvm/pr-subscribers-clangir

@llvm/pr-subscribers-clang

Author: Amr Hesham (AmrDeveloper)

Changes

This change adds a folder for the VecSplatOp

Issue #136487


Full diff: https://github.com/llvm/llvm-project/pull/143771.diff

6 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+2)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+12)
  • (modified) clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp (+2-1)
  • (modified) clang/test/CIR/CodeGen/vector-ext.cpp (+40-22)
  • (modified) clang/test/CIR/CodeGen/vector.cpp (+40-22)
  • (added) clang/test/CIR/Transforms/vector-splat-fold.cir (+16)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 634f0dd554c77..ab132ee5feee0 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2307,6 +2307,8 @@ def VecSplatOp : CIR_Op<"vec.splat", [Pure,
   let assemblyFormat = [{
     $value `:` type($value) `,` qualified(type($result)) attr-dict
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index a6cf0a6b5d75e..65ad7488be7ca 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1697,6 +1697,18 @@ LogicalResult cir::VecTernaryOp::verify() {
   return success();
 }
 
+// VecSplatOp
+OpFoldResult cir::VecSplatOp::fold(FoldAdaptor adaptor) {
+  mlir::Attribute value = adaptor.getValue();
+  if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
+      !mlir::isa_and_nonnull<cir::FPAttr>(value))
+    return {};
+
+  SmallVector<mlir::Attribute, 16> elements(getType().getSize(), value);
+  return cir::ConstVectorAttr::get(
+      getType(), mlir::ArrayAttr::get(getContext(), elements));
+}
+
 OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
   mlir::Attribute cond = adaptor.getCond();
   mlir::Attribute lhs = adaptor.getLhs();
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 29f9942638964..dc9b0a1546c9c 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -142,7 +142,8 @@ void CIRCanonicalizePass::runOnOperation() {
     // Many operations are here to perform a manual `fold` in
     // applyOpPatternsGreedily.
     if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
-            VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op))
+            VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, VecSplatOp,
+            VecTernaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 965c44c9461a8..55f917e97ca67 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -1095,65 +1095,83 @@ void foo17() {
 
 void foo18() {
   vi4 a = {1, 2, 3, 4};
-  vi4 shl = a << 3;
+  int sv = 3;
+  vi4 shl = a << sv;
 
   uvi4 b = {1u, 2u, 3u, 4u};
-  uvi4 shr = b >> 3u;
+  unsigned int usv = 3u;
+  uvi4 shr = b >> usv;
 }
 
 // CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SV:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["sv", init]
 // CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
 // CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[USV:.*]] = cir.alloca !u32i, !cir.ptr<!u32i>, ["usv", init]
 // CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
-// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
-// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
-// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
-// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
-// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
-// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
 // CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[SV_VAL:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: cir.store{{.*}} %[[SV_VAL]], %[[SV]] : !s32i, !cir.ptr<!s32i>
 // CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
-// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
-// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_SV:.*]] = cir.load{{.*}} %[[SV]] : !cir.ptr<!s32i>, !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[TMP_SV]] : !s32i, !cir.vector<4 x !s32i>
 // CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
 // CIR: cir.store{{.*}} %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
-// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
-// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
-// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
-// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
-// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
-// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
 // CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[USV_VAL:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: cir.store{{.*}} %[[USV_VAL]], %[[USV]] : !u32i, !cir.ptr<!u32i>
 // CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
-// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
-// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[TMP_USV:.*]] = cir.load{{.*}} %[[USV]] : !cir.ptr<!u32i>, !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[TMP_USV]] : !u32i, !cir.vector<4 x !u32i>
 // CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
 // CIR: cir.store{{.*}} %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
 
 // LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SV:.*]] = alloca i32, i64 1, align 4
 // LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
 // LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[USV:.*]] = alloca i32, i64 1, align 4
 // LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
 // LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: store i32 3, ptr %[[SV]], align 4
 // LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
-// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: %[[TMP_SV:.*]] = load i32, ptr %[[SV]], align 4
+// LLVM: %[[SI:.*]] = insertelement <4 x i32> poison, i32 %[[TMP_SV]], i64 0
+// LLVM: %[[S_VEC:.*]] = shufflevector <4 x i32> %[[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[S_VEC]]
 // LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
 // LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: store i32 3, ptr %[[USV]], align 4
 // LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
-// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: %[[TMP_USV:.*]] = load i32, ptr %[[USV]], align 4
+// LLVM: %[[USI:.*]] = insertelement <4 x i32> poison, i32 %[[TMP_USV]], i64 0
+// LLVM: %[[US_VEC:.*]] = shufflevector <4 x i32> %[[USI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], %[[US_VEC]]
 // LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
 
 // OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SV:.*]] = alloca i32, align 4
 // OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
 // OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[USV:.*]] = alloca i32, align 4
 // OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
 // OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: store i32 3, ptr %[[SV]], align 4
 // OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
-// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: %[[TMP_SV:.*]] = load i32, ptr %[[SV]], align 4
+// OGCG: %[[SI:.*]] = insertelement <4 x i32> poison, i32 %[[TMP_SV]], i64 0
+// OGCG: %[[S_VEC:.*]] = shufflevector <4 x i32> %[[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[S_VEC]]
 // OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
 // OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: store i32 3, ptr %[[USV]], align 4
 // OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
-// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: %[[TMP_USV:.*]] = load i32, ptr %[[USV]], align 4
+// OGCG: %[[USI:.*]] = insertelement <4 x i32> poison, i32 %[[TMP_USV]], i64 0
+// OGCG: %[[US_VEC:.*]] = shufflevector <4 x i32> %[[USI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], %[[US_VEC]]
 // OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
 
 void foo19() {
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 23e91724dc0f3..dfb0078148cbe 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -1073,65 +1073,83 @@ void foo17() {
 
 void foo18() {
   vi4 a = {1, 2, 3, 4};
-  vi4 shl = a << 3;
+  int sv = 3;
+  vi4 shl = a << sv;
 
   uvi4 b = {1u, 2u, 3u, 4u};
-  uvi4 shr = b >> 3u;
+  unsigned int usv = 3u;
+  uvi4 shr = b >> usv;
 }
 
 // CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SV:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["sv", init]
 // CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
 // CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[USV:.*]] = cir.alloca !u32i, !cir.ptr<!u32i>, ["usv", init]
 // CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
-// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
-// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
-// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
-// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
-// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
-// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
 // CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[SV_VAL:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: cir.store{{.*}} %[[SV_VAL]], %[[SV]] : !s32i, !cir.ptr<!s32i>
 // CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
-// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
-// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_SV:.*]] = cir.load{{.*}} %[[SV]] : !cir.ptr<!s32i>, !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[TMP_SV]] : !s32i, !cir.vector<4 x !s32i>
 // CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
 // CIR: cir.store{{.*}} %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
-// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
-// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
-// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
-// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
-// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
-// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create({{.*}}, {{.*}}, {{.*}}, {{.*}} : !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
 // CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[USV_VAL:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: cir.store{{.*}} %[[USV_VAL]], %[[USV]] : !u32i, !cir.ptr<!u32i>
 // CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
-// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
-// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[TMP_USV:.*]] = cir.load{{.*}} %[[USV]] : !cir.ptr<!u32i>, !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[TMP_USV]] : !u32i, !cir.vector<4 x !u32i>
 // CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
 // CIR: cir.store{{.*}} %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
 
 // LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SV:.*]] = alloca i32, i64 1, align 4
 // LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
 // LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[USV:.*]] = alloca i32, i64 1, align 4
 // LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
 // LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: store i32 3, ptr %[[SV]], align 4
 // LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
-// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: %[[TMP_SV:.*]] = load i32, ptr %[[SV]], align 4
+// LLVM: %[[SI:.*]] = insertelement <4 x i32> poison, i32 %[[TMP_SV]], i64 0
+// LLVM: %[[S_VEC:.*]] = shufflevector <4 x i32> %[[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[S_VEC]]
 // LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
 // LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: store i32 3, ptr %[[USV]], align 4
 // LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
-// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: %[[TMP_USV:.*]] = load i32, ptr %[[USV]], align 4
+// LLVM: %[[USI:.*]] = insertelement <4 x i32> poison, i32 %[[TMP_USV]], i64 0
+// LLVM: %[[US_VEC:.*]] = shufflevector <4 x i32> %[[USI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], %[[US_VEC]]
 // LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
 
 // OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SV:.*]] = alloca i32, align 4
 // OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
 // OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[USV:.*]] = alloca i32, align 4
 // OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
 // OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: store i32 3, ptr %[[SV]], align 4
 // OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
-// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: %[[TMP_SV:.*]] = load i32, ptr %[[SV]], align 4
+// OGCG: %[[SI:.*]] = insertelement <4 x i32> poison, i32 %[[TMP_SV]], i64 0
+// OGCG: %[[S_VEC:.*]] = shufflevector <4 x i32> %[[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[S_VEC]]
 // OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
 // OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: store i32 3, ptr %[[USV]], align 4
 // OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
-// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: %[[TMP_USV:.*]] = load i32, ptr %[[USV]], align 4
+// OGCG: %[[USI:.*]] = insertelement <4 x i32> poison, i32 %[[TMP_USV]], i64 0
+// OGCG: %[[US_VEC:.*]] = shufflevector <4 x i32> %[[USI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], %[[US_VEC]]
 // OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
 
 void foo19() {
diff --git a/clang/test/CIR/Transforms/vector-splat-fold.cir b/clang/test/CIR/Transforms/vector-splat-fold.cir
new file mode 100644
index 0000000000000..5e3e8ba0eec42
--- /dev/null
+++ b/clang/test/CIR/Transforms/vector-splat-fold.cir
@@ -0,0 +1,16 @@
+// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module  {
+  cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
+    %v = cir.const #cir.int<3> : !s32i
+    %vec = cir.vec.splat %v : !s32i, !cir.vector<4 x !s32i>
+    cir.return %vec : !cir.vector<4 x !s32i>
+  }
+
+  // CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> {
+  // CHECK-NEXT: %0 = cir.const #cir.const_vector<[#cir.int<3> : !s32i, #cir.int<3> : !s32i,
+  // CHECK-SAME: #cir.int<3> : !s32i, #cir.int<3> : !s32i]> : !cir.vector<4 x !s32i>
+  // CHECK-NEXT: cir.return %0 : !cir.vector<4 x !s32i>
+}

Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. I just have a small request regarding the test change.

@@ -1095,65 +1095,83 @@ void foo17() {

void foo18() {
vi4 a = {1, 2, 3, 4};
vi4 shl = a << 3;
int sv = 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you had to make this change to keep the vec.splat from being folded away. I'd like to see both forms tested.

@AmrDeveloper AmrDeveloper force-pushed the cir_splat_op_folder branch from 4eca449 to 301cfa7 Compare June 18, 2025 19:15
@AmrDeveloper AmrDeveloper merged commit 9ee55e7 into llvm:main Jun 19, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants