Skip to content

[VPlan] Generalize type inference for binary/cast/shift/logic. NFC #116173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 24, 2024

Conversation

LiqinWeng
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2024

@llvm/pr-subscribers-vectorizers

Author: LiqinWeng (LiqinWeng)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/116173.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+21-59)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 8b8ab6be99b0d5..cb42cfe8159b04 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -93,34 +93,19 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
 
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
   unsigned Opcode = R->getOpcode();
-  switch (Opcode) {
-  case Instruction::ICmp:
-  case Instruction::FCmp:
-    return IntegerType::get(Ctx, 1);
-  case Instruction::UDiv:
-  case Instruction::SDiv:
-  case Instruction::SRem:
-  case Instruction::URem:
-  case Instruction::Add:
-  case Instruction::FAdd:
-  case Instruction::Sub:
-  case Instruction::FSub:
-  case Instruction::Mul:
-  case Instruction::FMul:
-  case Instruction::FDiv:
-  case Instruction::FRem:
-  case Instruction::Shl:
-  case Instruction::LShr:
-  case Instruction::AShr:
-  case Instruction::And:
-  case Instruction::Or:
-  case Instruction::Xor: {
+  if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
+      Instruction::isBitwiseLogicOp(Opcode)) {
     Type *ResTy = inferScalarType(R->getOperand(0));
     assert(ResTy == inferScalarType(R->getOperand(1)) &&
            "types for both operands must match for binary op");
     CachedTypes[R->getOperand(1)] = ResTy;
     return ResTy;
   }
+
+  switch (Opcode) {
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+    return IntegerType::get(Ctx, 1);
   case Instruction::FNeg:
   case Instruction::Freeze:
     return inferScalarType(R->getOperand(0));
@@ -157,36 +142,26 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) {
 }
 
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
-  switch (R->getUnderlyingInstr()->getOpcode()) {
-  case Instruction::Call: {
-    unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
-    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
-        ->getReturnType();
-  }
-  case Instruction::UDiv:
-  case Instruction::SDiv:
-  case Instruction::SRem:
-  case Instruction::URem:
-  case Instruction::Add:
-  case Instruction::FAdd:
-  case Instruction::Sub:
-  case Instruction::FSub:
-  case Instruction::Mul:
-  case Instruction::FMul:
-  case Instruction::FDiv:
-  case Instruction::FRem:
-  case Instruction::Shl:
-  case Instruction::LShr:
-  case Instruction::AShr:
-  case Instruction::And:
-  case Instruction::Or:
-  case Instruction::Xor: {
+  unsigned Opcode = R->getUnderlyingInstr()->getOpcode();
+
+  if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
+      Instruction::isBitwiseLogicOp(Opcode)) {
     Type *ResTy = inferScalarType(R->getOperand(0));
     assert(ResTy == inferScalarType(R->getOperand(1)) &&
            "inferred types for operands of binary op don't match");
     CachedTypes[R->getOperand(1)] = ResTy;
     return ResTy;
   }
+
+  if (Instruction::isCast(Opcode))
+    return R->getUnderlyingInstr()->getType();
+
+  switch (Opcode) {
+  case Instruction::Call: {
+    unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
+    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
+        ->getReturnType();
+  }
   case Instruction::Select: {
     Type *ResTy = inferScalarType(R->getOperand(1));
     assert(ResTy == inferScalarType(R->getOperand(2)) &&
@@ -197,21 +172,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   case Instruction::ICmp:
   case Instruction::FCmp:
     return IntegerType::get(Ctx, 1);
-  case Instruction::AddrSpaceCast:
   case Instruction::Alloca:
-  case Instruction::BitCast:
-  case Instruction::Trunc:
-  case Instruction::SExt:
-  case Instruction::ZExt:
-  case Instruction::FPExt:
-  case Instruction::FPTrunc:
   case Instruction::ExtractValue:
-  case Instruction::SIToFP:
-  case Instruction::UIToFP:
-  case Instruction::FPToSI:
-  case Instruction::FPToUI:
-  case Instruction::PtrToInt:
-  case Instruction::IntToPtr:
     return R->getUnderlyingInstr()->getType();
   case Instruction::Freeze:
   case Instruction::FNeg:

@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2024

@llvm/pr-subscribers-llvm-transforms

Author: LiqinWeng (LiqinWeng)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/116173.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+21-59)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 8b8ab6be99b0d5..cb42cfe8159b04 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -93,34 +93,19 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
 
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
   unsigned Opcode = R->getOpcode();
-  switch (Opcode) {
-  case Instruction::ICmp:
-  case Instruction::FCmp:
-    return IntegerType::get(Ctx, 1);
-  case Instruction::UDiv:
-  case Instruction::SDiv:
-  case Instruction::SRem:
-  case Instruction::URem:
-  case Instruction::Add:
-  case Instruction::FAdd:
-  case Instruction::Sub:
-  case Instruction::FSub:
-  case Instruction::Mul:
-  case Instruction::FMul:
-  case Instruction::FDiv:
-  case Instruction::FRem:
-  case Instruction::Shl:
-  case Instruction::LShr:
-  case Instruction::AShr:
-  case Instruction::And:
-  case Instruction::Or:
-  case Instruction::Xor: {
+  if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
+      Instruction::isBitwiseLogicOp(Opcode)) {
     Type *ResTy = inferScalarType(R->getOperand(0));
     assert(ResTy == inferScalarType(R->getOperand(1)) &&
            "types for both operands must match for binary op");
     CachedTypes[R->getOperand(1)] = ResTy;
     return ResTy;
   }
+
+  switch (Opcode) {
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+    return IntegerType::get(Ctx, 1);
   case Instruction::FNeg:
   case Instruction::Freeze:
     return inferScalarType(R->getOperand(0));
@@ -157,36 +142,26 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) {
 }
 
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
-  switch (R->getUnderlyingInstr()->getOpcode()) {
-  case Instruction::Call: {
-    unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
-    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
-        ->getReturnType();
-  }
-  case Instruction::UDiv:
-  case Instruction::SDiv:
-  case Instruction::SRem:
-  case Instruction::URem:
-  case Instruction::Add:
-  case Instruction::FAdd:
-  case Instruction::Sub:
-  case Instruction::FSub:
-  case Instruction::Mul:
-  case Instruction::FMul:
-  case Instruction::FDiv:
-  case Instruction::FRem:
-  case Instruction::Shl:
-  case Instruction::LShr:
-  case Instruction::AShr:
-  case Instruction::And:
-  case Instruction::Or:
-  case Instruction::Xor: {
+  unsigned Opcode = R->getUnderlyingInstr()->getOpcode();
+
+  if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
+      Instruction::isBitwiseLogicOp(Opcode)) {
     Type *ResTy = inferScalarType(R->getOperand(0));
     assert(ResTy == inferScalarType(R->getOperand(1)) &&
            "inferred types for operands of binary op don't match");
     CachedTypes[R->getOperand(1)] = ResTy;
     return ResTy;
   }
+
+  if (Instruction::isCast(Opcode))
+    return R->getUnderlyingInstr()->getType();
+
+  switch (Opcode) {
+  case Instruction::Call: {
+    unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
+    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
+        ->getReturnType();
+  }
   case Instruction::Select: {
     Type *ResTy = inferScalarType(R->getOperand(1));
     assert(ResTy == inferScalarType(R->getOperand(2)) &&
@@ -197,21 +172,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   case Instruction::ICmp:
   case Instruction::FCmp:
     return IntegerType::get(Ctx, 1);
-  case Instruction::AddrSpaceCast:
   case Instruction::Alloca:
-  case Instruction::BitCast:
-  case Instruction::Trunc:
-  case Instruction::SExt:
-  case Instruction::ZExt:
-  case Instruction::FPExt:
-  case Instruction::FPTrunc:
   case Instruction::ExtractValue:
-  case Instruction::SIToFP:
-  case Instruction::UIToFP:
-  case Instruction::FPToSI:
-  case Instruction::FPToUI:
-  case Instruction::PtrToInt:
-  case Instruction::IntToPtr:
     return R->getUnderlyingInstr()->getType();
   case Instruction::Freeze:
   case Instruction::FNeg:

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@LiqinWeng LiqinWeng merged commit 042a1cc into llvm:main Nov 24, 2024
11 checks passed
@LiqinWeng LiqinWeng deleted the refactor-vplananalysis branch December 11, 2024 07:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants