Skip to content

Commit 7605d27

Browse files
NickGuy-Armakuhlens
authored andcommitted
[AArch64] Add fixed-length SVE USDOT support (llvm#143730)
1 parent e49613b commit 7605d27

File tree

2 files changed

+238
-3
lines changed

2 files changed

+238
-3
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,17 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22722272
setPartialReduceMLAAction(MLAOps, VT,
22732273
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
22742274
}
2275+
2276+
if (Subtarget->hasMatMulInt8()) {
2277+
if (VT.getVectorElementType() == MVT::i32)
2278+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
2279+
MVT::getVectorVT(MVT::i8, NumElts * 4),
2280+
Custom);
2281+
else if (VT.getVectorElementType() == MVT::i64)
2282+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
2283+
MVT::getVectorVT(MVT::i8, NumElts * 8),
2284+
Custom);
2285+
}
22752286
}
22762287

22772288
// Lower fixed length vector operations to scalable equivalents.

llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll

Lines changed: 227 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
3-
; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
4-
; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
2+
; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
3+
; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
4+
; RUN: llc -mattr=+sme,+i8mm -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
55

66
target triple = "aarch64"
77

@@ -407,6 +407,154 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
407407
ret <4 x i32> %partial.reduce
408408
}
409409

410+
define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
411+
; COMMON-LABEL: four_way_i8_i32_vl128_usdot:
412+
; COMMON: // %bb.0:
413+
; COMMON-NEXT: ldr q0, [x0]
414+
; COMMON-NEXT: ldr q1, [x1]
415+
; COMMON-NEXT: ldr q2, [x2]
416+
; COMMON-NEXT: usdot v0.4s, v1.16b, v2.16b
417+
; COMMON-NEXT: ret
418+
;
419+
; SME-LABEL: four_way_i8_i32_vl128_usdot:
420+
; SME: // %bb.0:
421+
; SME-NEXT: ldr q0, [x0]
422+
; SME-NEXT: ldr q1, [x1]
423+
; SME-NEXT: ldr q2, [x2]
424+
; SME-NEXT: usdot z0.s, z1.b, z2.b
425+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
426+
; SME-NEXT: ret
427+
%acc = load <4 x i32>, ptr %accptr
428+
%u = load <16 x i8>, ptr %uptr
429+
%s = load <16 x i8>, ptr %sptr
430+
%u.wide = zext <16 x i8> %u to <16 x i32>
431+
%s.wide = sext <16 x i8> %s to <16 x i32>
432+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
433+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
434+
ret <4 x i32> %partial.reduce
435+
}
436+
437+
define <4 x i32> @four_way_i8_i32_vl128_sudot(ptr %accptr, ptr %uptr, ptr %sptr) {
438+
; COMMON-LABEL: four_way_i8_i32_vl128_sudot:
439+
; COMMON: // %bb.0:
440+
; COMMON-NEXT: ldr q0, [x0]
441+
; COMMON-NEXT: ldr q1, [x1]
442+
; COMMON-NEXT: ldr q2, [x2]
443+
; COMMON-NEXT: usdot v0.4s, v2.16b, v1.16b
444+
; COMMON-NEXT: ret
445+
;
446+
; SME-LABEL: four_way_i8_i32_vl128_sudot:
447+
; SME: // %bb.0:
448+
; SME-NEXT: ldr q0, [x0]
449+
; SME-NEXT: ldr q1, [x1]
450+
; SME-NEXT: ldr q2, [x2]
451+
; SME-NEXT: usdot z0.s, z2.b, z1.b
452+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
453+
; SME-NEXT: ret
454+
%acc = load <4 x i32>, ptr %accptr
455+
%u = load <16 x i8>, ptr %uptr
456+
%s = load <16 x i8>, ptr %sptr
457+
%u.wide = sext <16 x i8> %u to <16 x i32>
458+
%s.wide = zext <16 x i8> %s to <16 x i32>
459+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
460+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
461+
ret <4 x i32> %partial.reduce
462+
}
463+
464+
define <2 x i64> @four_way_i8_i64_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
465+
; NEON-LABEL: four_way_i8_i64_vl128_usdot:
466+
; NEON: // %bb.0:
467+
; NEON-NEXT: movi v0.2d, #0000000000000000
468+
; NEON-NEXT: ldr q1, [x1]
469+
; NEON-NEXT: ldr q2, [x2]
470+
; NEON-NEXT: usdot v0.4s, v1.16b, v2.16b
471+
; NEON-NEXT: ldr q1, [x0]
472+
; NEON-NEXT: saddw v1.2d, v1.2d, v0.2s
473+
; NEON-NEXT: saddw2 v0.2d, v1.2d, v0.4s
474+
; NEON-NEXT: ret
475+
;
476+
; SVE-LABEL: four_way_i8_i64_vl128_usdot:
477+
; SVE: // %bb.0:
478+
; SVE-NEXT: movi v0.2d, #0000000000000000
479+
; SVE-NEXT: ldr q1, [x1]
480+
; SVE-NEXT: ldr q2, [x2]
481+
; SVE-NEXT: usdot z0.s, z1.b, z2.b
482+
; SVE-NEXT: ldr q2, [x0]
483+
; SVE-NEXT: sunpklo z1.d, z0.s
484+
; SVE-NEXT: sunpkhi z0.d, z0.s
485+
; SVE-NEXT: add z1.d, z2.d, z1.d
486+
; SVE-NEXT: add z0.d, z1.d, z0.d
487+
; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
488+
; SVE-NEXT: ret
489+
;
490+
; SME-LABEL: four_way_i8_i64_vl128_usdot:
491+
; SME: // %bb.0:
492+
; SME-NEXT: mov z0.s, #0 // =0x0
493+
; SME-NEXT: ldr q1, [x1]
494+
; SME-NEXT: ldr q2, [x2]
495+
; SME-NEXT: usdot z0.s, z1.b, z2.b
496+
; SME-NEXT: ldr q1, [x0]
497+
; SME-NEXT: saddwb z1.d, z1.d, z0.s
498+
; SME-NEXT: saddwt z0.d, z1.d, z0.s
499+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
500+
; SME-NEXT: ret
501+
%acc = load <2 x i64>, ptr %accptr
502+
%u = load <16 x i8>, ptr %uptr
503+
%s = load <16 x i8>, ptr %sptr
504+
%u.wide = zext <16 x i8> %u to <16 x i64>
505+
%s.wide = sext <16 x i8> %s to <16 x i64>
506+
%mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
507+
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <16 x i64> %mult)
508+
ret <2 x i64> %partial.reduce
509+
}
510+
511+
define <2 x i64> @four_way_i16_i64_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
512+
; COMMON-LABEL: four_way_i16_i64_vl128_usdot:
513+
; COMMON: // %bb.0:
514+
; COMMON-NEXT: ldr q1, [x1]
515+
; COMMON-NEXT: ldr q2, [x2]
516+
; COMMON-NEXT: ldr q0, [x0]
517+
; COMMON-NEXT: ushll v3.4s, v1.4h, #0
518+
; COMMON-NEXT: sshll v4.4s, v2.4h, #0
519+
; COMMON-NEXT: ushll2 v1.4s, v1.8h, #0
520+
; COMMON-NEXT: sshll2 v2.4s, v2.8h, #0
521+
; COMMON-NEXT: smlal v0.2d, v4.2s, v3.2s
522+
; COMMON-NEXT: smlal2 v0.2d, v4.4s, v3.4s
523+
; COMMON-NEXT: smlal v0.2d, v2.2s, v1.2s
524+
; COMMON-NEXT: smlal2 v0.2d, v2.4s, v1.4s
525+
; COMMON-NEXT: ret
526+
;
527+
; SME-LABEL: four_way_i16_i64_vl128_usdot:
528+
; SME: // %bb.0:
529+
; SME-NEXT: ptrue p0.d, vl2
530+
; SME-NEXT: ldr q2, [x0]
531+
; SME-NEXT: mov x8, #2 // =0x2
532+
; SME-NEXT: ld1h { z0.d }, p0/z, [x1]
533+
; SME-NEXT: ld1sh { z1.d }, p0/z, [x2]
534+
; SME-NEXT: mad z0.d, p0/m, z1.d, z2.d
535+
; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
536+
; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
537+
; SME-NEXT: mov x8, #4 // =0x4
538+
; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d
539+
; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
540+
; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
541+
; SME-NEXT: mov x8, #6 // =0x6
542+
; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d
543+
; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
544+
; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
545+
; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d
546+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
547+
; SME-NEXT: ret
548+
%acc = load <2 x i64>, ptr %accptr
549+
%u = load <8 x i16>, ptr %uptr
550+
%s = load <8 x i16>, ptr %sptr
551+
%u.wide = zext <8 x i16> %u to <8 x i64>
552+
%s.wide = sext <8 x i16> %s to <8 x i64>
553+
%mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
554+
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <8 x i64> %mult)
555+
ret <2 x i64> %partial.reduce
556+
}
557+
410558
define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
411559
;
412560
; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
@@ -438,6 +586,37 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr
438586
ret <8 x i32> %partial.reduce
439587
}
440588

589+
define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
590+
;
591+
; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot:
592+
; COMMON: // %bb.0:
593+
; COMMON-NEXT: ldp q0, q1, [x0]
594+
; COMMON-NEXT: ldp q3, q2, [x1]
595+
; COMMON-NEXT: ldp q5, q4, [x2]
596+
; COMMON-NEXT: usdot v0.4s, v3.16b, v5.16b
597+
; COMMON-NEXT: usdot v1.4s, v2.16b, v4.16b
598+
; COMMON-NEXT: ret
599+
;
600+
; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
601+
; SME: // %bb.0:
602+
; SME-NEXT: ldp q0, q1, [x0]
603+
; SME-NEXT: ldp q3, q2, [x1]
604+
; SME-NEXT: ldp q5, q4, [x2]
605+
; SME-NEXT: usdot z0.s, z3.b, z5.b
606+
; SME-NEXT: usdot z1.s, z2.b, z4.b
607+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
608+
; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
609+
; SME-NEXT: ret
610+
%acc = load <8 x i32>, ptr %accptr
611+
%u = load <32 x i8>, ptr %uptr
612+
%s = load <32 x i8>, ptr %sptr
613+
%u.wide = zext <32 x i8> %u to <32 x i32>
614+
%s.wide = sext <32 x i8> %s to <32 x i32>
615+
%mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
616+
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
617+
ret <8 x i32> %partial.reduce
618+
}
619+
441620
define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
442621
;
443622
;
@@ -483,6 +662,51 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal
483662
ret <8 x i32> %partial.reduce
484663
}
485664

665+
define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
666+
;
667+
;
668+
; NEON-LABEL: four_way_i8_i32_vl256_usdot:
669+
; NEON: // %bb.0:
670+
; NEON-NEXT: ldp q0, q1, [x0]
671+
; NEON-NEXT: ldp q3, q2, [x1]
672+
; NEON-NEXT: ldp q5, q4, [x2]
673+
; NEON-NEXT: usdot v0.4s, v3.16b, v5.16b
674+
; NEON-NEXT: usdot v1.4s, v2.16b, v4.16b
675+
; NEON-NEXT: ret
676+
;
677+
; SVE-LABEL: four_way_i8_i32_vl256_usdot:
678+
; SVE: // %bb.0:
679+
; SVE-NEXT: ldr z0, [x0]
680+
; SVE-NEXT: ldr z1, [x1]
681+
; SVE-NEXT: ldr z2, [x2]
682+
; SVE-NEXT: usdot z0.s, z1.b, z2.b
683+
; SVE-NEXT: mov z1.d, z0.d
684+
; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
685+
; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
686+
; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
687+
; SVE-NEXT: ret
688+
;
689+
; SME-LABEL: four_way_i8_i32_vl256_usdot:
690+
; SME: // %bb.0:
691+
; SME-NEXT: ldr z0, [x0]
692+
; SME-NEXT: ldr z1, [x1]
693+
; SME-NEXT: ldr z2, [x2]
694+
; SME-NEXT: usdot z0.s, z1.b, z2.b
695+
; SME-NEXT: mov z1.d, z0.d
696+
; SME-NEXT: ext z1.b, z1.b, z0.b, #16
697+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
698+
; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
699+
; SME-NEXT: ret
700+
%acc = load <8 x i32>, ptr %accptr
701+
%u = load <32 x i8>, ptr %uptr
702+
%s = load <32 x i8>, ptr %sptr
703+
%u.wide = zext <32 x i8> %u to <32 x i32>
704+
%s.wide = sext <32 x i8> %s to <32 x i32>
705+
%mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
706+
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
707+
ret <8 x i32> %partial.reduce
708+
}
709+
486710
;
487711
; Four-way dot (i16 -> i64)
488712
;

0 commit comments

Comments
 (0)