Skip to content

Commit d3ab7b1

Browse files
[GISel] Add support for scalable vectors in getLCMType
This function can be called from buildCopyToRegs where at least one of the types is a scalable vector type. This function crashed because it did not know how to handle scalable vector types. This patch extends the functionality of getLCMType to handle when at least one of the types is a scalable vector. getLCMType between a fixed and scalable vector is not implemented since the docstring of the function explains that getLCMType is used to build MERGE/UNMERGE instructions and we will never build a MERGE/UNMERGE between fixed and scalable vectors.
1 parent 59eadcd commit d3ab7b1

File tree

3 files changed

+149
-35
lines changed

3 files changed

+149
-35
lines changed

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,13 @@ Register getFunctionLiveInPhysReg(MachineFunction &MF,
343343
const TargetRegisterClass &RC,
344344
const DebugLoc &DL, LLT RegTy = LLT());
345345

346-
/// Return the least common multiple type of \p OrigTy and \p TargetTy, by changing the
347-
/// number of vector elements or scalar bitwidth. The intent is a
346+
/// Return the least common multiple type of \p OrigTy and \p TargetTy, by
347+
/// changing the number of vector elements or scalar bitwidth. The intent is a
348348
/// G_MERGE_VALUES, G_BUILD_VECTOR, or G_CONCAT_VECTORS can be constructed from
349-
/// \p OrigTy elements, and unmerged into \p TargetTy
349+
/// \p OrigTy elements, and unmerged into \p TargetTy. It is an error to call
350+
/// this function where one argument is a fixed vector and the other is a
351+
/// scalable vector, since it is illegal to build a G_{MERGE|UNMERGE}_VALUES
352+
/// between fixed and scalable vectors.
350353
LLVM_READNONE
351354
LLT getLCMType(LLT OrigTy, LLT TargetTy);
352355

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,49 +1071,73 @@ void llvm::getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU) {
10711071
}
10721072

10731073
LLT llvm::getLCMType(LLT OrigTy, LLT TargetTy) {
1074-
const unsigned OrigSize = OrigTy.getSizeInBits();
1075-
const unsigned TargetSize = TargetTy.getSizeInBits();
1076-
1077-
if (OrigSize == TargetSize)
1074+
if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits())
10781075
return OrigTy;
10791076

1080-
if (OrigTy.isVector()) {
1081-
const LLT OrigElt = OrigTy.getElementType();
1082-
1083-
if (TargetTy.isVector()) {
1084-
const LLT TargetElt = TargetTy.getElementType();
1077+
if (OrigTy.isVector() && TargetTy.isVector()) {
1078+
LLT OrigElt = OrigTy.getElementType();
1079+
LLT TargetElt = TargetTy.getElementType();
10851080

1086-
if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
1087-
int GCDElts =
1088-
std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements());
1089-
// Prefer the original element type.
1090-
ElementCount Mul = OrigTy.getElementCount() * TargetTy.getNumElements();
1091-
return LLT::vector(Mul.divideCoefficientBy(GCDElts),
1092-
OrigTy.getElementType());
1093-
}
1094-
} else {
1095-
if (OrigElt.getSizeInBits() == TargetSize)
1096-
return OrigTy;
1081+
// TODO: The docstring for this function says the intention is to use this
1082+
// function to build MERGE/UNMERGE instructions. It won't be the case that
1083+
// we generate a MERGE/UNMERGE between fixed and scalable vector types. We
1084+
// could implement getLCMType between the two in the future if there was a
1085+
// need, but it is not worth it now as this function should not be used in
1086+
// that way.
1087+
if ((OrigTy.isScalableVector() && TargetTy.isFixedVector()) ||
1088+
(OrigTy.isFixedVector() && TargetTy.isScalableVector()))
1089+
llvm_unreachable(
1090+
"getLCMType not implemented between fixed and scalable vectors.");
1091+
1092+
if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
1093+
int GCDMinElts = std::gcd(OrigTy.getElementCount().getKnownMinValue(),
1094+
TargetTy.getElementCount().getKnownMinValue());
1095+
// Prefer the original element type.
1096+
ElementCount Mul = OrigTy.getElementCount().multiplyCoefficientBy(
1097+
TargetTy.getElementCount().getKnownMinValue());
1098+
return LLT::vector(Mul.divideCoefficientBy(GCDMinElts),
1099+
OrigTy.getElementType());
10971100
}
1098-
1099-
unsigned LCMSize = std::lcm(OrigSize, TargetSize);
1100-
return LLT::fixed_vector(LCMSize / OrigElt.getSizeInBits(), OrigElt);
1101+
unsigned LCM = std::lcm(OrigTy.getElementCount().getKnownMinValue() *
1102+
OrigElt.getSizeInBits().getFixedValue(),
1103+
TargetTy.getElementCount().getKnownMinValue() *
1104+
TargetElt.getSizeInBits().getFixedValue());
1105+
return LLT::vector(
1106+
ElementCount::get(LCM / OrigElt.getSizeInBits(), OrigTy.isScalable()),
1107+
OrigElt);
11011108
}
11021109

1103-
if (TargetTy.isVector()) {
1104-
unsigned LCMSize = std::lcm(OrigSize, TargetSize);
1105-
return LLT::fixed_vector(LCMSize / OrigSize, OrigTy);
1110+
// One type is scalar, one type is vector
1111+
if (OrigTy.isVector() || TargetTy.isVector()) {
1112+
LLT VecTy = OrigTy.isVector() ? OrigTy : TargetTy;
1113+
LLT ScalarTy = OrigTy.isVector() ? TargetTy : OrigTy;
1114+
LLT EltTy = VecTy.getElementType();
1115+
LLT OrigEltTy = OrigTy.isVector() ? OrigTy.getElementType() : OrigTy;
1116+
1117+
// Prefer scalar type from OrigTy.
1118+
if (EltTy.getSizeInBits() == ScalarTy.getSizeInBits())
1119+
return LLT::vector(VecTy.getElementCount(), OrigEltTy);
1120+
1121+
// Different size scalars. Create vector with the same total size.
1122+
// LCM will take fixed/scalable from VecTy.
1123+
unsigned LCM = std::lcm(EltTy.getSizeInBits().getFixedValue() *
1124+
VecTy.getElementCount().getKnownMinValue(),
1125+
ScalarTy.getSizeInBits().getFixedValue());
1126+
// Prefer type from OrigTy
1127+
return LLT::vector(ElementCount::get(LCM / OrigEltTy.getSizeInBits(),
1128+
VecTy.getElementCount().isScalable()),
1129+
OrigEltTy);
11061130
}
11071131

1108-
unsigned LCMSize = std::lcm(OrigSize, TargetSize);
1109-
1132+
// At this point, both types are scalars of different size
1133+
unsigned LCM = std::lcm(OrigTy.getSizeInBits().getFixedValue(),
1134+
TargetTy.getSizeInBits().getFixedValue());
11101135
// Preserve pointer types.
1111-
if (LCMSize == OrigSize)
1136+
if (LCM == OrigTy.getSizeInBits())
11121137
return OrigTy;
1113-
if (LCMSize == TargetSize)
1138+
if (LCM == TargetTy.getSizeInBits())
11141139
return TargetTy;
1115-
1116-
return LLT::scalar(LCMSize);
1140+
return LLT::scalar(LCM);
11171141
}
11181142

11191143
LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) {

llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,37 @@ static const LLT V6P0 = LLT::fixed_vector(6, P0);
4646
static const LLT V2P1 = LLT::fixed_vector(2, P1);
4747
static const LLT V4P1 = LLT::fixed_vector(4, P1);
4848

49+
static const LLT NXV1S1 = LLT::scalable_vector(1, S1);
50+
static const LLT NXV2S1 = LLT::scalable_vector(2, S1);
51+
static const LLT NXV3S1 = LLT::scalable_vector(3, S1);
52+
static const LLT NXV4S1 = LLT::scalable_vector(4, S1);
53+
static const LLT NXV12S1 = LLT::scalable_vector(12, S1);
54+
static const LLT NXV32S1 = LLT::scalable_vector(32, S1);
55+
static const LLT NXV64S1 = LLT::scalable_vector(64, S1);
56+
static const LLT NXV128S1 = LLT::scalable_vector(128, S1);
57+
static const LLT NXV384S1 = LLT::scalable_vector(384, S1);
58+
59+
static const LLT NXV1S32 = LLT::scalable_vector(1, S32);
60+
static const LLT NXV2S32 = LLT::scalable_vector(2, S32);
61+
static const LLT NXV3S32 = LLT::scalable_vector(3, S32);
62+
static const LLT NXV4S32 = LLT::scalable_vector(4, S32);
63+
static const LLT NXV8S32 = LLT::scalable_vector(8, S32);
64+
static const LLT NXV12S32 = LLT::scalable_vector(12, S32);
65+
static const LLT NXV24S32 = LLT::scalable_vector(24, S32);
66+
67+
static const LLT NXV1S64 = LLT::scalable_vector(1, S64);
68+
static const LLT NXV2S64 = LLT::scalable_vector(2, S64);
69+
static const LLT NXV3S64 = LLT::scalable_vector(3, S64);
70+
static const LLT NXV4S64 = LLT::scalable_vector(4, S64);
71+
static const LLT NXV6S64 = LLT::scalable_vector(6, S64);
72+
static const LLT NXV12S64 = LLT::scalable_vector(12, S64);
73+
74+
static const LLT NXV1P0 = LLT::scalable_vector(1, P0);
75+
static const LLT NXV2P0 = LLT::scalable_vector(2, P0);
76+
static const LLT NXV3P0 = LLT::scalable_vector(3, P0);
77+
static const LLT NXV4P0 = LLT::scalable_vector(4, P0);
78+
static const LLT NXV12P0 = LLT::scalable_vector(12, P0);
79+
4980
TEST(GISelUtilsTest, getGCDType) {
5081
EXPECT_EQ(S1, getGCDType(S1, S1));
5182
EXPECT_EQ(S32, getGCDType(S32, S32));
@@ -244,6 +275,62 @@ TEST(GISelUtilsTest, getLCMType) {
244275

245276
EXPECT_EQ(V2S64, getLCMType(V2S64, P1));
246277
EXPECT_EQ(V4P1, getLCMType(P1, V2S64));
278+
279+
// Scalable, Scalable
280+
EXPECT_EQ(NXV32S1, getLCMType(NXV1S1, NXV1S32));
281+
EXPECT_EQ(NXV1S64, getLCMType(NXV1S64, NXV1S32));
282+
EXPECT_EQ(NXV2S32, getLCMType(NXV1S32, NXV1S64));
283+
EXPECT_EQ(NXV1P0, getLCMType(NXV1P0, NXV1S64));
284+
EXPECT_EQ(NXV1S64, getLCMType(NXV1S64, NXV1P0));
285+
286+
EXPECT_EQ(NXV128S1, getLCMType(NXV4S1, NXV4S32));
287+
EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV4S32));
288+
EXPECT_EQ(NXV8S32, getLCMType(NXV4S32, NXV4S64));
289+
EXPECT_EQ(NXV4P0, getLCMType(NXV4P0, NXV4S64));
290+
EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV4P0));
291+
292+
EXPECT_EQ(NXV64S1, getLCMType(NXV4S1, NXV2S32));
293+
EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV2S32));
294+
EXPECT_EQ(NXV4S32, getLCMType(NXV4S32, NXV2S64));
295+
EXPECT_EQ(NXV4P0, getLCMType(NXV4P0, NXV2S64));
296+
EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV2P0));
297+
298+
EXPECT_EQ(NXV128S1, getLCMType(NXV2S1, NXV4S32));
299+
EXPECT_EQ(NXV2S64, getLCMType(NXV2S64, NXV4S32));
300+
EXPECT_EQ(NXV8S32, getLCMType(NXV2S32, NXV4S64));
301+
EXPECT_EQ(NXV4P0, getLCMType(NXV2P0, NXV4S64));
302+
EXPECT_EQ(NXV4S64, getLCMType(NXV2S64, NXV4P0));
303+
304+
EXPECT_EQ(NXV384S1, getLCMType(NXV3S1, NXV4S32));
305+
EXPECT_EQ(NXV6S64, getLCMType(NXV3S64, NXV4S32));
306+
EXPECT_EQ(NXV24S32, getLCMType(NXV3S32, NXV4S64));
307+
EXPECT_EQ(NXV12P0, getLCMType(NXV3P0, NXV4S64));
308+
EXPECT_EQ(NXV12S64, getLCMType(NXV3S64, NXV4P0));
309+
310+
EXPECT_EQ(NXV12S1, getLCMType(NXV3S1, NXV4S1));
311+
EXPECT_EQ(NXV12S32, getLCMType(NXV3S32, NXV4S32));
312+
EXPECT_EQ(NXV12S64, getLCMType(NXV3S64, NXV4S64));
313+
EXPECT_EQ(NXV12P0, getLCMType(NXV3P0, NXV4P0));
314+
315+
// Scalable, Scalar
316+
317+
EXPECT_EQ(NXV1S1, getLCMType(NXV1S1, S1));
318+
EXPECT_EQ(NXV32S1, getLCMType(NXV1S1, S32));
319+
EXPECT_EQ(NXV1S32, getLCMType(NXV1S32, S1));
320+
EXPECT_EQ(NXV1S32, getLCMType(NXV1S32, S32));
321+
EXPECT_EQ(NXV2S32, getLCMType(NXV1S32, S64));
322+
EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S1));
323+
EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S32));
324+
EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S64));
325+
326+
EXPECT_EQ(NXV1S1, getLCMType(S1, NXV1S1));
327+
EXPECT_EQ(NXV1S32, getLCMType(S32, NXV1S1));
328+
EXPECT_EQ(NXV32S1, getLCMType(S1, NXV1S32));
329+
EXPECT_EQ(NXV1S32, getLCMType(S32, NXV1S32));
330+
EXPECT_EQ(NXV1S64, getLCMType(S64, NXV1S32));
331+
EXPECT_EQ(NXV64S1, getLCMType(S1, NXV2S32));
332+
EXPECT_EQ(NXV2S32, getLCMType(S32, NXV2S32));
333+
EXPECT_EQ(NXV1S64, getLCMType(S64, NXV2S32));
247334
}
248335

249336
TEST_F(AArch64GISelMITest, ConstFalseTest) {

0 commit comments

Comments
 (0)