Skip to content

Commit 7d94043

Browse files
committed
[AArch64] Legalize MVT::i64x8 in DAG isel lowering
This patch legalizes the Machine Value Type introduced in D94096 for loads and stores. A new target hook named getAsmOperandValueType() is added which maps i512 to MVT::i64x8. GlobalISel falls back to DAG for legalization. Differential Revision: https://reviews.llvm.org/D94097
1 parent 3094e53 commit 7d94043

File tree

12 files changed

+253
-8
lines changed

12 files changed

+253
-8
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,6 +1396,11 @@ class TargetLoweringBase {
13961396
return NVT;
13971397
}
13981398

1399+
virtual EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty,
1400+
bool AllowUnknown = false) const {
1401+
return getValueType(DL, Ty, AllowUnknown);
1402+
}
1403+
13991404
/// Return the EVT corresponding to this LLVM type. This is fixed by the LLVM
14001405
/// operations except for the pointer size. If AllowUnknown is true, this
14011406
/// will return MVT::Other for types with no EVT counterpart (e.g. structs),

llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ bool InlineAsmLowering::lowerInlineAsm(
325325
return false;
326326
}
327327

328-
OpInfo.ConstraintVT = TLI->getValueType(DL, OpTy, true).getSimpleVT();
328+
OpInfo.ConstraintVT =
329+
TLI->getAsmOperandValueType(DL, OpTy, true).getSimpleVT();
329330

330331
} else if (OpInfo.Type == InlineAsm::isOutput && !OpInfo.isIndirect) {
331332
assert(!Call.getType()->isVoidTy() && "Bad inline asm!");
@@ -334,13 +335,17 @@ bool InlineAsmLowering::lowerInlineAsm(
334335
TLI->getSimpleValueType(DL, STy->getElementType(ResNo));
335336
} else {
336337
assert(ResNo == 0 && "Asm only has one result!");
337-
OpInfo.ConstraintVT = TLI->getSimpleValueType(DL, Call.getType());
338+
OpInfo.ConstraintVT =
339+
TLI->getAsmOperandValueType(DL, Call.getType()).getSimpleVT();
338340
}
339341
++ResNo;
340342
} else {
341343
OpInfo.ConstraintVT = MVT::Other;
342344
}
343345

346+
if (OpInfo.ConstraintVT == MVT::i64x8)
347+
return false;
348+
344349
// Compute the constraint code and ConstraintType to use.
345350
computeConstraintToUse(TLI, OpInfo);
346351

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8176,7 +8176,7 @@ class SDISelAsmOperandInfo : public TargetLowering::AsmOperandInfo {
81768176
}
81778177
}
81788178

8179-
return TLI.getValueType(DL, OpTy, true);
8179+
return TLI.getAsmOperandValueType(DL, OpTy, true);
81808180
}
81818181
};
81828182

@@ -8479,8 +8479,8 @@ void SelectionDAGBuilder::visitInlineAsm(const CallBase &Call,
84798479
DAG.getDataLayout(), STy->getElementType(ResNo));
84808480
} else {
84818481
assert(ResNo == 0 && "Asm only has one result!");
8482-
OpInfo.ConstraintVT =
8483-
TLI.getSimpleValueType(DAG.getDataLayout(), Call.getType());
8482+
OpInfo.ConstraintVT = TLI.getAsmOperandValueType(
8483+
DAG.getDataLayout(), Call.getType()).getSimpleVT();
84848484
}
84858485
++ResNo;
84868486
} else {

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4687,7 +4687,8 @@ TargetLowering::ParseConstraints(const DataLayout &DL,
46874687
getSimpleValueType(DL, STy->getElementType(ResNo));
46884688
} else {
46894689
assert(ResNo == 0 && "Asm only has one result!");
4690-
OpInfo.ConstraintVT = getSimpleValueType(DL, Call.getType());
4690+
OpInfo.ConstraintVT =
4691+
getAsmOperandValueType(DL, Call.getType()).getSimpleVT();
46914692
}
46924693
++ResNo;
46934694
break;

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,9 @@ bool AArch64AsmPrinter::printAsmMRegister(const MachineOperand &MO, char Mode,
653653
case 'x':
654654
Reg = getXRegFromWReg(Reg);
655655
break;
656+
case 't':
657+
Reg = getXRegFromXRegTuple(Reg);
658+
break;
656659
}
657660

658661
O << AArch64InstPrinter::getRegisterName(Reg);
@@ -749,6 +752,10 @@ bool AArch64AsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNum,
749752
AArch64::GPR64allRegClass.contains(Reg))
750753
return printAsmMRegister(MO, 'x', O);
751754

755+
// If this is an x register tuple, print an x register.
756+
if (AArch64::GPR64x8ClassRegClass.contains(Reg))
757+
return printAsmMRegister(MO, 't', O);
758+
752759
unsigned AltName = AArch64::NoRegAltName;
753760
const TargetRegisterClass *RegClass;
754761
if (AArch64::ZPRRegClass.contains(Reg)) {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
246246
addRegisterClass(MVT::i32, &AArch64::GPR32allRegClass);
247247
addRegisterClass(MVT::i64, &AArch64::GPR64allRegClass);
248248

249+
if (Subtarget->hasLS64()) {
250+
addRegisterClass(MVT::i64x8, &AArch64::GPR64x8ClassRegClass);
251+
setOperationAction(ISD::LOAD, MVT::i64x8, Custom);
252+
setOperationAction(ISD::STORE, MVT::i64x8, Custom);
253+
}
254+
249255
if (Subtarget->hasFPARMv8()) {
250256
addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
251257
addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
@@ -2023,6 +2029,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
20232029
MAKE_CASE(AArch64ISD::LASTA)
20242030
MAKE_CASE(AArch64ISD::LASTB)
20252031
MAKE_CASE(AArch64ISD::REINTERPRET_CAST)
2032+
MAKE_CASE(AArch64ISD::LS64_BUILD)
2033+
MAKE_CASE(AArch64ISD::LS64_EXTRACT)
20262034
MAKE_CASE(AArch64ISD::TBL)
20272035
MAKE_CASE(AArch64ISD::FADD_PRED)
20282036
MAKE_CASE(AArch64ISD::FADDA_PRED)
@@ -4611,17 +4619,51 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
46114619
{StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()},
46124620
StoreNode->getMemoryVT(), StoreNode->getMemOperand());
46134621
return Result;
4622+
} else if (MemVT == MVT::i64x8) {
4623+
SDValue Value = StoreNode->getValue();
4624+
assert(Value->getValueType(0) == MVT::i64x8);
4625+
SDValue Chain = StoreNode->getChain();
4626+
SDValue Base = StoreNode->getBasePtr();
4627+
EVT PtrVT = Base.getValueType();
4628+
for (unsigned i = 0; i < 8; i++) {
4629+
SDValue Part = DAG.getNode(AArch64ISD::LS64_EXTRACT, Dl, MVT::i64,
4630+
Value, DAG.getConstant(i, Dl, MVT::i32));
4631+
SDValue Ptr = DAG.getNode(ISD::ADD, Dl, PtrVT, Base,
4632+
DAG.getConstant(i * 8, Dl, PtrVT));
4633+
Chain = DAG.getStore(Chain, Dl, Part, Ptr, StoreNode->getPointerInfo(),
4634+
StoreNode->getOriginalAlign());
4635+
}
4636+
return Chain;
46144637
}
46154638

46164639
return SDValue();
46174640
}
46184641

4619-
// Custom lowering for extending v4i8 vector loads.
46204642
SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
46214643
SelectionDAG &DAG) const {
46224644
SDLoc DL(Op);
46234645
LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
46244646
assert(LoadNode && "Expected custom lowering of a load node");
4647+
4648+
if (LoadNode->getMemoryVT() == MVT::i64x8) {
4649+
SmallVector<SDValue, 8> Ops;
4650+
SDValue Base = LoadNode->getBasePtr();
4651+
SDValue Chain = LoadNode->getChain();
4652+
EVT PtrVT = Base.getValueType();
4653+
for (unsigned i = 0; i < 8; i++) {
4654+
SDValue Ptr = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
4655+
DAG.getConstant(i * 8, DL, PtrVT));
4656+
SDValue Part = DAG.getLoad(MVT::i64, DL, Chain, Ptr,
4657+
LoadNode->getPointerInfo(),
4658+
LoadNode->getOriginalAlign());
4659+
Ops.push_back(Part);
4660+
Chain = SDValue(Part.getNode(), 1);
4661+
}
4662+
SDValue Loaded = DAG.getNode(AArch64ISD::LS64_BUILD, DL, MVT::i64x8, Ops);
4663+
return DAG.getMergeValues({Loaded, Chain}, DL);
4664+
}
4665+
4666+
// Custom lowering for extending v4i8 vector loads.
46254667
EVT VT = Op->getValueType(0);
46264668
assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");
46274669

@@ -8179,6 +8221,8 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
81798221
case 'r':
81808222
if (VT.isScalableVector())
81818223
return std::make_pair(0U, nullptr);
8224+
if (Subtarget->hasLS64() && VT.getSizeInBits() == 512)
8225+
return std::make_pair(0U, &AArch64::GPR64x8ClassRegClass);
81828226
if (VT.getFixedSizeInBits() == 64)
81838227
return std::make_pair(0U, &AArch64::GPR64commonRegClass);
81848228
return std::make_pair(0U, &AArch64::GPR32commonRegClass);
@@ -8266,6 +8310,15 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
82668310
return Res;
82678311
}
82688312

8313+
EVT AArch64TargetLowering::getAsmOperandValueType(const DataLayout &DL,
8314+
llvm::Type *Ty,
8315+
bool AllowUnknown) const {
8316+
if (Subtarget->hasLS64() && Ty->isIntegerTy(512))
8317+
return EVT(MVT::i64x8);
8318+
8319+
return TargetLowering::getAsmOperandValueType(DL, Ty, AllowUnknown);
8320+
}
8321+
82698322
/// LowerAsmOperandForConstraint - Lower the specified operand into the Ops
82708323
/// vector. If it is invalid, don't add anything to Ops.
82718324
void AArch64TargetLowering::LowerAsmOperandForConstraint(

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,10 @@ enum NodeType : unsigned {
330330
// Cast between vectors of the same element type but differ in length.
331331
REINTERPRET_CAST,
332332

333+
// Nodes to build an LD64B / ST64B 64-bit quantity out of i64, and vice versa
334+
LS64_BUILD,
335+
LS64_EXTRACT,
336+
333337
LD1_MERGE_ZERO,
334338
LD1S_MERGE_ZERO,
335339
LDNF1_MERGE_ZERO,
@@ -824,6 +828,9 @@ class AArch64TargetLowering : public TargetLowering {
824828
bool isAllActivePredicate(SDValue N) const;
825829
EVT getPromotedVTForPredicate(EVT VT) const;
826830

831+
EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty,
832+
bool AllowUnknown = false) const override;
833+
827834
private:
828835
/// Keep a pointer to the AArch64Subtarget around so that we can
829836
/// make the right decision when generating code for different targets.

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8104,6 +8104,20 @@ let AddedComplexity = 10 in {
81048104
// FIXME: add SVE dot-product patterns.
81058105
}
81068106

8107+
// Custom DAG nodes and isel rules to make a 64-byte block out of eight GPRs,
8108+
// so that it can be used as input to inline asm, and vice versa.
8109+
def LS64_BUILD : SDNode<"AArch64ISD::LS64_BUILD", SDTypeProfile<1, 8, []>>;
8110+
def LS64_EXTRACT : SDNode<"AArch64ISD::LS64_EXTRACT", SDTypeProfile<1, 2, []>>;
8111+
def : Pat<(i64x8 (LS64_BUILD GPR64:$x0, GPR64:$x1, GPR64:$x2, GPR64:$x3,
8112+
GPR64:$x4, GPR64:$x5, GPR64:$x6, GPR64:$x7)),
8113+
(REG_SEQUENCE GPR64x8Class,
8114+
$x0, x8sub_0, $x1, x8sub_1, $x2, x8sub_2, $x3, x8sub_3,
8115+
$x4, x8sub_4, $x5, x8sub_5, $x6, x8sub_6, $x7, x8sub_7)>;
8116+
foreach i = 0-7 in {
8117+
def : Pat<(i64 (LS64_EXTRACT (i64x8 GPR64x8:$val), (i32 i))),
8118+
(EXTRACT_SUBREG $val, !cast<SubRegIndex>("x8sub_"#i))>;
8119+
}
8120+
81078121
let Predicates = [HasLS64] in {
81088122
def LD64B: LoadStore64B<0b101, "ld64b", (ins GPR64sp:$Rn),
81098123
(outs GPR64x8:$Rt)>;

llvm/lib/Target/AArch64/AArch64RegisterInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,9 @@ def Tuples8X : RegisterTuples<
732732
!foreach(i, [0,1,2,3,4,5,6,7], !cast<SubRegIndex>("x8sub_"#i)),
733733
!foreach(i, [0,1,2,3,4,5,6,7], (trunc (decimate (rotl GPR64, i), 2), 12))>;
734734

735-
def GPR64x8Class : RegisterClass<"AArch64", [i64], 64, (trunc Tuples8X, 12)>;
735+
def GPR64x8Class : RegisterClass<"AArch64", [i64x8], 512, (trunc Tuples8X, 12)> {
736+
let Size = 512;
737+
}
736738
def GPR64x8AsmOp : AsmOperandClass {
737739
let Name = "GPR64x8";
738740
let ParserMethod = "tryParseGPR64x8";

llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,25 @@ inline static unsigned getXRegFromWReg(unsigned Reg) {
106106
return Reg;
107107
}
108108

109+
inline static unsigned getXRegFromXRegTuple(unsigned RegTuple) {
110+
switch (RegTuple) {
111+
case AArch64::X0_X1_X2_X3_X4_X5_X6_X7: return AArch64::X0;
112+
case AArch64::X2_X3_X4_X5_X6_X7_X8_X9: return AArch64::X2;
113+
case AArch64::X4_X5_X6_X7_X8_X9_X10_X11: return AArch64::X4;
114+
case AArch64::X6_X7_X8_X9_X10_X11_X12_X13: return AArch64::X6;
115+
case AArch64::X8_X9_X10_X11_X12_X13_X14_X15: return AArch64::X8;
116+
case AArch64::X10_X11_X12_X13_X14_X15_X16_X17: return AArch64::X10;
117+
case AArch64::X12_X13_X14_X15_X16_X17_X18_X19: return AArch64::X12;
118+
case AArch64::X14_X15_X16_X17_X18_X19_X20_X21: return AArch64::X14;
119+
case AArch64::X16_X17_X18_X19_X20_X21_X22_X23: return AArch64::X16;
120+
case AArch64::X18_X19_X20_X21_X22_X23_X24_X25: return AArch64::X18;
121+
case AArch64::X20_X21_X22_X23_X24_X25_X26_X27: return AArch64::X20;
122+
case AArch64::X22_X23_X24_X25_X26_X27_X28_FP: return AArch64::X22;
123+
}
124+
// For anything else, return it unchanged.
125+
return RegTuple;
126+
}
127+
109128
static inline unsigned getBRegFromDReg(unsigned Reg) {
110129
switch (Reg) {
111130
case AArch64::D0: return AArch64::B0;

llvm/test/CodeGen/AArch64/GlobalISel/arm64-fallback.ll

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,32 @@ entry:
126126
ret void
127127
}
128128

129+
%struct.foo = type { [8 x i64] }
130+
131+
; FALLBACK-WITH-REPORT-ERR: remark: <unknown>:0:0: unable to translate instruction:{{.*}}ld64b{{.*}}asm_output_ls64
132+
; FALLBACK-WITH-REPORT-ERR: warning: Instruction selection used fallback path for asm_output_ls64
133+
; FALLBACK-WITH-REPORT-OUT-LABEL: asm_output_ls64
134+
define void @asm_output_ls64(%struct.foo* %output, i8* %addr) #2 {
135+
entry:
136+
%val = call i512 asm sideeffect "ld64b $0,[$1]", "=r,r,~{memory}"(i8* %addr)
137+
%outcast = bitcast %struct.foo* %output to i512*
138+
store i512 %val, i512* %outcast, align 8
139+
ret void
140+
}
141+
142+
; FALLBACK-WITH-REPORT-ERR: remark: <unknown>:0:0: unable to translate instruction:{{.*}}st64b{{.*}}asm_input_ls64
143+
; FALLBACK-WITH-REPORT-ERR: warning: Instruction selection used fallback path for asm_input_ls64
144+
; FALLBACK-WITH-REPORT-OUT-LABEL: asm_input_ls64
145+
define void @asm_input_ls64(%struct.foo* %input, i8* %addr) #2 {
146+
entry:
147+
%incast = bitcast %struct.foo* %input to i512*
148+
%val = load i512, i512* %incast, align 8
149+
call void asm sideeffect "st64b $0,[$1]", "r,r,~{memory}"(i512 %val, i8* %addr)
150+
ret void
151+
}
152+
129153
attributes #1 = { "target-features"="+sve" }
154+
attributes #2 = { "target-features"="+ls64" }
130155

131156
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)
132157
declare <vscale x 16 x i8> @llvm.aarch64.sve.ld1.nxv16i8(<vscale x 16 x i1>, i8*)

0 commit comments

Comments
 (0)