Skip to content

Commit a9e65fd

Browse files
get f16 working
1 parent be9a152 commit a9e65fd

File tree

4 files changed

+501
-11
lines changed

4 files changed

+501
-11
lines changed

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
358358
if (PartLLT.isVector() == LLTy.isVector() &&
359359
PartLLT.getScalarSizeInBits() > LLTy.getScalarSizeInBits() &&
360360
(!PartLLT.isVector() ||
361-
PartLLT.getNumElements() == LLTy.getNumElements()) &&
361+
PartLLT.getElementCount() == LLTy.getElementCount()) &&
362362
OrigRegs.size() == 1 && Regs.size() == 1) {
363363
Register SrcReg = Regs[0];
364364

@@ -406,6 +406,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
406406
// If PartLLT is a mismatched vector in both number of elements and element
407407
// size, e.g. PartLLT == v2s64 and LLTy is v3s32, then first coerce it to
408408
// have the same elt type, i.e. v4s32.
409+
// TODO: Extend this coersion to element multiples other than just 2.
409410
if (PartLLT.getSizeInBits() > LLTy.getSizeInBits() &&
410411
PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 &&
411412
Regs.size() == 1) {
@@ -472,7 +473,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
472473
} else {
473474
// Vector was split, and elements promoted to a wider type.
474475
// FIXME: Should handle floating point promotions.
475-
LLT BVType = LLT::fixed_vector(LLTy.getNumElements(), PartLLT);
476+
LLT BVType = LLT::vector(LLTy.getElementCount(), PartLLT);
476477
auto BV = B.buildBuildVector(BVType, Regs);
477478
B.buildTrunc(OrigRegs[0], BV);
478479
}

llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,16 +1065,16 @@ void MachineIRBuilder::validateTruncExt(const LLT DstTy, const LLT SrcTy,
10651065
#ifndef NDEBUG
10661066
if (DstTy.isVector()) {
10671067
assert(SrcTy.isVector() && "mismatched cast between vector and non-vector");
1068-
assert(SrcTy.getNumElements() == DstTy.getNumElements() &&
1068+
assert(SrcTy.getElementCount() == DstTy.getElementCount() &&
10691069
"different number of elements in a trunc/ext");
10701070
} else
10711071
assert(DstTy.isScalar() && SrcTy.isScalar() && "invalid extend/trunc");
10721072

10731073
if (IsExtend)
1074-
assert(DstTy.getSizeInBits() > SrcTy.getSizeInBits() &&
1074+
assert(TypeSize::isKnownGT(DstTy.getSizeInBits(), SrcTy.getSizeInBits()) &&
10751075
"invalid narrowing extend");
10761076
else
1077-
assert(DstTy.getSizeInBits() < SrcTy.getSizeInBits() &&
1077+
assert(TypeSize::isKnownLT(DstTy.getSizeInBits(), SrcTy.getSizeInBits()) &&
10781078
"invalid widening trunc");
10791079
#endif
10801080
}
@@ -1281,10 +1281,19 @@ MachineIRBuilder::buildInstr(unsigned Opc, ArrayRef<DstOp> DstOps,
12811281
SrcOps[0].getLLTTy(*getMRI());
12821282
}) &&
12831283
"type mismatch in input list");
1284-
assert((TypeSize::ScalarTy)SrcOps.size() *
1285-
SrcOps[0].getLLTTy(*getMRI()).getSizeInBits() ==
1286-
DstOps[0].getLLTTy(*getMRI()).getSizeInBits() &&
1287-
"input scalars do not exactly cover the output vector register");
1284+
if (DstOps[0].getLLTTy(*getMRI()).isScalable())
1285+
assert((TypeSize::ScalarTy)SrcOps.size() *
1286+
SrcOps[0].getLLTTy(*getMRI()).getSizeInBits() >=
1287+
DstOps[0]
1288+
.getLLTTy(*getMRI())
1289+
.getSizeInBits()
1290+
.getKnownMinValue() &&
1291+
"input scalars does not cover the output vector register");
1292+
else
1293+
assert((TypeSize::ScalarTy)SrcOps.size() *
1294+
SrcOps[0].getLLTTy(*getMRI()).getSizeInBits() ==
1295+
DstOps[0].getLLTTy(*getMRI()).getSizeInBits() &&
1296+
"input scalars do not exactly cover the output vector register");
12881297
break;
12891298
}
12901299
case TargetOpcode::G_BUILD_VECTOR_TRUNC: {

llvm/lib/CodeGen/MachineVerifier.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ bool MachineVerifier::verifyVectorElementMatch(LLT Ty0, LLT Ty1,
965965
return false;
966966
}
967967

968-
if (Ty0.isVector() && Ty0.getNumElements() != Ty1.getNumElements()) {
968+
if (Ty0.isVector() && Ty0.getElementCount() != Ty1.getElementCount()) {
969969
report("operand types must preserve number of vector elements", MI);
970970
return false;
971971
}
@@ -1435,7 +1435,7 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
14351435
if (DstTy.getElementType() != SrcEltTy)
14361436
report("G_BUILD_VECTOR result element type must match source type", MI);
14371437

1438-
if (DstTy.getNumElements() != MI->getNumOperands() - 1)
1438+
if (DstTy.getElementCount().getKnownMinValue() > MI->getNumOperands() - 1)
14391439
report("G_BUILD_VECTOR must have an operand for each elemement", MI);
14401440

14411441
for (const MachineOperand &MO : llvm::drop_begin(MI->operands(), 2))

0 commit comments

Comments
 (0)