-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64] Improve lowering of truncating uzp1 #82457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21034,12 +21034,8 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG, | |
} | ||
} | ||
|
||
// uzp1(xtn x, xtn y) -> xtn(uzp1 (x, y)) | ||
// Only implemented on little-endian subtargets. | ||
bool IsLittleEndian = DAG.getDataLayout().isLittleEndian(); | ||
|
||
// This optimization only works on little endian. | ||
if (!IsLittleEndian) | ||
// These optimizations only work on little endian. | ||
if (!DAG.getDataLayout().isLittleEndian()) | ||
return SDValue(); | ||
|
||
// uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y) | ||
|
@@ -21058,21 +21054,28 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG, | |
if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8) | ||
return SDValue(); | ||
|
||
auto getSourceOp = [](SDValue Operand) -> SDValue { | ||
const unsigned Opcode = Operand.getOpcode(); | ||
if (Opcode == ISD::TRUNCATE) | ||
return Operand->getOperand(0); | ||
if (Opcode == ISD::BITCAST && | ||
Operand->getOperand(0).getOpcode() == ISD::TRUNCATE) | ||
return Operand->getOperand(0)->getOperand(0); | ||
return SDValue(); | ||
}; | ||
SDValue SourceOp0 = peekThroughBitcasts(Op0); | ||
SDValue SourceOp1 = peekThroughBitcasts(Op1); | ||
|
||
SDValue SourceOp0 = getSourceOp(Op0); | ||
SDValue SourceOp1 = getSourceOp(Op1); | ||
// truncating uzp1(x, y) -> xtn(concat (x, y)) | ||
if (SourceOp0.getValueType() == SourceOp1.getValueType()) { | ||
EVT Op0Ty = SourceOp0.getValueType(); | ||
if ((ResVT == MVT::v4i16 && Op0Ty == MVT::v2i32) || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realised that we actually don't care about the operand type. The Arm developer website does say this about uzp1:
So really you should be able to always transform this into truncate(concat(x, y)) regardless of the type of x or y. If x or y have the wrong type you can just create a new bitcast and write the code like this, i.e.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You suggested change does break some tests. I looked at a couple of examples it looks like the bitcast is moved between the concat and trunc which breaks the patterns. Not sure if it is worth the extra effort, I can take it up later if get time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK yeah it's because when creating CONCAT_VECTORS we should be doing |
||
(ResVT == MVT::v8i8 && Op0Ty == MVT::v4i16)) { | ||
SDValue Concat = | ||
DAG.getNode(ISD::CONCAT_VECTORS, DL, | ||
Op0Ty.getDoubleNumVectorElementsVT(*DAG.getContext()), | ||
SourceOp0, SourceOp1); | ||
return DAG.getNode(ISD::TRUNCATE, DL, ResVT, Concat); | ||
} | ||
} | ||
|
||
if (!SourceOp0 || !SourceOp1) | ||
// uzp1(xtn x, xtn y) -> xtn(uzp1 (x, y)) | ||
if (SourceOp0.getOpcode() != ISD::TRUNCATE || | ||
UsmanNadeem marked this conversation as resolved.
Show resolved
Hide resolved
|
||
SourceOp1.getOpcode() != ISD::TRUNCATE) | ||
return SDValue(); | ||
SourceOp0 = SourceOp0.getOperand(0); | ||
SourceOp1 = SourceOp1.getOperand(0); | ||
|
||
if (SourceOp0.getValueType() != SourceOp1.getValueType() || | ||
!SourceOp0.getValueType().isSimple()) | ||
UsmanNadeem marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6105,26 +6105,39 @@ defm UZP2 : SIMDZipVector<0b101, "uzp2", AArch64uzp2>; | |
defm ZIP1 : SIMDZipVector<0b011, "zip1", AArch64zip1>; | ||
defm ZIP2 : SIMDZipVector<0b111, "zip2", AArch64zip2>; | ||
|
||
def : Pat<(v16i8 (concat_vectors (v8i8 (trunc (v8i16 V128:$Vn))), | ||
(v8i8 (trunc (v8i16 V128:$Vm))))), | ||
(UZP1v16i8 V128:$Vn, V128:$Vm)>; | ||
def : Pat<(v8i16 (concat_vectors (v4i16 (trunc (v4i32 V128:$Vn))), | ||
(v4i16 (trunc (v4i32 V128:$Vm))))), | ||
(UZP1v8i16 V128:$Vn, V128:$Vm)>; | ||
def : Pat<(v4i32 (concat_vectors (v2i32 (trunc (v2i64 V128:$Vn))), | ||
(v2i32 (trunc (v2i64 V128:$Vm))))), | ||
(UZP1v4i32 V128:$Vn, V128:$Vm)>; | ||
// These are the same as above, with an optional assertzext node that can be | ||
// generated from fptoi lowering. | ||
def : Pat<(v16i8 (concat_vectors (v8i8 (assertzext (trunc (v8i16 V128:$Vn)))), | ||
(v8i8 (assertzext (trunc (v8i16 V128:$Vm)))))), | ||
(UZP1v16i8 V128:$Vn, V128:$Vm)>; | ||
def : Pat<(v8i16 (concat_vectors (v4i16 (assertzext (trunc (v4i32 V128:$Vn)))), | ||
(v4i16 (assertzext (trunc (v4i32 V128:$Vm)))))), | ||
(UZP1v8i16 V128:$Vn, V128:$Vm)>; | ||
def : Pat<(v4i32 (concat_vectors (v2i32 (assertzext (trunc (v2i64 V128:$Vn)))), | ||
(v2i32 (assertzext (trunc (v2i64 V128:$Vm)))))), | ||
(UZP1v4i32 V128:$Vn, V128:$Vm)>; | ||
def trunc_optional_assert_ext : PatFrags<(ops node:$op0), | ||
[(trunc node:$op0), | ||
(assertzext (trunc node:$op0)), | ||
(assertsext (trunc node:$op0))]>; | ||
|
||
// concat_vectors(trunc(x), trunc(y)) -> uzp1(x, y) | ||
// concat_vectors(assertzext(trunc(x)), assertzext(trunc(y))) -> uzp1(x, y) | ||
// concat_vectors(assertsext(trunc(x)), assertsext(trunc(y))) -> uzp1(x, y) | ||
class concat_trunc_to_uzp1_pat<ValueType SrcTy, ValueType TruncTy, ValueType ConcatTy> | ||
: Pat<(ConcatTy (concat_vectors (TruncTy (trunc_optional_assert_ext (SrcTy V128:$Vn))), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This pattern is a subset of trunc_concat_trunc_to_xtn_uzp1_pat so it's not obvious to me which one would get picked. I think we normally use the AddedComplexity tablegen property to differentiate between cases like this - perhaps worth seeing if we need to use it here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that the more complex pattern will get picked and this one will get picked because it covers more nodes than the other one. |
||
(TruncTy (trunc_optional_assert_ext (SrcTy V128:$Vm))))), | ||
(!cast<Instruction>("UZP1"#ConcatTy) V128:$Vn, V128:$Vm)>; | ||
def : concat_trunc_to_uzp1_pat<v8i16, v8i8, v16i8>; | ||
def : concat_trunc_to_uzp1_pat<v4i32, v4i16, v8i16>; | ||
def : concat_trunc_to_uzp1_pat<v2i64, v2i32, v4i32>; | ||
|
||
// trunc(concat_vectors(trunc(x), trunc(y))) -> xtn(uzp1(x, y)) | ||
// trunc(concat_vectors(assertzext(trunc(x)), assertzext(trunc(y)))) -> xtn(uzp1(x, y)) | ||
// trunc(concat_vectors(assertsext(trunc(x)), assertsext(trunc(y)))) -> xtn(uzp1(x, y)) | ||
class trunc_concat_trunc_to_xtn_uzp1_pat<ValueType SrcTy, ValueType TruncTy, ValueType ConcatTy, | ||
ValueType Ty> | ||
: Pat<(Ty (trunc_optional_assert_ext | ||
(ConcatTy (concat_vectors | ||
(TruncTy (trunc_optional_assert_ext (SrcTy V128:$Vn))), | ||
(TruncTy (trunc_optional_assert_ext (SrcTy V128:$Vm))))))), | ||
(!cast<Instruction>("XTN"#Ty) (!cast<Instruction>("UZP1"#ConcatTy) V128:$Vn, V128:$Vm))>; | ||
def : trunc_concat_trunc_to_xtn_uzp1_pat<v4i32, v4i16, v8i16, v8i8>; | ||
def : trunc_concat_trunc_to_xtn_uzp1_pat<v2i64, v2i32, v4i32, v4i16>; | ||
|
||
def : Pat<(v8i8 (trunc (concat_vectors (v4i16 V64:$Vn), (v4i16 V64:$Vm)))), | ||
(UZP1v8i8 V64:$Vn, V64:$Vm)>; | ||
def : Pat<(v4i16 (trunc (concat_vectors (v2i32 V64:$Vn), (v2i32 V64:$Vm)))), | ||
(UZP1v4i16 V64:$Vn, V64:$Vm)>; | ||
|
||
def : Pat<(v16i8 (concat_vectors | ||
(v8i8 (trunc (AArch64vlshr (v8i16 V128:$Vn), (i32 8)))), | ||
|
Uh oh!
There was an error while loading. Please reload this page.