-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[VPlan] Generalize type inference for binary/cast/shift/logic. NFC #116173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-vectorizers Author: LiqinWeng (LiqinWeng) ChangesFull diff: https://github.com/llvm/llvm-project/pull/116173.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 8b8ab6be99b0d5..cb42cfe8159b04 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -93,34 +93,19 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
unsigned Opcode = R->getOpcode();
- switch (Opcode) {
- case Instruction::ICmp:
- case Instruction::FCmp:
- return IntegerType::get(Ctx, 1);
- case Instruction::UDiv:
- case Instruction::SDiv:
- case Instruction::SRem:
- case Instruction::URem:
- case Instruction::Add:
- case Instruction::FAdd:
- case Instruction::Sub:
- case Instruction::FSub:
- case Instruction::Mul:
- case Instruction::FMul:
- case Instruction::FDiv:
- case Instruction::FRem:
- case Instruction::Shl:
- case Instruction::LShr:
- case Instruction::AShr:
- case Instruction::And:
- case Instruction::Or:
- case Instruction::Xor: {
+ if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
+ Instruction::isBitwiseLogicOp(Opcode)) {
Type *ResTy = inferScalarType(R->getOperand(0));
assert(ResTy == inferScalarType(R->getOperand(1)) &&
"types for both operands must match for binary op");
CachedTypes[R->getOperand(1)] = ResTy;
return ResTy;
}
+
+ switch (Opcode) {
+ case Instruction::ICmp:
+ case Instruction::FCmp:
+ return IntegerType::get(Ctx, 1);
case Instruction::FNeg:
case Instruction::Freeze:
return inferScalarType(R->getOperand(0));
@@ -157,36 +142,26 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) {
}
Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
- switch (R->getUnderlyingInstr()->getOpcode()) {
- case Instruction::Call: {
- unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
- return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
- ->getReturnType();
- }
- case Instruction::UDiv:
- case Instruction::SDiv:
- case Instruction::SRem:
- case Instruction::URem:
- case Instruction::Add:
- case Instruction::FAdd:
- case Instruction::Sub:
- case Instruction::FSub:
- case Instruction::Mul:
- case Instruction::FMul:
- case Instruction::FDiv:
- case Instruction::FRem:
- case Instruction::Shl:
- case Instruction::LShr:
- case Instruction::AShr:
- case Instruction::And:
- case Instruction::Or:
- case Instruction::Xor: {
+ unsigned Opcode = R->getUnderlyingInstr()->getOpcode();
+
+ if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
+ Instruction::isBitwiseLogicOp(Opcode)) {
Type *ResTy = inferScalarType(R->getOperand(0));
assert(ResTy == inferScalarType(R->getOperand(1)) &&
"inferred types for operands of binary op don't match");
CachedTypes[R->getOperand(1)] = ResTy;
return ResTy;
}
+
+ if (Instruction::isCast(Opcode))
+ return R->getUnderlyingInstr()->getType();
+
+ switch (Opcode) {
+ case Instruction::Call: {
+ unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
+ return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
+ ->getReturnType();
+ }
case Instruction::Select: {
Type *ResTy = inferScalarType(R->getOperand(1));
assert(ResTy == inferScalarType(R->getOperand(2)) &&
@@ -197,21 +172,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
case Instruction::ICmp:
case Instruction::FCmp:
return IntegerType::get(Ctx, 1);
- case Instruction::AddrSpaceCast:
case Instruction::Alloca:
- case Instruction::BitCast:
- case Instruction::Trunc:
- case Instruction::SExt:
- case Instruction::ZExt:
- case Instruction::FPExt:
- case Instruction::FPTrunc:
case Instruction::ExtractValue:
- case Instruction::SIToFP:
- case Instruction::UIToFP:
- case Instruction::FPToSI:
- case Instruction::FPToUI:
- case Instruction::PtrToInt:
- case Instruction::IntToPtr:
return R->getUnderlyingInstr()->getType();
case Instruction::Freeze:
case Instruction::FNeg:
|
@llvm/pr-subscribers-llvm-transforms Author: LiqinWeng (LiqinWeng) ChangesFull diff: https://github.com/llvm/llvm-project/pull/116173.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 8b8ab6be99b0d5..cb42cfe8159b04 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -93,34 +93,19 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
unsigned Opcode = R->getOpcode();
- switch (Opcode) {
- case Instruction::ICmp:
- case Instruction::FCmp:
- return IntegerType::get(Ctx, 1);
- case Instruction::UDiv:
- case Instruction::SDiv:
- case Instruction::SRem:
- case Instruction::URem:
- case Instruction::Add:
- case Instruction::FAdd:
- case Instruction::Sub:
- case Instruction::FSub:
- case Instruction::Mul:
- case Instruction::FMul:
- case Instruction::FDiv:
- case Instruction::FRem:
- case Instruction::Shl:
- case Instruction::LShr:
- case Instruction::AShr:
- case Instruction::And:
- case Instruction::Or:
- case Instruction::Xor: {
+ if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
+ Instruction::isBitwiseLogicOp(Opcode)) {
Type *ResTy = inferScalarType(R->getOperand(0));
assert(ResTy == inferScalarType(R->getOperand(1)) &&
"types for both operands must match for binary op");
CachedTypes[R->getOperand(1)] = ResTy;
return ResTy;
}
+
+ switch (Opcode) {
+ case Instruction::ICmp:
+ case Instruction::FCmp:
+ return IntegerType::get(Ctx, 1);
case Instruction::FNeg:
case Instruction::Freeze:
return inferScalarType(R->getOperand(0));
@@ -157,36 +142,26 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) {
}
Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
- switch (R->getUnderlyingInstr()->getOpcode()) {
- case Instruction::Call: {
- unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
- return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
- ->getReturnType();
- }
- case Instruction::UDiv:
- case Instruction::SDiv:
- case Instruction::SRem:
- case Instruction::URem:
- case Instruction::Add:
- case Instruction::FAdd:
- case Instruction::Sub:
- case Instruction::FSub:
- case Instruction::Mul:
- case Instruction::FMul:
- case Instruction::FDiv:
- case Instruction::FRem:
- case Instruction::Shl:
- case Instruction::LShr:
- case Instruction::AShr:
- case Instruction::And:
- case Instruction::Or:
- case Instruction::Xor: {
+ unsigned Opcode = R->getUnderlyingInstr()->getOpcode();
+
+ if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
+ Instruction::isBitwiseLogicOp(Opcode)) {
Type *ResTy = inferScalarType(R->getOperand(0));
assert(ResTy == inferScalarType(R->getOperand(1)) &&
"inferred types for operands of binary op don't match");
CachedTypes[R->getOperand(1)] = ResTy;
return ResTy;
}
+
+ if (Instruction::isCast(Opcode))
+ return R->getUnderlyingInstr()->getType();
+
+ switch (Opcode) {
+ case Instruction::Call: {
+ unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
+ return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
+ ->getReturnType();
+ }
case Instruction::Select: {
Type *ResTy = inferScalarType(R->getOperand(1));
assert(ResTy == inferScalarType(R->getOperand(2)) &&
@@ -197,21 +172,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
case Instruction::ICmp:
case Instruction::FCmp:
return IntegerType::get(Ctx, 1);
- case Instruction::AddrSpaceCast:
case Instruction::Alloca:
- case Instruction::BitCast:
- case Instruction::Trunc:
- case Instruction::SExt:
- case Instruction::ZExt:
- case Instruction::FPExt:
- case Instruction::FPTrunc:
case Instruction::ExtractValue:
- case Instruction::SIToFP:
- case Instruction::UIToFP:
- case Instruction::FPToSI:
- case Instruction::FPToUI:
- case Instruction::PtrToInt:
- case Instruction::IntToPtr:
return R->getUnderlyingInstr()->getType();
case Instruction::Freeze:
case Instruction::FNeg:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
No description provided.