Skip to content

[SandboxVec][Legality] Per opcode checks #114145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2024
Merged

[SandboxVec][Legality] Per opcode checks #114145

merged 1 commit into from
Nov 1, 2024

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Oct 29, 2024

This patch adds more opcode-specific legality checks.

This patch adds more opcode-specific legality checks.
@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

This patch adds more opcode-specific legality checks.


Patch is 33.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114145.diff

7 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h (+16-1)
  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h (+1-1)
  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h (+37)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (+103-1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (+4-1)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp (+86-6)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp (+333)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 77ba5cd7f002e9..f43e033e3cc7e3 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -13,6 +13,8 @@
 #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/IR/DataLayout.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -33,6 +35,9 @@ enum class ResultReason {
   DiffTypes,
   DiffMathFlags,
   DiffWrapFlags,
+  NotConsecutive,
+  Unimplemented,
+  Infeasible,
 };
 
 #ifndef NDEBUG
@@ -59,6 +64,12 @@ struct ToStr {
       return "DiffMathFlags";
     case ResultReason::DiffWrapFlags:
       return "DiffWrapFlags";
+    case ResultReason::NotConsecutive:
+      return "NotConsecutive";
+    case ResultReason::Unimplemented:
+      return "Unimplemented";
+    case ResultReason::Infeasible:
+      return "Infeasible";
     }
     llvm_unreachable("Unknown ResultReason enum");
   }
@@ -142,8 +153,12 @@ class LegalityAnalysis {
   std::optional<ResultReason>
   notVectorizableBasedOnOpcodesAndTypes(ArrayRef<Value *> Bndl);
 
+  ScalarEvolution &SE;
+  const DataLayout &DL;
+
 public:
-  LegalityAnalysis() = default;
+  LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL)
+      : SE(SE), DL(DL) {}
   /// A LegalityResult factory.
   template <typename ResultT, typename... ArgsT>
   ResultT &createLegalityResult(ArgsT... Args) {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 2b0b3f8192c048..7e0b88ae7197d4 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -24,7 +24,7 @@ namespace llvm::sandboxir {
 
 class BottomUpVec final : public FunctionPass {
   bool Change = false;
-  LegalityAnalysis Legality;
+  std::unique_ptr<LegalityAnalysis> Legality;
   void vectorizeRec(ArrayRef<Value *> Bndl);
   void tryVectorize(ArrayRef<Value *> Seeds);
 
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index 9577e8ef7b37cb..8b64ec58da345d 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -12,7 +12,10 @@
 #ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
 #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
 
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/IR/DataLayout.h"
 #include "llvm/SandboxIR/Type.h"
+#include "llvm/SandboxIR/Utils.h"
 
 namespace llvm::sandboxir {
 
@@ -29,6 +32,40 @@ class VecUtils {
   static Type *getElementType(Type *Ty) {
     return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
   }
+
+  /// \Returns true if \p I1 and \p I2 are load/stores accessing consecutive
+  /// memory addresses.
+  template <typename LoadOrStoreT>
+  static bool areConsecutive(LoadOrStoreT *I1, LoadOrStoreT *I2,
+                             ScalarEvolution &SE, const DataLayout &DL) {
+    static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
+                      std::is_same<LoadOrStoreT, StoreInst>::value,
+                  "Expected Load or Store!");
+    auto Diff = Utils::getPointerDiffInBytes(I1, I2, SE);
+    if (!Diff)
+      return false;
+    int ElmBytes = Utils::getNumBits(I1) / 8;
+    return *Diff == ElmBytes;
+  }
+
+  template <typename LoadOrStoreT>
+  static bool areConsecutive(ArrayRef<Value *> &Bndl, ScalarEvolution &SE,
+                             const DataLayout &DL) {
+    static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
+                      std::is_same<LoadOrStoreT, StoreInst>::value,
+                  "Expected Load or Store!");
+    assert(isa<LoadOrStoreT>(Bndl[0]) && "Expected Load or Store!");
+    auto *LastLS = cast<LoadOrStoreT>(Bndl[0]);
+    for (Value *V : drop_begin(Bndl)) {
+      assert(isa<LoadOrStoreT>(V) &&
+             "Unimplemented: we only support StoreInst!");
+      auto *LS = cast<LoadOrStoreT>(V);
+      if (!VecUtils::areConsecutive(LastLS, LS, SE, DL))
+        return false;
+      LastLS = LS;
+    }
+    return true;
+  }
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 1cc6356300e492..1efd178778b9f6 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -70,7 +70,109 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
     }
   }
 
-  // TODO: Missing checks
+  // Now we need to do further checks for specific opcodes.
+  switch (Opcode) {
+  case Instruction::Opcode::ZExt:
+  case Instruction::Opcode::SExt:
+  case Instruction::Opcode::FPToUI:
+  case Instruction::Opcode::FPToSI:
+  case Instruction::Opcode::FPExt:
+  case Instruction::Opcode::PtrToInt:
+  case Instruction::Opcode::IntToPtr:
+  case Instruction::Opcode::SIToFP:
+  case Instruction::Opcode::UIToFP:
+  case Instruction::Opcode::Trunc:
+  case Instruction::Opcode::FPTrunc:
+  case Instruction::Opcode::BitCast: {
+    // We have already checked that they are of the same opcode.
+    assert(all_of(Bndl,
+                  [Opcode](Value *V) {
+                    return cast<Instruction>(V)->getOpcode() == Opcode;
+                  }) &&
+           "Different opcodes, should have early returned!");
+    // But for these opcodes we should also check the operand type.
+    Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0));
+    if (any_of(drop_begin(Bndl), [FromTy0](Value *V) {
+          return Utils::getExpectedType(cast<User>(V)->getOperand(0)) !=
+                 FromTy0;
+        }))
+      return ResultReason::DiffTypes;
+    return std::nullopt;
+  }
+  case Instruction::Opcode::FCmp:
+  case Instruction::Opcode::ICmp: {
+    // We need the same predicate..
+    auto Pred0 = cast<CmpInst>(I0)->getPredicate();
+    bool Same = all_of(Bndl, [Pred0](Value *V) {
+      return cast<CmpInst>(V)->getPredicate() == Pred0;
+    });
+    if (Same)
+      return std::nullopt;
+    return ResultReason::DiffOpcodes;
+  }
+  case Instruction::Opcode::Select:
+  case Instruction::Opcode::FNeg:
+  case Instruction::Opcode::Add:
+  case Instruction::Opcode::FAdd:
+  case Instruction::Opcode::Sub:
+  case Instruction::Opcode::FSub:
+  case Instruction::Opcode::Mul:
+  case Instruction::Opcode::FMul:
+  case Instruction::Opcode::FRem:
+  case Instruction::Opcode::UDiv:
+  case Instruction::Opcode::SDiv:
+  case Instruction::Opcode::FDiv:
+  case Instruction::Opcode::URem:
+  case Instruction::Opcode::SRem:
+  case Instruction::Opcode::Shl:
+  case Instruction::Opcode::LShr:
+  case Instruction::Opcode::AShr:
+  case Instruction::Opcode::And:
+  case Instruction::Opcode::Or:
+  case Instruction::Opcode::Xor:
+    return std::nullopt;
+  case Instruction::Opcode::Load:
+    if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
+      return std::nullopt;
+    return ResultReason::NotConsecutive;
+  case Instruction::Opcode::Store:
+    if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
+      return std::nullopt;
+    return ResultReason::NotConsecutive;
+  case Instruction::Opcode::PHI:
+    return ResultReason::Unimplemented;
+  case Instruction::Opcode::Opaque:
+    return ResultReason::Unimplemented;
+  case Instruction::Opcode::Br:
+  case Instruction::Opcode::Ret:
+  case Instruction::Opcode::AddrSpaceCast:
+  case Instruction::Opcode::InsertElement:
+  case Instruction::Opcode::InsertValue:
+  case Instruction::Opcode::ExtractElement:
+  case Instruction::Opcode::ExtractValue:
+  case Instruction::Opcode::ShuffleVector:
+  case Instruction::Opcode::Call:
+  case Instruction::Opcode::GetElementPtr:
+  case Instruction::Opcode::Switch:
+    return ResultReason::Unimplemented;
+  case Instruction::Opcode::VAArg:
+  case Instruction::Opcode::Freeze:
+  case Instruction::Opcode::Fence:
+  case Instruction::Opcode::Invoke:
+  case Instruction::Opcode::CallBr:
+  case Instruction::Opcode::LandingPad:
+  case Instruction::Opcode::CatchPad:
+  case Instruction::Opcode::CleanupPad:
+  case Instruction::Opcode::CatchRet:
+  case Instruction::Opcode::CleanupRet:
+  case Instruction::Opcode::Resume:
+  case Instruction::Opcode::CatchSwitch:
+  case Instruction::Opcode::AtomicRMW:
+  case Instruction::Opcode::AtomicCmpXchg:
+  case Instruction::Opcode::Alloca:
+  case Instruction::Opcode::Unreachable:
+    return ResultReason::Infeasible;
+  }
 
   return std::nullopt;
 }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 66d631edfc4076..339330c64f0caa 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -11,6 +11,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/SandboxIR/Function.h"
 #include "llvm/SandboxIR/Instruction.h"
+#include "llvm/SandboxIR/Module.h"
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
 
 namespace llvm::sandboxir {
@@ -40,7 +41,7 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
 }
 
 void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
-  const auto &LegalityRes = Legality.canVectorize(Bndl);
+  const auto &LegalityRes = Legality->canVectorize(Bndl);
   switch (LegalityRes.getSubclassID()) {
   case LegalityResultID::Widen: {
     auto *I = cast<Instruction>(Bndl[0]);
@@ -60,6 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
 void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
 
 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
+  Legality = std::make_unique<LegalityAnalysis>(A.getScalarEvolution(),
+                                                F.getParent()->getDataLayout());
   Change = false;
   // TODO: Start from innermost BBs first
   for (auto &BB : F) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 50b78f6f48afdf..68557cb8b129f2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -7,7 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/SandboxIR/Function.h"
 #include "llvm/SandboxIR/Instruction.h"
 #include "llvm/Support/SourceMgr.h"
@@ -18,6 +24,22 @@ using namespace llvm;
 struct LegalityTest : public testing::Test {
   LLVMContext C;
   std::unique_ptr<Module> M;
+  std::unique_ptr<DominatorTree> DT;
+  std::unique_ptr<TargetLibraryInfoImpl> TLII;
+  std::unique_ptr<TargetLibraryInfo> TLI;
+  std::unique_ptr<AssumptionCache> AC;
+  std::unique_ptr<LoopInfo> LI;
+  std::unique_ptr<ScalarEvolution> SE;
+
+  ScalarEvolution &getSE(llvm::Function &LLVMF) {
+    DT = std::make_unique<DominatorTree>(LLVMF);
+    TLII = std::make_unique<TargetLibraryInfoImpl>();
+    TLI = std::make_unique<TargetLibraryInfo>(*TLII);
+    AC = std::make_unique<AssumptionCache>(LLVMF);
+    LI = std::make_unique<LoopInfo>(*DT);
+    SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
+    return *SE;
+  }
 
   void parseIR(LLVMContext &C, const char *IR) {
     SMDiagnostic Err;
@@ -29,12 +51,14 @@ struct LegalityTest : public testing::Test {
 
 TEST_F(LegalityTest, Legality) {
   parseIR(C, R"IR(
-define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1) {
+define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
   %gep0 = getelementptr float, ptr %ptr, i32 0
   %gep1 = getelementptr float, ptr %ptr, i32 1
   %gep3 = getelementptr float, ptr %ptr, i32 3
   %ld0 = load float, ptr %gep0
-  %ld1 = load float, ptr %gep0
+  %ld0b = load float, ptr %gep0
+  %ld1 = load float, ptr %gep1
+  %ld3 = load float, ptr %gep3
   store float %ld0, ptr %gep0
   store float %ld1, ptr %gep1
   store <2 x float> %vec2, ptr %gep1
@@ -44,10 +68,17 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   %fadd1 = fadd fast float %farg1, %farg1
   %trunc0 = trunc nuw nsw i64 %v0 to i8
   %trunc1 = trunc nsw i64 %v1 to i8
+  %trunc64to8 = trunc i64 %v0 to i8
+  %trunc32to8 = trunc i32 %v2 to i8
+  %cmpSLT = icmp slt i64 %v0, %v1
+  %cmpSGT = icmp sgt i64 %v0, %v1
   ret void
 }
 )IR");
   llvm::Function *LLVMF = &*M->getFunction("foo");
+  auto &SE = getSE(*LLVMF);
+  const auto &DL = M->getDataLayout();
+
   sandboxir::Context Ctx(C);
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
@@ -55,8 +86,10 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
   [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
   [[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
-  [[maybe_unused]] auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
-  [[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ld0b = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ld3 = cast<sandboxir::LoadInst>(&*It++);
   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
   auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
@@ -66,8 +99,12 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
   auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++);
   auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++);
+  auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++);
+  auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
+  auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
+  auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
 
-  sandboxir::LegalityAnalysis Legality;
+  sandboxir::LegalityAnalysis Legality(SE, DL);
   const auto &Result = Legality.canVectorize({St0, St1});
   EXPECT_TRUE(isa<sandboxir::Widen>(Result));
 
@@ -109,10 +146,52 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffWrapFlags);
   }
+  {
+    // Check DiffTypes for unary operands that have a different type.
+    const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::DiffTypes);
+  }
+  {
+    // Check DiffOpcodes for CMPs with different predicates.
+    const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::DiffOpcodes);
+  }
+  {
+    // Check NotConsecutive Ld0,Ld0b
+    const auto &Result = Legality.canVectorize({Ld0, Ld0b});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::NotConsecutive);
+  }
+  {
+    // Check NotConsecutive Ld0,Ld3
+    const auto &Result = Legality.canVectorize({Ld0, Ld3});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::NotConsecutive);
+  }
+  {
+    // Check Widen Ld0,Ld1
+    const auto &Result = Legality.canVectorize({Ld0, Ld1});
+    EXPECT_TRUE(isa<sandboxir::Widen>(Result));
+  }
 }
 
 #ifndef NDEBUG
 TEST_F(LegalityTest, LegalityResultDump) {
+  parseIR(C, R"IR(
+define void @foo() {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  auto &SE = getSE(*LLVMF);
+  const auto &DL = M->getDataLayout();
+
   auto Matches = [](const sandboxir::LegalityResult &Result,
                     const std::string &ExpectedStr) -> bool {
     std::string Buff;
@@ -120,7 +199,8 @@ TEST_F(LegalityTest, LegalityResultDump) {
     Result.print(OS);
     return Buff == ExpectedStr;
   };
-  sandboxir::LegalityAnalysis Legality;
+
+  sandboxir::LegalityAnalysis Legality(SE, DL);
   EXPECT_TRUE(
       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index e0b08284964392..75f72ce23fbaac 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -7,15 +7,47 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Function.h"
 #include "llvm/SandboxIR/Type.h"
+#include "llvm/Support/SourceMgr.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
 
 struct VecUtilsTest : public testing::Test {
   LLVMContext C;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<AssumptionCache> AC;
+  std::unique_ptr<TargetLibraryInfoImpl> TLII;
+  std::unique_ptr<TargetLibraryInfo> TLI;
+  std::unique_ptr<DominatorTree> DT;
+  std::unique_ptr<LoopInfo> LI;
+  std::unique_ptr<ScalarEvolution> SE;
+  void parseIR(const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("VecUtilsTest", errs());
+  }
+  ScalarEvolution &getSE(llvm::Function &LLVMF) {
+    TLII = std::make_unique<TargetLibraryInfoImpl>();
+    TLI = std::make_unique<TargetLibraryInfo>(*TLII);
+    AC = std::make_unique<AssumptionCache>(LLVMF);
+    DT = std::make_unique<DominatorTree>(LLVMF);
+    LI = std::make_unique<LoopInfo>(*DT);
+    SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
+    return *SE;
+  }
 };
 
 TEST_F(VecUtilsTest, GetNumElements) {
@@ -35,3 +67,304 @@ TEST_F(VecUtilsTest, GetElementType) {
   auto *VTy = sandboxir::FixedVectorType::get(ElemTy, 2);
   EXPECT_EQ(sandboxir::VecUtils::getElementType(VTy), ElemTy);
 }
+
+TEST_F(VecUtilsTest, AreConsecutive_gep_float) {
+  parseIR(R"IR(
+define void @foo(ptr %ptr) {
+  %gep0 = getelementptr inbounds float, ptr %ptr, i64 0
+  %gep1 = getelementptr inbounds float, ptr %ptr, i64 1
+  %gep2 = getelementptr inbounds float, ptr %ptr, i64 2
+  %gep3 = getelementptr inbounds float, ptr %ptr, i64 3
+
+  %ld0 = load float, ptr %gep0
+  %ld1 = load float, ptr %gep1
+  %ld2 = load float, ptr %gep2
+  %ld3 = load float, ptr %gep3
+
+  %v2ld0 = load <2 x float>, ptr %gep0
+  %v2ld1 = load <2 x float>, ptr %gep1
+  %v2ld2 = load <2 x float>, ptr %gep2
+  %v2ld3 = load <2 x float>, ptr %gep3
+
+  %v3ld0 = load <3 x float>, ptr %gep0
+  %v3ld1 = load <3 x float>, ptr %gep1
+  %v3ld2 = load <3 x float>, ptr %gep2
+  %v3ld3 = load <3 x float>, ptr %gep3
+  ret void
+}
+)IR");
+  Fun...
[truncated]

Copy link
Member

@tmsri tmsri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vporpo vporpo merged commit 083369f into llvm:main Nov 1, 2024
9 of 11 checks passed
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
This patch adds more opcode-specific legality checks.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
This patch adds more opcode-specific legality checks.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants