Skip to content

Commit a3010c7

Browse files
authored
[GlobalISel] Add boolean predicated legalization action methods. (#111287)
Under AArch64 it is common and will become more common to have operation legalization rules dependant on a feature of the architecture. For example HasFP16 or the newer CSSC integer min/max instructions, among many others. With the current legalization rules this either means adding a custom predicate based on the feature as in `legalIf([=](const LegalityQuery &Query) { return HasFP16 && ...; }` or splitting the legalization rules into pieces that place rules optionally into them base on the features available. This patch proposes an alternative where the existing routines like legalFor(..) are provided a boolean predicate, which if false skips adding the rule. It makes the rules cleaner and will hopefully allow them to scale better as we add more features. The SVE predicates for loads/stores I have changed to just be always available. Scalable vectors without SVE have never been supported, but it could also add a condition.
1 parent f5d3c87 commit a3010c7

File tree

3 files changed

+119
-165
lines changed

3 files changed

+119
-165
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,22 @@ class LegalizeRuleSet {
599599
LegalizeRuleSet &legalFor(std::initializer_list<LLT> Types) {
600600
return actionFor(LegalizeAction::Legal, Types);
601601
}
602+
LegalizeRuleSet &legalFor(bool Pred, std::initializer_list<LLT> Types) {
603+
if (!Pred)
604+
return *this;
605+
return actionFor(LegalizeAction::Legal, Types);
606+
}
602607
/// The instruction is legal when type indexes 0 and 1 is any type pair in the
603608
/// given list.
604609
LegalizeRuleSet &legalFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
605610
return actionFor(LegalizeAction::Legal, Types);
606611
}
612+
LegalizeRuleSet &legalFor(bool Pred,
613+
std::initializer_list<std::pair<LLT, LLT>> Types) {
614+
if (!Pred)
615+
return *this;
616+
return actionFor(LegalizeAction::Legal, Types);
617+
}
607618
/// The instruction is legal when type index 0 is any type in the given list
608619
/// and imm index 0 is anything.
609620
LegalizeRuleSet &legalForTypeWithAnyImm(std::initializer_list<LLT> Types) {
@@ -846,12 +857,23 @@ class LegalizeRuleSet {
846857
LegalizeRuleSet &customFor(std::initializer_list<LLT> Types) {
847858
return actionFor(LegalizeAction::Custom, Types);
848859
}
860+
LegalizeRuleSet &customFor(bool Pred, std::initializer_list<LLT> Types) {
861+
if (!Pred)
862+
return *this;
863+
return actionFor(LegalizeAction::Custom, Types);
864+
}
849865

850-
/// The instruction is custom when type indexes 0 and 1 is any type pair in the
851-
/// given list.
866+
/// The instruction is custom when type indexes 0 and 1 is any type pair in
867+
/// the given list.
852868
LegalizeRuleSet &customFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
853869
return actionFor(LegalizeAction::Custom, Types);
854870
}
871+
LegalizeRuleSet &customFor(bool Pred,
872+
std::initializer_list<std::pair<LLT, LLT>> Types) {
873+
if (!Pred)
874+
return *this;
875+
return actionFor(LegalizeAction::Custom, Types);
876+
}
855877

856878
LegalizeRuleSet &customForCartesianProduct(std::initializer_list<LLT> Types) {
857879
return actionForCartesianProduct(LegalizeAction::Custom, Types);
@@ -990,6 +1012,11 @@ class LegalizeRuleSet {
9901012
scalarNarrowerThan(TypeIdx, Ty.getSizeInBits()),
9911013
changeTo(typeIdx(TypeIdx), Ty));
9921014
}
1015+
LegalizeRuleSet &minScalar(bool Pred, unsigned TypeIdx, const LLT Ty) {
1016+
if (!Pred)
1017+
return *this;
1018+
return minScalar(TypeIdx, Ty);
1019+
}
9931020

9941021
/// Ensure the scalar is at least as wide as Ty if condition is met.
9951022
LegalizeRuleSet &minScalarIf(LegalityPredicate Predicate, unsigned TypeIdx,

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp

Lines changed: 54 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -215,19 +215,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
215215
.legalFor({s64, v8s16, v16s8, v4s32})
216216
.lower();
217217

218-
auto &MinMaxActions = getActionDefinitionsBuilder(
219-
{G_SMIN, G_SMAX, G_UMIN, G_UMAX});
220-
if (HasCSSC)
221-
MinMaxActions
222-
.legalFor({s32, s64, v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
223-
// Making clamping conditional on CSSC extension as without legal types we
224-
// lower to CMP which can fold one of the two sxtb's we'd otherwise need
225-
// if we detect a type smaller than 32-bit.
226-
.minScalar(0, s32);
227-
else
228-
MinMaxActions
229-
.legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32});
230-
MinMaxActions
218+
getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
219+
.legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
220+
.legalFor(HasCSSC, {s32, s64})
221+
.minScalar(HasCSSC, 0, s32)
231222
.clampNumElements(0, v8s8, v16s8)
232223
.clampNumElements(0, v4s16, v8s16)
233224
.clampNumElements(0, v2s32, v4s32)
@@ -247,11 +238,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
247238
{G_FADD, G_FSUB, G_FMUL, G_FDIV, G_FMA, G_FSQRT, G_FMAXNUM, G_FMINNUM,
248239
G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR, G_FRINT, G_FNEARBYINT,
249240
G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
250-
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
251-
.legalIf([=](const LegalityQuery &Query) {
252-
const auto &Ty = Query.Types[0];
253-
return (Ty == v8s16 || Ty == v4s16) && HasFP16;
254-
})
241+
.legalFor({s32, s64, v2s32, v4s32, v2s64})
242+
.legalFor(HasFP16, {s16, v4s16, v8s16})
255243
.libcallFor({s128})
256244
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
257245
.minScalarOrElt(0, MinFPScalar)
@@ -261,11 +249,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
261249
.moreElementsToNextPow2(0);
262250

263251
getActionDefinitionsBuilder({G_FABS, G_FNEG})
264-
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
265-
.legalIf([=](const LegalityQuery &Query) {
266-
const auto &Ty = Query.Types[0];
267-
return (Ty == v8s16 || Ty == v4s16) && HasFP16;
268-
})
252+
.legalFor({s32, s64, v2s32, v4s32, v2s64})
253+
.legalFor(HasFP16, {s16, v4s16, v8s16})
269254
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
270255
.lowerIf(scalarOrEltWiderThan(0, 64))
271256
.clampNumElements(0, v4s16, v8s16)
@@ -350,31 +335,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
350335
return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
351336
};
352337

353-
auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
354-
auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
355-
356-
if (ST.hasSVE()) {
357-
LoadActions.legalForTypesWithMemDesc({
358-
// 128 bit base sizes
359-
{nxv16s8, p0, nxv16s8, 8},
360-
{nxv8s16, p0, nxv8s16, 8},
361-
{nxv4s32, p0, nxv4s32, 8},
362-
{nxv2s64, p0, nxv2s64, 8},
363-
});
364-
365-
// TODO: Add nxv2p0. Consider bitcastIf.
366-
// See #92130
367-
// https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
368-
StoreActions.legalForTypesWithMemDesc({
369-
// 128 bit base sizes
370-
{nxv16s8, p0, nxv16s8, 8},
371-
{nxv8s16, p0, nxv8s16, 8},
372-
{nxv4s32, p0, nxv4s32, 8},
373-
{nxv2s64, p0, nxv2s64, 8},
374-
});
375-
}
376-
377-
LoadActions
338+
getActionDefinitionsBuilder(G_LOAD)
378339
.customIf([=](const LegalityQuery &Query) {
379340
return HasRCPC3 && Query.Types[0] == s128 &&
380341
Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
@@ -399,6 +360,13 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
399360
// These extends are also legal
400361
.legalForTypesWithMemDesc(
401362
{{s32, p0, s8, 8}, {s32, p0, s16, 8}, {s64, p0, s32, 8}})
363+
.legalForTypesWithMemDesc({
364+
// SVE vscale x 128 bit base sizes
365+
{nxv16s8, p0, nxv16s8, 8},
366+
{nxv8s16, p0, nxv8s16, 8},
367+
{nxv4s32, p0, nxv4s32, 8},
368+
{nxv2s64, p0, nxv2s64, 8},
369+
})
402370
.widenScalarToNextPow2(0, /* MinSize = */ 8)
403371
.clampMaxNumElements(0, s8, 16)
404372
.clampMaxNumElements(0, s16, 8)
@@ -425,7 +393,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
425393
.scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0)
426394
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0);
427395

428-
StoreActions
396+
getActionDefinitionsBuilder(G_STORE)
429397
.customIf([=](const LegalityQuery &Query) {
430398
return HasRCPC3 && Query.Types[0] == s128 &&
431399
Query.MMODescrs[0].Ordering == AtomicOrdering::Release;
@@ -445,6 +413,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
445413
{p0, p0, s64, 8}, {s128, p0, s128, 8}, {v16s8, p0, s128, 8},
446414
{v8s8, p0, s64, 8}, {v4s16, p0, s64, 8}, {v8s16, p0, s128, 8},
447415
{v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8}})
416+
.legalForTypesWithMemDesc({
417+
// SVE vscale x 128 bit base sizes
418+
// TODO: Add nxv2p0. Consider bitcastIf.
419+
// See #92130
420+
// https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
421+
{nxv16s8, p0, nxv16s8, 8},
422+
{nxv8s16, p0, nxv8s16, 8},
423+
{nxv4s32, p0, nxv4s32, 8},
424+
{nxv2s64, p0, nxv2s64, 8},
425+
})
448426
.clampScalar(0, s8, s64)
449427
.lowerIf([=](const LegalityQuery &Query) {
450428
return Query.Types[0].isScalar() &&
@@ -532,12 +510,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
532510
.widenScalarToNextPow2(0)
533511
.clampScalar(0, s8, s64);
534512
getActionDefinitionsBuilder(G_FCONSTANT)
535-
.legalIf([=](const LegalityQuery &Query) {
536-
const auto &Ty = Query.Types[0];
537-
if (HasFP16 && Ty == s16)
538-
return true;
539-
return Ty == s32 || Ty == s64 || Ty == s128;
540-
})
513+
.legalFor({s32, s64, s128})
514+
.legalFor(HasFP16, {s16})
541515
.clampScalar(0, MinFPScalar, s128);
542516

543517
// FIXME: fix moreElementsToNextPow2
@@ -569,16 +543,12 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
569543
.customIf(isVector(0));
570544

571545
getActionDefinitionsBuilder(G_FCMP)
572-
.legalFor({{s32, MinFPScalar},
573-
{s32, s32},
546+
.legalFor({{s32, s32},
574547
{s32, s64},
575548
{v4s32, v4s32},
576549
{v2s32, v2s32},
577550
{v2s64, v2s64}})
578-
.legalIf([=](const LegalityQuery &Query) {
579-
const auto &Ty = Query.Types[1];
580-
return (Ty == v8s16 || Ty == v4s16) && Ty == Query.Types[0] && HasFP16;
581-
})
551+
.legalFor(HasFP16, {{s32, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
582552
.widenScalarOrEltToNextPow2(1)
583553
.clampScalar(0, s32, s32)
584554
.minScalarOrElt(1, MinFPScalar)
@@ -693,13 +663,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
693663
{v2s64, v2s64},
694664
{v4s32, v4s32},
695665
{v2s32, v2s32}})
696-
.legalIf([=](const LegalityQuery &Query) {
697-
return HasFP16 &&
698-
(Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
699-
Query.Types[1] == v8s16) &&
700-
(Query.Types[0] == s32 || Query.Types[0] == s64 ||
701-
Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
702-
})
666+
.legalFor(HasFP16,
667+
{{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
703668
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
704669
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
705670
// The range of a fp16 value fits into an i17, so we can lower the width
@@ -741,13 +706,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
741706
{v2s64, v2s64},
742707
{v4s32, v4s32},
743708
{v2s32, v2s32}})
744-
.legalIf([=](const LegalityQuery &Query) {
745-
return HasFP16 &&
746-
(Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
747-
Query.Types[1] == v8s16) &&
748-
(Query.Types[0] == s32 || Query.Types[0] == s64 ||
749-
Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
750-
})
709+
.legalFor(HasFP16,
710+
{{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
751711
// Handle types larger than i64 by scalarizing/lowering.
752712
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
753713
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
@@ -790,13 +750,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
790750
{v2s64, v2s64},
791751
{v4s32, v4s32},
792752
{v2s32, v2s32}})
793-
.legalIf([=](const LegalityQuery &Query) {
794-
return HasFP16 &&
795-
(Query.Types[0] == s16 || Query.Types[0] == v4s16 ||
796-
Query.Types[0] == v8s16) &&
797-
(Query.Types[1] == s32 || Query.Types[1] == s64 ||
798-
Query.Types[1] == v4s16 || Query.Types[1] == v8s16);
799-
})
753+
.legalFor(HasFP16,
754+
{{s16, s32}, {s16, s64}, {v4s16, v4s16}, {v8s16, v8s16}})
800755
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
801756
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
802757
.moreElementsToNextPow2(1)
@@ -1050,12 +1005,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
10501005
.widenScalarToNextPow2(1, /*Min=*/32)
10511006
.clampScalar(1, s32, s64)
10521007
.scalarSameSizeAs(0, 1)
1053-
.legalIf([=](const LegalityQuery &Query) {
1054-
return (HasCSSC && typeInSet(0, {s32, s64})(Query));
1055-
})
1056-
.customIf([=](const LegalityQuery &Query) {
1057-
return (!HasCSSC && typeInSet(0, {s32, s64})(Query));
1058-
});
1008+
.legalFor(HasCSSC, {s32, s64})
1009+
.customFor(!HasCSSC, {s32, s64});
10591010

10601011
getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
10611012
.legalIf([=](const LegalityQuery &Query) {
@@ -1143,11 +1094,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
11431094
}
11441095

11451096
// FIXME: Legal vector types are only legal with NEON.
1146-
auto &ABSActions = getActionDefinitionsBuilder(G_ABS);
1147-
if (HasCSSC)
1148-
ABSActions
1149-
.legalFor({s32, s64});
1150-
ABSActions.legalFor(PackedVectorAllTypeList)
1097+
getActionDefinitionsBuilder(G_ABS)
1098+
.legalFor(HasCSSC, {s32, s64})
1099+
.legalFor(PackedVectorAllTypeList)
11511100
.customIf([=](const LegalityQuery &Q) {
11521101
// TODO: Fix suboptimal codegen for 128+ bit types.
11531102
LLT SrcTy = Q.Types[0];
@@ -1171,10 +1120,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
11711120
// later.
11721121
getActionDefinitionsBuilder(G_VECREDUCE_FADD)
11731122
.legalFor({{s32, v2s32}, {s32, v4s32}, {s64, v2s64}})
1174-
.legalIf([=](const LegalityQuery &Query) {
1175-
const auto &Ty = Query.Types[1];
1176-
return (Ty == v4s16 || Ty == v8s16) && HasFP16;
1177-
})
1123+
.legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
11781124
.minScalarOrElt(0, MinFPScalar)
11791125
.clampMaxNumElements(1, s64, 2)
11801126
.clampMaxNumElements(1, s32, 4)
@@ -1215,10 +1161,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
12151161
getActionDefinitionsBuilder({G_VECREDUCE_FMIN, G_VECREDUCE_FMAX,
12161162
G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM})
12171163
.legalFor({{s32, v4s32}, {s32, v2s32}, {s64, v2s64}})
1218-
.legalIf([=](const LegalityQuery &Query) {
1219-
const auto &Ty = Query.Types[1];
1220-
return Query.Types[0] == s16 && (Ty == v8s16 || Ty == v4s16) && HasFP16;
1221-
})
1164+
.legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
12221165
.minScalarOrElt(0, MinFPScalar)
12231166
.clampMaxNumElements(1, s64, 2)
12241167
.clampMaxNumElements(1, s32, 4)
@@ -1295,32 +1238,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
12951238
.customFor({{s32, s32}, {s64, s64}});
12961239

12971240
auto always = [=](const LegalityQuery &Q) { return true; };
1298-
auto &CTPOPActions = getActionDefinitionsBuilder(G_CTPOP);
1299-
if (HasCSSC)
1300-
CTPOPActions
1301-
.legalFor({{s32, s32},
1302-
{s64, s64},
1303-
{v8s8, v8s8},
1304-
{v16s8, v16s8}})
1305-
.customFor({{s128, s128},
1306-
{v2s64, v2s64},
1307-
{v2s32, v2s32},
1308-
{v4s32, v4s32},
1309-
{v4s16, v4s16},
1310-
{v8s16, v8s16}});
1311-
else
1312-
CTPOPActions
1313-
.legalFor({{v8s8, v8s8},
1314-
{v16s8, v16s8}})
1315-
.customFor({{s32, s32},
1316-
{s64, s64},
1317-
{s128, s128},
1318-
{v2s64, v2s64},
1319-
{v2s32, v2s32},
1320-
{v4s32, v4s32},
1321-
{v4s16, v4s16},
1322-
{v8s16, v8s16}});
1323-
CTPOPActions
1241+
getActionDefinitionsBuilder(G_CTPOP)
1242+
.legalFor(HasCSSC, {{s32, s32}, {s64, s64}})
1243+
.legalFor({{v8s8, v8s8}, {v16s8, v16s8}})
1244+
.customFor(!HasCSSC, {{s32, s32}, {s64, s64}})
1245+
.customFor({{s128, s128},
1246+
{v2s64, v2s64},
1247+
{v2s32, v2s32},
1248+
{v4s32, v4s32},
1249+
{v4s16, v4s16},
1250+
{v8s16, v8s16}})
13241251
.clampScalar(0, s32, s128)
13251252
.widenScalarToNextPow2(0)
13261253
.minScalarEltSameAsIf(always, 1, 0)

0 commit comments

Comments
 (0)