Skip to content

Commit b5f1f37

Browse files
[GISel] Add support for scalable vectors in getGCDType
This function can be called from buildCopyToRegs where at least one of the types is a scalable vector type. This function crashed because it did not know how to handle scalable vector types. This patch extends the functionality of getGCDType to handle when at least one of the types is a scalable vector. getGCDType between a fixed and scalable vector is not implemented since the docstring of the function explains that getGCDType is used to build MERGE/UNMERGE instructions and we will never build a MERGE/UNMERGE between fixed and scalable vectors.
1 parent d3ab7b1 commit b5f1f37

File tree

3 files changed

+104
-30
lines changed

3 files changed

+104
-30
lines changed

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,10 @@ LLT getCoverTy(LLT OrigTy, LLT TargetTy);
368368
/// If these are vectors with different element types, this will try to produce
369369
/// a vector with a compatible total size, but the element type of \p OrigTy. If
370370
/// this can't be satisfied, this will produce a scalar smaller than the
371-
/// original vector elements.
371+
/// original vector elements. It is an error to call this function where
372+
/// one argument is a fixed vector and the other is a scalable vector, since it
373+
/// is illegal to build a G_{MERGE|UNMERGE}_VALUES between fixed and scalable
374+
/// vectors.
372375
///
373376
/// In the worst case, this returns LLT::scalar(1)
374377
LLVM_READNONE

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,45 +1156,60 @@ LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) {
11561156
}
11571157

11581158
LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
1159-
const unsigned OrigSize = OrigTy.getSizeInBits();
1160-
const unsigned TargetSize = TargetTy.getSizeInBits();
1161-
1162-
if (OrigSize == TargetSize)
1159+
if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits())
11631160
return OrigTy;
11641161

1165-
if (OrigTy.isVector()) {
1162+
if (OrigTy.isVector() && TargetTy.isVector()) {
11661163
LLT OrigElt = OrigTy.getElementType();
1167-
if (TargetTy.isVector()) {
1168-
LLT TargetElt = TargetTy.getElementType();
1169-
if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
1170-
int GCD = std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements());
1171-
return LLT::scalarOrVector(ElementCount::getFixed(GCD), OrigElt);
1172-
}
1173-
} else {
1174-
// If the source is a vector of pointers, return a pointer element.
1175-
if (OrigElt.getSizeInBits() == TargetSize)
1176-
return OrigElt;
1177-
}
1164+
LLT TargetElt = TargetTy.getElementType();
11781165

1179-
unsigned GCD = std::gcd(OrigSize, TargetSize);
1166+
// TODO: The docstring for this function says the intention is to use this
1167+
// function to build MERGE/UNMERGE instructions. It won't be the case that
1168+
// we generate a MERGE/UNMERGE between fixed and scalable vector types. We
1169+
// could implement getGCDType between the two in the future if there was a
1170+
// need, but it is not worth it now as this function should not be used in
1171+
// that way.
1172+
if ((OrigTy.isScalableVector() && TargetTy.isFixedVector()) ||
1173+
(OrigTy.isFixedVector() && TargetTy.isScalableVector()))
1174+
llvm_unreachable(
1175+
"getGCDType not implemented between fixed and scalable vectors.");
1176+
1177+
unsigned GCD = std::gcd(OrigTy.getElementCount().getKnownMinValue() *
1178+
OrigElt.getSizeInBits().getFixedValue(),
1179+
TargetTy.getElementCount().getKnownMinValue() *
1180+
TargetElt.getSizeInBits().getFixedValue());
11801181
if (GCD == OrigElt.getSizeInBits())
1181-
return OrigElt;
1182+
return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
1183+
OrigElt);
11821184

1183-
// If we can't produce the original element type, we have to use a smaller
1184-
// scalar.
1185+
// Cannot produce original element type, but both have vscale in common.
11851186
if (GCD < OrigElt.getSizeInBits())
1186-
return LLT::scalar(GCD);
1187-
return LLT::fixed_vector(GCD / OrigElt.getSizeInBits(), OrigElt);
1188-
}
1187+
return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
1188+
GCD);
11891189

1190-
if (TargetTy.isVector()) {
1191-
// Try to preserve the original element type.
1192-
LLT TargetElt = TargetTy.getElementType();
1193-
if (TargetElt.getSizeInBits() == OrigSize)
1194-
return OrigTy;
1190+
return LLT::vector(
1191+
ElementCount::get(GCD / OrigElt.getSizeInBits().getFixedValue(),
1192+
OrigTy.isScalable()),
1193+
OrigElt);
11951194
}
11961195

1197-
unsigned GCD = std::gcd(OrigSize, TargetSize);
1196+
// If one type is vector and the element size matches the scalar size, then
1197+
// the gcd is the scalar type.
1198+
if (OrigTy.isVector() &&
1199+
OrigTy.getElementType().getSizeInBits() == TargetTy.getSizeInBits())
1200+
return OrigTy.getElementType();
1201+
if (TargetTy.isVector() &&
1202+
TargetTy.getElementType().getSizeInBits() == OrigTy.getSizeInBits())
1203+
return OrigTy;
1204+
1205+
// At this point, both types are either scalars of different type or one is a
1206+
// vector and one is a scalar. If both types are scalars, the GCD type is the
1207+
// GCD between the two scalar sizes. If one is vector and one is scalar, then
1208+
// the GCD type is the GCD between the scalar and the vector element size.
1209+
LLT OrigScalar = OrigTy.isVector() ? OrigTy.getElementType() : OrigTy;
1210+
LLT TargetScalar = TargetTy.isVector() ? TargetTy.getElementType() : TargetTy;
1211+
unsigned GCD = std::gcd(OrigScalar.getSizeInBits().getFixedValue(),
1212+
TargetScalar.getSizeInBits().getFixedValue());
11981213
return LLT::scalar(GCD);
11991214
}
12001215

llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,62 @@ TEST(GISelUtilsTest, getGCDType) {
183183

184184
EXPECT_EQ(LLT::scalar(4), getGCDType(LLT::fixed_vector(3, 4), S8));
185185
EXPECT_EQ(LLT::scalar(4), getGCDType(S8, LLT::fixed_vector(3, 4)));
186+
187+
// Scalable -> Scalable
188+
EXPECT_EQ(NXV1S1, getGCDType(NXV1S1, NXV1S32));
189+
EXPECT_EQ(NXV1S32, getGCDType(NXV1S64, NXV1S32));
190+
EXPECT_EQ(NXV1S32, getGCDType(NXV1S32, NXV1S64));
191+
EXPECT_EQ(NXV1P0, getGCDType(NXV1P0, NXV1S64));
192+
EXPECT_EQ(NXV1S64, getGCDType(NXV1S64, NXV1P0));
193+
194+
EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV4S32));
195+
EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV4S32));
196+
EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV4S64));
197+
EXPECT_EQ(NXV4P0, getGCDType(NXV4P0, NXV4S64));
198+
EXPECT_EQ(NXV4S64, getGCDType(NXV4S64, NXV4P0));
199+
200+
EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV2S32));
201+
EXPECT_EQ(NXV1S64, getGCDType(NXV4S64, NXV2S32));
202+
EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV2S64));
203+
EXPECT_EQ(NXV2P0, getGCDType(NXV4P0, NXV2S64));
204+
EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV2P0));
205+
206+
EXPECT_EQ(NXV2S1, getGCDType(NXV2S1, NXV4S32));
207+
EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4S32));
208+
EXPECT_EQ(NXV2S32, getGCDType(NXV2S32, NXV4S64));
209+
EXPECT_EQ(NXV2P0, getGCDType(NXV2P0, NXV4S64));
210+
EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4P0));
211+
212+
EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S32));
213+
EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S32));
214+
EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S64));
215+
EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4S64));
216+
EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4P0));
217+
218+
EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S1));
219+
EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S32));
220+
EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S64));
221+
EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4P0));
222+
223+
// Scalable, Scalar
224+
225+
EXPECT_EQ(S1, getGCDType(NXV1S1, S1));
226+
EXPECT_EQ(S1, getGCDType(NXV1S1, S32));
227+
EXPECT_EQ(S1, getGCDType(NXV1S32, S1));
228+
EXPECT_EQ(S32, getGCDType(NXV1S32, S32));
229+
EXPECT_EQ(S32, getGCDType(NXV1S32, S64));
230+
EXPECT_EQ(S1, getGCDType(NXV2S32, S1));
231+
EXPECT_EQ(S32, getGCDType(NXV2S32, S32));
232+
EXPECT_EQ(S32, getGCDType(NXV2S32, S64));
233+
234+
EXPECT_EQ(S1, getGCDType(S1, NXV1S1));
235+
EXPECT_EQ(S1, getGCDType(S32, NXV1S1));
236+
EXPECT_EQ(S1, getGCDType(S1, NXV1S32));
237+
EXPECT_EQ(S32, getGCDType(S32, NXV1S32));
238+
EXPECT_EQ(S32, getGCDType(S64, NXV1S32));
239+
EXPECT_EQ(S1, getGCDType(S1, NXV2S32));
240+
EXPECT_EQ(S32, getGCDType(S32, NXV2S32));
241+
EXPECT_EQ(S32, getGCDType(S64, NXV2S32));
186242
}
187243

188244
TEST(GISelUtilsTest, getLCMType) {

0 commit comments

Comments
 (0)