Skip to content

[AArch64] Improve lowering of truncating uzp1 #82457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21034,12 +21034,8 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
}
}

// uzp1(xtn x, xtn y) -> xtn(uzp1 (x, y))
// Only implemented on little-endian subtargets.
bool IsLittleEndian = DAG.getDataLayout().isLittleEndian();

// This optimization only works on little endian.
if (!IsLittleEndian)
// These optimizations only work on little endian.
if (!DAG.getDataLayout().isLittleEndian())
return SDValue();

// uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
Expand All @@ -21058,21 +21054,28 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
return SDValue();

auto getSourceOp = [](SDValue Operand) -> SDValue {
const unsigned Opcode = Operand.getOpcode();
if (Opcode == ISD::TRUNCATE)
return Operand->getOperand(0);
if (Opcode == ISD::BITCAST &&
Operand->getOperand(0).getOpcode() == ISD::TRUNCATE)
return Operand->getOperand(0)->getOperand(0);
return SDValue();
};
SDValue SourceOp0 = peekThroughBitcasts(Op0);
SDValue SourceOp1 = peekThroughBitcasts(Op1);

SDValue SourceOp0 = getSourceOp(Op0);
SDValue SourceOp1 = getSourceOp(Op1);
// truncating uzp1(x, y) -> xtn(concat (x, y))
if (SourceOp0.getValueType() == SourceOp1.getValueType()) {
EVT Op0Ty = SourceOp0.getValueType();
if ((ResVT == MVT::v4i16 && Op0Ty == MVT::v2i32) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realised that we actually don't care about the operand type. The Arm developer website does say this about uzp1:

Note: UZP1 is equivalent to truncating and packing each element from two source vectors into a single destination vector with elements of half the size.

So really you should be able to always transform this into truncate(concat(x, y)) regardless of the type of x or y. If x or y have the wrong type you can just create a new bitcast and write the code like this, i.e.

  if (SourceOp0.getValueType() == SourceOp1.getValueType() && (ResVT == MVT::v4i16 || ResVT == MVT::v8i8)) {
    EVT Op0Ty = SourceOp0.getValueType();
    EVT RequiredOpTy = ResVT == MVT::v4i16 ? MVT::v2i32 : MVT::v4i16;
    if (Op0Ty != RequiredOpTy) {
      SourceOp0 = DAG.getNode(ISD::BITCAST, DL, RequiredOpTy, SourceOp0);
      SourceOp1 = DAG.getNode(ISD::BITCAST, DL, RequiredOpTy, SourceOp1);
    }
    SDValue Concat =
           DAG.getNode(ISD::CONCAT_VECTORS, DL,
                       Op0Ty.getDoubleNumVectorElementsVT(*DAG.getContext()),
                       SourceOp0, SourceOp1);
    return DAG.getNode(ISD::TRUNCATE, DL, ResVT, Concat);
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You suggested change does break some tests. I looked at a couple of examples it looks like the bitcast is moved between the concat and trunc which breaks the patterns. Not sure if it is worth the extra effort, I can take it up later if get time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK yeah it's because when creating CONCAT_VECTORS we should be doing RequiredOpTy.getDoubleNumVectorElementsVT(*DAG.getContext()). Anyway, I tried out this change myself and makes the code worse for many tests so just ignore me. :)

(ResVT == MVT::v8i8 && Op0Ty == MVT::v4i16)) {
SDValue Concat =
DAG.getNode(ISD::CONCAT_VECTORS, DL,
Op0Ty.getDoubleNumVectorElementsVT(*DAG.getContext()),
SourceOp0, SourceOp1);
return DAG.getNode(ISD::TRUNCATE, DL, ResVT, Concat);
}
}

if (!SourceOp0 || !SourceOp1)
// uzp1(xtn x, xtn y) -> xtn(uzp1 (x, y))
if (SourceOp0.getOpcode() != ISD::TRUNCATE ||
SourceOp1.getOpcode() != ISD::TRUNCATE)
return SDValue();
SourceOp0 = SourceOp0.getOperand(0);
SourceOp1 = SourceOp1.getOperand(0);

if (SourceOp0.getValueType() != SourceOp1.getValueType() ||
!SourceOp0.getValueType().isSimple())
Expand Down
53 changes: 33 additions & 20 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -6105,26 +6105,39 @@ defm UZP2 : SIMDZipVector<0b101, "uzp2", AArch64uzp2>;
defm ZIP1 : SIMDZipVector<0b011, "zip1", AArch64zip1>;
defm ZIP2 : SIMDZipVector<0b111, "zip2", AArch64zip2>;

def : Pat<(v16i8 (concat_vectors (v8i8 (trunc (v8i16 V128:$Vn))),
(v8i8 (trunc (v8i16 V128:$Vm))))),
(UZP1v16i8 V128:$Vn, V128:$Vm)>;
def : Pat<(v8i16 (concat_vectors (v4i16 (trunc (v4i32 V128:$Vn))),
(v4i16 (trunc (v4i32 V128:$Vm))))),
(UZP1v8i16 V128:$Vn, V128:$Vm)>;
def : Pat<(v4i32 (concat_vectors (v2i32 (trunc (v2i64 V128:$Vn))),
(v2i32 (trunc (v2i64 V128:$Vm))))),
(UZP1v4i32 V128:$Vn, V128:$Vm)>;
// These are the same as above, with an optional assertzext node that can be
// generated from fptoi lowering.
def : Pat<(v16i8 (concat_vectors (v8i8 (assertzext (trunc (v8i16 V128:$Vn)))),
(v8i8 (assertzext (trunc (v8i16 V128:$Vm)))))),
(UZP1v16i8 V128:$Vn, V128:$Vm)>;
def : Pat<(v8i16 (concat_vectors (v4i16 (assertzext (trunc (v4i32 V128:$Vn)))),
(v4i16 (assertzext (trunc (v4i32 V128:$Vm)))))),
(UZP1v8i16 V128:$Vn, V128:$Vm)>;
def : Pat<(v4i32 (concat_vectors (v2i32 (assertzext (trunc (v2i64 V128:$Vn)))),
(v2i32 (assertzext (trunc (v2i64 V128:$Vm)))))),
(UZP1v4i32 V128:$Vn, V128:$Vm)>;
def trunc_optional_assert_ext : PatFrags<(ops node:$op0),
[(trunc node:$op0),
(assertzext (trunc node:$op0)),
(assertsext (trunc node:$op0))]>;

// concat_vectors(trunc(x), trunc(y)) -> uzp1(x, y)
// concat_vectors(assertzext(trunc(x)), assertzext(trunc(y))) -> uzp1(x, y)
// concat_vectors(assertsext(trunc(x)), assertsext(trunc(y))) -> uzp1(x, y)
class concat_trunc_to_uzp1_pat<ValueType SrcTy, ValueType TruncTy, ValueType ConcatTy>
: Pat<(ConcatTy (concat_vectors (TruncTy (trunc_optional_assert_ext (SrcTy V128:$Vn))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern is a subset of trunc_concat_trunc_to_xtn_uzp1_pat so it's not obvious to me which one would get picked. I think we normally use the AddedComplexity tablegen property to differentiate between cases like this - perhaps worth seeing if we need to use it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the more complex pattern will get picked and this one will get picked because it covers more nodes than the other one.

(TruncTy (trunc_optional_assert_ext (SrcTy V128:$Vm))))),
(!cast<Instruction>("UZP1"#ConcatTy) V128:$Vn, V128:$Vm)>;
def : concat_trunc_to_uzp1_pat<v8i16, v8i8, v16i8>;
def : concat_trunc_to_uzp1_pat<v4i32, v4i16, v8i16>;
def : concat_trunc_to_uzp1_pat<v2i64, v2i32, v4i32>;

// trunc(concat_vectors(trunc(x), trunc(y))) -> xtn(uzp1(x, y))
// trunc(concat_vectors(assertzext(trunc(x)), assertzext(trunc(y)))) -> xtn(uzp1(x, y))
// trunc(concat_vectors(assertsext(trunc(x)), assertsext(trunc(y)))) -> xtn(uzp1(x, y))
class trunc_concat_trunc_to_xtn_uzp1_pat<ValueType SrcTy, ValueType TruncTy, ValueType ConcatTy,
ValueType Ty>
: Pat<(Ty (trunc_optional_assert_ext
(ConcatTy (concat_vectors
(TruncTy (trunc_optional_assert_ext (SrcTy V128:$Vn))),
(TruncTy (trunc_optional_assert_ext (SrcTy V128:$Vm))))))),
(!cast<Instruction>("XTN"#Ty) (!cast<Instruction>("UZP1"#ConcatTy) V128:$Vn, V128:$Vm))>;
def : trunc_concat_trunc_to_xtn_uzp1_pat<v4i32, v4i16, v8i16, v8i8>;
def : trunc_concat_trunc_to_xtn_uzp1_pat<v2i64, v2i32, v4i32, v4i16>;

def : Pat<(v8i8 (trunc (concat_vectors (v4i16 V64:$Vn), (v4i16 V64:$Vm)))),
(UZP1v8i8 V64:$Vn, V64:$Vm)>;
def : Pat<(v4i16 (trunc (concat_vectors (v2i32 V64:$Vn), (v2i32 V64:$Vm)))),
(UZP1v4i16 V64:$Vn, V64:$Vm)>;

def : Pat<(v16i8 (concat_vectors
(v8i8 (trunc (AArch64vlshr (v8i16 V128:$Vn), (i32 8)))),
Expand Down
21 changes: 8 additions & 13 deletions llvm/test/CodeGen/AArch64/arm64-convert-v4f64.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ define <4 x i16> @fptosi_v4f64_to_v4i16(ptr %ptr) {
; CHECK-NEXT: ldp q0, q1, [x0]
; CHECK-NEXT: fcvtzs v1.2d, v1.2d
; CHECK-NEXT: fcvtzs v0.2d, v0.2d
; CHECK-NEXT: xtn v1.2s, v1.2d
; CHECK-NEXT: xtn v0.2s, v0.2d
; CHECK-NEXT: uzp1 v0.4h, v0.4h, v1.4h
; CHECK-NEXT: uzp1 v0.4s, v0.4s, v1.4s
; CHECK-NEXT: xtn v0.4h, v0.4s
; CHECK-NEXT: ret
%tmp1 = load <4 x double>, ptr %ptr
%tmp2 = fptosi <4 x double> %tmp1 to <4 x i16>
Expand All @@ -26,13 +25,10 @@ define <8 x i8> @fptosi_v4f64_to_v4i8(ptr %ptr) {
; CHECK-NEXT: fcvtzs v1.2d, v1.2d
; CHECK-NEXT: fcvtzs v3.2d, v3.2d
; CHECK-NEXT: fcvtzs v2.2d, v2.2d
; CHECK-NEXT: xtn v0.2s, v0.2d
; CHECK-NEXT: xtn v1.2s, v1.2d
; CHECK-NEXT: xtn v3.2s, v3.2d
; CHECK-NEXT: xtn v2.2s, v2.2d
; CHECK-NEXT: uzp1 v0.4h, v1.4h, v0.4h
; CHECK-NEXT: uzp1 v1.4h, v2.4h, v3.4h
; CHECK-NEXT: uzp1 v0.8b, v1.8b, v0.8b
; CHECK-NEXT: uzp1 v0.4s, v1.4s, v0.4s
; CHECK-NEXT: uzp1 v1.4s, v2.4s, v3.4s
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
; CHECK-NEXT: xtn v0.8b, v0.8h
; CHECK-NEXT: ret
%tmp1 = load <8 x double>, ptr %ptr
%tmp2 = fptosi <8 x double> %tmp1 to <8 x i8>
Expand Down Expand Up @@ -72,9 +68,8 @@ define <4 x i16> @fptoui_v4f64_to_v4i16(ptr %ptr) {
; CHECK-NEXT: ldp q0, q1, [x0]
; CHECK-NEXT: fcvtzs v1.2d, v1.2d
; CHECK-NEXT: fcvtzs v0.2d, v0.2d
; CHECK-NEXT: xtn v1.2s, v1.2d
; CHECK-NEXT: xtn v0.2s, v0.2d
; CHECK-NEXT: uzp1 v0.4h, v0.4h, v1.4h
; CHECK-NEXT: uzp1 v0.4s, v0.4s, v1.4s
; CHECK-NEXT: xtn v0.4h, v0.4s
; CHECK-NEXT: ret
%tmp1 = load <4 x double>, ptr %ptr
%tmp2 = fptoui <4 x double> %tmp1 to <4 x i16>
Expand Down
31 changes: 16 additions & 15 deletions llvm/test/CodeGen/AArch64/extbinopload.ll
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ define <16 x i32> @extrause_load(ptr %p, ptr %q, ptr %r, ptr %s, ptr %z) {
; CHECK-NEXT: add x11, x3, #12
; CHECK-NEXT: str s1, [x4]
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: ldp s0, s5, [x2]
; CHECK-NEXT: ldp s0, s4, [x2]
; CHECK-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-NEXT: umov w9, v2.h[0]
; CHECK-NEXT: umov w10, v2.h[1]
Expand All @@ -662,24 +662,25 @@ define <16 x i32> @extrause_load(ptr %p, ptr %q, ptr %r, ptr %s, ptr %z) {
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NEXT: mov v0.b[10], w9
; CHECK-NEXT: add x9, x1, #4
; CHECK-NEXT: uzp1 v1.8b, v1.8b, v2.8b
; CHECK-NEXT: mov v1.d[1], v2.d[0]
; CHECK-NEXT: mov v0.b[11], w10
; CHECK-NEXT: add x10, x1, #12
; CHECK-NEXT: bic v1.8h, #255, lsl #8
; CHECK-NEXT: ld1 { v0.s }[3], [x3], #4
; CHECK-NEXT: ldr s4, [x0, #12]
; CHECK-NEXT: ldp s3, s16, [x0, #4]
; CHECK-NEXT: ld1 { v5.s }[1], [x3]
; CHECK-NEXT: ldp s6, s7, [x2, #8]
; CHECK-NEXT: ld1 { v4.s }[1], [x10]
; CHECK-NEXT: ld1 { v3.s }[1], [x9]
; CHECK-NEXT: ld1 { v6.s }[1], [x8]
; CHECK-NEXT: ld1 { v7.s }[1], [x11]
; CHECK-NEXT: ldr s3, [x0, #12]
; CHECK-NEXT: ldp s2, s7, [x0, #4]
; CHECK-NEXT: ld1 { v4.s }[1], [x3]
; CHECK-NEXT: ldp s5, s6, [x2, #8]
; CHECK-NEXT: ld1 { v3.s }[1], [x10]
; CHECK-NEXT: ld1 { v2.s }[1], [x9]
; CHECK-NEXT: ld1 { v5.s }[1], [x8]
; CHECK-NEXT: ld1 { v6.s }[1], [x11]
; CHECK-NEXT: add x8, x1, #8
; CHECK-NEXT: ld1 { v16.s }[1], [x8]
; CHECK-NEXT: uaddl v2.8h, v3.8b, v4.8b
; CHECK-NEXT: ushll v3.8h, v6.8b, #0
; CHECK-NEXT: uaddl v4.8h, v5.8b, v7.8b
; CHECK-NEXT: uaddl v1.8h, v1.8b, v16.8b
; CHECK-NEXT: ld1 { v7.s }[1], [x8]
; CHECK-NEXT: uaddl v2.8h, v2.8b, v3.8b
; CHECK-NEXT: ushll v3.8h, v5.8b, #0
; CHECK-NEXT: uaddl v4.8h, v4.8b, v6.8b
; CHECK-NEXT: uaddw v1.8h, v1.8h, v7.8b
; CHECK-NEXT: uaddw2 v5.8h, v3.8h, v0.16b
; CHECK-NEXT: ushll v0.4s, v2.4h, #3
; CHECK-NEXT: ushll2 v2.4s, v2.8h, #3
Expand Down
5 changes: 2 additions & 3 deletions llvm/test/CodeGen/AArch64/fp-conversion-to-tbl.ll
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ define void @fptoui_v8f32_to_v8i8_no_loop(ptr %A, ptr %dst) {
; CHECK-NEXT: ldp q0, q1, [x0]
; CHECK-NEXT: fcvtzs.4s v1, v1
; CHECK-NEXT: fcvtzs.4s v0, v0
; CHECK-NEXT: xtn.4h v1, v1
; CHECK-NEXT: xtn.4h v0, v0
; CHECK-NEXT: uzp1.8b v0, v0, v1
; CHECK-NEXT: uzp1.8h v0, v0, v1
; CHECK-NEXT: xtn.8b v0, v0
; CHECK-NEXT: str d0, [x1]
; CHECK-NEXT: ret
entry:
Expand Down
Loading