|
26 | 26 | #include "llvm/IR/PatternMatch.h"
|
27 | 27 | #include "llvm/IR/Value.h"
|
28 | 28 | #include "llvm/Support/CommandLine.h"
|
29 |
| -#include <optional> |
30 | 29 |
|
31 | 30 | #define DEBUG_TYPE "vectorutils"
|
32 | 31 |
|
@@ -1482,27 +1481,26 @@ void VFABI::getVectorVariantNames(
|
1482 | 1481 |
|
1483 | 1482 | FunctionType *VFABI::createFunctionType(const VFInfo &Info,
|
1484 | 1483 | const FunctionType *ScalarFTy) {
|
1485 |
| - ElementCount VF = Info.Shape.VF; |
1486 | 1484 | // Create vector parameter types
|
1487 | 1485 | SmallVector<Type *, 8> VecTypes;
|
1488 |
| - for (auto [STy, VFParam] : zip(ScalarFTy->params(), Info.Shape.Parameters)) { |
| 1486 | + ElementCount VF = Info.Shape.VF; |
| 1487 | + for (auto [Idx, VFParam] : enumerate(Info.Shape.Parameters)) { |
| 1488 | + if (VFParam.ParamKind == VFParamKind::GlobalPredicate) { |
| 1489 | + VectorType *MaskTy = |
| 1490 | + VectorType::get(Type::getInt1Ty(ScalarFTy->getContext()), VF); |
| 1491 | + VecTypes.push_back(MaskTy); |
| 1492 | + continue; |
| 1493 | + } |
| 1494 | + |
| 1495 | + Type *OperandTy = ScalarFTy->getParamType(Idx); |
1489 | 1496 | if (VFParam.ParamKind == VFParamKind::Vector)
|
1490 |
| - VecTypes.push_back(VectorType::get(STy, VF)); |
1491 |
| - else |
1492 |
| - VecTypes.push_back(STy); |
| 1497 | + OperandTy = VectorType::get(OperandTy, VF); |
| 1498 | + VecTypes.push_back(OperandTy); |
1493 | 1499 | }
|
1494 | 1500 |
|
1495 |
| - // Get mask's position mask and append one if not present in the Instruction. |
1496 |
| - if (auto OptMaskPos = Info.getParamIndexForOptionalMask()) { |
1497 |
| - if (!OptMaskPos) |
1498 |
| - return nullptr; |
1499 |
| - VectorType *MaskTy = |
1500 |
| - VectorType::get(Type::getInt1Ty(ScalarFTy->getContext()), VF); |
1501 |
| - VecTypes.insert(VecTypes.begin() + OptMaskPos.value(), MaskTy); |
1502 |
| - } |
1503 | 1501 | auto *RetTy = ScalarFTy->getReturnType();
|
1504 | 1502 | if (!RetTy->isVoidTy())
|
1505 |
| - RetTy = VectorType::get(ScalarFTy->getReturnType(), VF); |
| 1503 | + RetTy = VectorType::get(RetTy, VF); |
1506 | 1504 | return FunctionType::get(RetTy, VecTypes, false);
|
1507 | 1505 | }
|
1508 | 1506 |
|
|
0 commit comments