Skip to content

Commit 86d282b

Browse files
committed
[RISCV] Add intrinsics for vmv.x.s and vmv.s.x
This adds intrinsics for vmv.x.s and vmv.s.x. I've used stricter type constraints on these intrinsics than what we've been doing on the arithmetic intrinsics so far. This will allow us to not need to pass the scalar type to the Intrinsic::getDeclaration call when creating these intrinsics. A custom ISD is used for vmv.x.s in order to implement the change in computeNumSignBitsForTargetNode which can remove sign extends on the result. I also modified the MC layer description of these instructions to show the tied source/dest operand. This is different than what we do for masked instructions where we drop the tied source operand when converting to MC. But it is a more accurate description of the instruction. We can't do this for masked instructions since we use the same MC instruction for masked and unmasked. Tools like llvm-mca operate in the MC layer and rely on ins/outs and Uses/Defs for analysis so I don't know if we'll be able to maintain the current behavior for masked instructions. So I went with the accurate description here since it was easy. Reviewed By: frasercrmck Differential Revision: https://reviews.llvm.org/D93365
1 parent 5ac3772 commit 86d282b

File tree

9 files changed

+1149
-10
lines changed

9 files changed

+1149
-10
lines changed

llvm/include/llvm/IR/IntrinsicsRISCV.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,4 +339,14 @@ let TargetPrefix = "riscv" in {
339339

340340
def int_riscv_vmv_v_v : RISCVUnary;
341341
def int_riscv_vmv_v_x : RISCVUnary;
342+
343+
def int_riscv_vmv_x_s : Intrinsic<[LLVMVectorElementType<0>],
344+
[llvm_anyint_ty],
345+
[IntrNoMem]>, RISCVVIntrinsic;
346+
def int_riscv_vmv_s_x : Intrinsic<[llvm_anyint_ty],
347+
[LLVMMatchType<0>, LLVMVectorElementType<0>,
348+
llvm_anyint_ty],
349+
[IntrNoMem]>, RISCVVIntrinsic {
350+
let ExtendOperand = 2;
351+
}
342352
} // TargetPrefix = "riscv"

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -348,14 +348,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
348348
setBooleanVectorContents(ZeroOrOneBooleanContent);
349349

350350
// RVV intrinsics may have illegal operands.
351+
// We also need to custom legalize vmv.x.s.
351352
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
352353
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom);
353354
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
354355
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i16, Custom);
356+
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i32, Custom);
357+
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i32, Custom);
355358

356359
if (Subtarget.is64Bit()) {
357-
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i32, Custom);
358-
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i32, Custom);
360+
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i64, Custom);
361+
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i64, Custom);
359362
}
360363
}
361364

@@ -1039,9 +1042,9 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
10391042
assert(II->ExtendedOperand < Op.getNumOperands());
10401043
SmallVector<SDValue, 8> Operands(Op->op_begin(), Op->op_end());
10411044
SDValue &ScalarOp = Operands[II->ExtendedOperand];
1042-
if (ScalarOp.getValueType() == MVT::i8 ||
1043-
ScalarOp.getValueType() == MVT::i16 ||
1044-
ScalarOp.getValueType() == MVT::i32) {
1045+
EVT OpVT = ScalarOp.getValueType();
1046+
if (OpVT == MVT::i8 || OpVT == MVT::i16 ||
1047+
(OpVT == MVT::i32 && Subtarget.is64Bit())) {
10451048
ScalarOp =
10461049
DAG.getNode(ISD::ANY_EXTEND, DL, Subtarget.getXLenVT(), ScalarOp);
10471050
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
@@ -1058,6 +1061,10 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
10581061
EVT PtrVT = getPointerTy(DAG.getDataLayout());
10591062
return DAG.getRegister(RISCV::X4, PtrVT);
10601063
}
1064+
case Intrinsic::riscv_vmv_x_s:
1065+
assert(Op.getValueType() == Subtarget.getXLenVT() && "Unexpected VT!");
1066+
return DAG.getNode(RISCVISD::VMV_X_S, DL, Op.getValueType(),
1067+
Op.getOperand(1));
10611068
}
10621069
}
10631070

@@ -1077,9 +1084,9 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
10771084
assert(ExtendOp < Op.getNumOperands());
10781085
SmallVector<SDValue, 8> Operands(Op->op_begin(), Op->op_end());
10791086
SDValue &ScalarOp = Operands[ExtendOp];
1080-
if (ScalarOp.getValueType() == MVT::i32 ||
1081-
ScalarOp.getValueType() == MVT::i16 ||
1082-
ScalarOp.getValueType() == MVT::i8) {
1087+
EVT OpVT = ScalarOp.getValueType();
1088+
if (OpVT == MVT::i8 || OpVT == MVT::i16 ||
1089+
(OpVT == MVT::i32 && Subtarget.is64Bit())) {
10831090
ScalarOp =
10841091
DAG.getNode(ISD::ANY_EXTEND, DL, Subtarget.getXLenVT(), ScalarOp);
10851092
return DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, Op->getVTList(), Operands);
@@ -1309,6 +1316,25 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
13091316
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewOp));
13101317
break;
13111318
}
1319+
case ISD::INTRINSIC_WO_CHAIN: {
1320+
unsigned IntNo = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
1321+
switch (IntNo) {
1322+
default:
1323+
llvm_unreachable(
1324+
"Don't know how to custom type legalize this intrinsic!");
1325+
case Intrinsic::riscv_vmv_x_s: {
1326+
EVT VT = N->getValueType(0);
1327+
assert((VT == MVT::i8 || VT == MVT::i16 ||
1328+
(Subtarget.is64Bit() && VT == MVT::i32)) &&
1329+
"Unexpected custom legalisation!");
1330+
SDValue Extract = DAG.getNode(RISCVISD::VMV_X_S, DL,
1331+
Subtarget.getXLenVT(), N->getOperand(1));
1332+
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Extract));
1333+
break;
1334+
}
1335+
}
1336+
break;
1337+
}
13121338
}
13131339
}
13141340

@@ -1730,6 +1756,11 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
17301756
// more precise answer could be calculated for SRAW depending on known
17311757
// bits in the shift amount.
17321758
return 33;
1759+
case RISCVISD::VMV_X_S:
1760+
// The number of sign bits of the scalar result is computed by obtaining the
1761+
// element type of the input vector operand, substracting its width from the
1762+
// XLEN, and then adding one (sign bit within the element type).
1763+
return Subtarget.getXLen() - Op.getOperand(0).getScalarValueSizeInBits() + 1;
17331764
}
17341765

17351766
return 1;
@@ -3369,6 +3400,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
33693400
NODE_NAME_CASE(GREVIW)
33703401
NODE_NAME_CASE(GORCI)
33713402
NODE_NAME_CASE(GORCIW)
3403+
NODE_NAME_CASE(VMV_X_S)
33723404
}
33733405
// clang-format on
33743406
return nullptr;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ enum NodeType : unsigned {
7777
GREVIW,
7878
GORCI,
7979
GORCIW,
80+
// Vector Extension
81+
// VMV_X_S matches the semantics of vmv.x.s. The result is always XLenVT
82+
// sign extended from the vector element size. NOTE: The result size will
83+
// never be less than the vector element size.
84+
VMV_X_S,
8085
};
8186
} // namespace RISCVISD
8287

llvm/lib/Target/RISCV/RISCVInstrInfoV.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -967,8 +967,9 @@ def VID_V : RVInstV<0b010100, 0b10001, OPMVV, (outs VR:$vd),
967967
let vm = 1 in {
968968
def VMV_X_S : RVInstV<0b010000, 0b00000, OPMVV, (outs GPR:$vd),
969969
(ins VR:$vs2), "vmv.x.s", "$vd, $vs2">;
970-
def VMV_S_X : RVInstV2<0b010000, 0b00000, OPMVX, (outs VR:$vd),
971-
(ins GPR:$rs1), "vmv.s.x", "$vd, $rs1">;
970+
let Constraints = "$vd = $vd_wb" in
971+
def VMV_S_X : RVInstV2<0b010000, 0b00000, OPMVX, (outs VR:$vd_wb),
972+
(ins VR:$vd, GPR:$rs1), "vmv.s.x", "$vd, $rs1">;
972973

973974
}
974975
} // hasSideEffects = 0, mayLoad = 0, mayStore = 0

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
///
1515
//===----------------------------------------------------------------------===//
1616

17+
def riscv_vmv_x_s : SDNode<"RISCVISD::VMV_X_S",
18+
SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>,
19+
SDTCisInt<1>]>>;
20+
1721
// X0 has special meaning for vsetvl/vsetvli.
1822
// rd | rs1 | AVL value | Effect on vl
1923
//--------------------------------------------------------------
@@ -1349,6 +1353,30 @@ defm PseudoVFRSUB : VPseudoBinaryV_VX</*IsFloat=*/1>;
13491353

13501354
} // Predicates = [HasStdExtV, HasStdExtF]
13511355

1356+
//===----------------------------------------------------------------------===//
1357+
// 17.1. Integer Scalar Move Instructions
1358+
//===----------------------------------------------------------------------===//
1359+
1360+
let Predicates = [HasStdExtV] in {
1361+
let mayLoad = 0, mayStore = 0, hasSideEffects = 0, usesCustomInserter = 1,
1362+
Uses = [VL, VTYPE] in {
1363+
foreach m = MxList.m in {
1364+
let VLMul = m.value in {
1365+
let SEWIndex = 2, BaseInstr = VMV_X_S in
1366+
def PseudoVMV_X_S # "_" # m.MX: Pseudo<(outs GPR:$rd),
1367+
(ins m.vrclass:$rs2, ixlenimm:$sew),
1368+
[]>, RISCVVPseudo;
1369+
let VLIndex = 3, SEWIndex = 4, BaseInstr = VMV_S_X,
1370+
Constraints = "$rd = $rs1" in
1371+
def PseudoVMV_S_X # "_" # m.MX: Pseudo<(outs m.vrclass:$rd),
1372+
(ins m.vrclass:$rs1, GPR:$rs2,
1373+
GPR:$vl, ixlenimm:$sew),
1374+
[]>, RISCVVPseudo;
1375+
}
1376+
}
1377+
}
1378+
}
1379+
13521380
//===----------------------------------------------------------------------===//
13531381
// Patterns.
13541382
//===----------------------------------------------------------------------===//
@@ -1514,3 +1542,18 @@ defm "" : VPatBinaryV_VV_VX<"int_riscv_vfsub", "PseudoVFSUB", AllFloatVectors>;
15141542
defm "" : VPatBinaryV_VX<"int_riscv_vfrsub", "PseudoVFRSUB", AllFloatVectors>;
15151543

15161544
} // Predicates = [HasStdExtV, HasStdExtF]
1545+
1546+
//===----------------------------------------------------------------------===//
1547+
// 17.1. Integer Scalar Move Instructions
1548+
//===----------------------------------------------------------------------===//
1549+
1550+
let Predicates = [HasStdExtV] in {
1551+
foreach vti = AllIntegerVectors in {
1552+
def : Pat<(riscv_vmv_x_s (vti.Vector vti.RegClass:$rs2)),
1553+
(!cast<Instruction>("PseudoVMV_X_S_" # vti.LMul.MX) $rs2, vti.SEW)>;
1554+
def : Pat<(vti.Vector (int_riscv_vmv_s_x (vti.Vector vti.RegClass:$rs1),
1555+
GPR:$rs2, GPR:$vl)),
1556+
(!cast<Instruction>("PseudoVMV_S_X_" # vti.LMul.MX)
1557+
(vti.Vector $rs1), $rs2, (NoX0 GPR:$vl), vti.SEW)>;
1558+
}
1559+
} // Predicates = [HasStdExtV]

0 commit comments

Comments
 (0)