Skip to content

Commit 2b4807b

Browse files
committed
[AArch64][SVE] Predicated mla/mls patterns
To go with D149267 and D149967, this adds predicated mla/mls patterns, selected from select(mask, add(a, mul(b, c)), a) -> mla(a, mask, b, c). The existing patterns are eventually removed by D149967. Differential Revision: https://reviews.llvm.org/D149969
1 parent b447dc5 commit 2b4807b

File tree

2 files changed

+60
-66
lines changed

2 files changed

+60
-66
lines changed

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,15 +408,19 @@ def AArch64sub_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2),
408408
def AArch64mla_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
409409
[(int_aarch64_sve_mla node:$pred, node:$op1, node:$op2, node:$op3),
410410
// add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c)
411-
(add node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>;
411+
(add node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0))),
412+
// select(mask, add(a, mul(b, c)), a) -> mla(a, mask, b, c)
413+
(vselect node:$pred, (add node:$op1, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3)), node:$op1)]>;
412414
// pattern for generating pseudo for MLA_ZPmZZ/MAD_ZPmZZ
413415
def AArch64mla_p : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
414416
[(int_aarch64_sve_mla_u node:$pred, node:$op1, node:$op2, node:$op3),
415417
(add node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3))]>;
416418
def AArch64mls_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
417419
[(int_aarch64_sve_mls node:$pred, node:$op1, node:$op2, node:$op3),
418420
// sub(a, select(mask, mul(b, c), splat(0))) -> mls(a, mask, b, c)
419-
(sub node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>;
421+
(sub node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0))),
422+
// select(mask, sub(a, mul(b, c)), a) -> mls(a, mask, b, c)
423+
(vselect node:$pred, (sub node:$op1, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3)), node:$op1)]>;
420424
def AArch64mls_p : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
421425
[(int_aarch64_sve_mls_u node:$pred, node:$op1, node:$op2, node:$op3),
422426
(sub node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3))]>;

llvm/test/CodeGen/AArch64/sve-pred-selectop2.ll

Lines changed: 54 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,7 @@ define <vscale x 2 x i64> @srem_nxv2i64_x(<vscale x 2 x i64> %x, <vscale x 2 x i
362362
; CHECK-NEXT: cmpgt p1.d, p0/z, z2.d, #0
363363
; CHECK-NEXT: movprfx z2, z0
364364
; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d
365-
; CHECK-NEXT: msb z1.d, p0/m, z2.d, z0.d
366-
; CHECK-NEXT: mov z0.d, p1/m, z1.d
365+
; CHECK-NEXT: mls z0.d, p1/m, z2.d, z1.d
367366
; CHECK-NEXT: ret
368367
entry:
369368
%c = icmp sgt <vscale x 2 x i64> %n, zeroinitializer
@@ -379,8 +378,7 @@ define <vscale x 4 x i32> @srem_nxv4i32_x(<vscale x 4 x i32> %x, <vscale x 4 x i
379378
; CHECK-NEXT: cmpgt p1.s, p0/z, z2.s, #0
380379
; CHECK-NEXT: movprfx z2, z0
381380
; CHECK-NEXT: sdiv z2.s, p0/m, z2.s, z1.s
382-
; CHECK-NEXT: msb z1.s, p0/m, z2.s, z0.s
383-
; CHECK-NEXT: mov z0.s, p1/m, z1.s
381+
; CHECK-NEXT: mls z0.s, p1/m, z2.s, z1.s
384382
; CHECK-NEXT: ret
385383
entry:
386384
%c = icmp sgt <vscale x 4 x i32> %n, zeroinitializer
@@ -392,19 +390,18 @@ entry:
392390
define <vscale x 8 x i16> @srem_nxv8i16_x(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i16> %n) {
393391
; CHECK-LABEL: srem_nxv8i16_x:
394392
; CHECK: // %bb.0: // %entry
395-
; CHECK-NEXT: ptrue p0.s
393+
; CHECK-NEXT: ptrue p1.s
396394
; CHECK-NEXT: sunpkhi z3.s, z1.h
397395
; CHECK-NEXT: sunpkhi z4.s, z0.h
396+
; CHECK-NEXT: ptrue p0.h
397+
; CHECK-NEXT: sdivr z3.s, p1/m, z3.s, z4.s
398398
; CHECK-NEXT: sunpklo z5.s, z1.h
399-
; CHECK-NEXT: sdivr z3.s, p0/m, z3.s, z4.s
400399
; CHECK-NEXT: sunpklo z6.s, z0.h
401400
; CHECK-NEXT: movprfx z4, z6
402-
; CHECK-NEXT: sdiv z4.s, p0/m, z4.s, z5.s
403-
; CHECK-NEXT: ptrue p0.h
404-
; CHECK-NEXT: uzp1 z3.h, z4.h, z3.h
405-
; CHECK-NEXT: cmpgt p1.h, p0/z, z2.h, #0
406-
; CHECK-NEXT: msb z1.h, p0/m, z3.h, z0.h
407-
; CHECK-NEXT: mov z0.h, p1/m, z1.h
401+
; CHECK-NEXT: sdiv z4.s, p1/m, z4.s, z5.s
402+
; CHECK-NEXT: cmpgt p0.h, p0/z, z2.h, #0
403+
; CHECK-NEXT: uzp1 z2.h, z4.h, z3.h
404+
; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h
408405
; CHECK-NEXT: ret
409406
entry:
410407
%c = icmp sgt <vscale x 8 x i16> %n, zeroinitializer
@@ -421,25 +418,25 @@ define <vscale x 16 x i8> @srem_nxv16i8_x(<vscale x 16 x i8> %x, <vscale x 16 x
421418
; CHECK-NEXT: ptrue p0.s
422419
; CHECK-NEXT: sunpkhi z5.s, z3.h
423420
; CHECK-NEXT: sunpkhi z6.s, z4.h
424-
; CHECK-NEXT: sunpklo z3.s, z3.h
425-
; CHECK-NEXT: sunpklo z4.s, z4.h
421+
; CHECK-NEXT: sunpklo z7.h, z1.b
426422
; CHECK-NEXT: sdivr z5.s, p0/m, z5.s, z6.s
427-
; CHECK-NEXT: sdivr z3.s, p0/m, z3.s, z4.s
428-
; CHECK-NEXT: sunpklo z4.h, z1.b
429423
; CHECK-NEXT: sunpklo z6.h, z0.b
430-
; CHECK-NEXT: sunpkhi z7.s, z4.h
431-
; CHECK-NEXT: sunpkhi z24.s, z6.h
424+
; CHECK-NEXT: sunpklo z3.s, z3.h
432425
; CHECK-NEXT: sunpklo z4.s, z4.h
426+
; CHECK-NEXT: sunpkhi z24.s, z7.h
427+
; CHECK-NEXT: sunpkhi z25.s, z6.h
428+
; CHECK-NEXT: sunpklo z7.s, z7.h
433429
; CHECK-NEXT: sunpklo z6.s, z6.h
434-
; CHECK-NEXT: sdivr z7.s, p0/m, z7.s, z24.s
435-
; CHECK-NEXT: sdivr z4.s, p0/m, z4.s, z6.s
436-
; CHECK-NEXT: uzp1 z3.h, z3.h, z5.h
437-
; CHECK-NEXT: uzp1 z4.h, z4.h, z7.h
430+
; CHECK-NEXT: sdivr z3.s, p0/m, z3.s, z4.s
431+
; CHECK-NEXT: movprfx z4, z25
432+
; CHECK-NEXT: sdiv z4.s, p0/m, z4.s, z24.s
433+
; CHECK-NEXT: sdiv z6.s, p0/m, z6.s, z7.s
438434
; CHECK-NEXT: ptrue p0.b
439-
; CHECK-NEXT: uzp1 z3.b, z4.b, z3.b
440-
; CHECK-NEXT: cmpgt p1.b, p0/z, z2.b, #0
441-
; CHECK-NEXT: msb z1.b, p0/m, z3.b, z0.b
442-
; CHECK-NEXT: mov z0.b, p1/m, z1.b
435+
; CHECK-NEXT: uzp1 z3.h, z3.h, z5.h
436+
; CHECK-NEXT: uzp1 z4.h, z6.h, z4.h
437+
; CHECK-NEXT: cmpgt p0.b, p0/z, z2.b, #0
438+
; CHECK-NEXT: uzp1 z2.b, z4.b, z3.b
439+
; CHECK-NEXT: mls z0.b, p0/m, z2.b, z1.b
443440
; CHECK-NEXT: ret
444441
entry:
445442
%c = icmp sgt <vscale x 16 x i8> %n, zeroinitializer
@@ -455,8 +452,7 @@ define <vscale x 2 x i64> @urem_nxv2i64_x(<vscale x 2 x i64> %x, <vscale x 2 x i
455452
; CHECK-NEXT: cmpgt p1.d, p0/z, z2.d, #0
456453
; CHECK-NEXT: movprfx z2, z0
457454
; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d
458-
; CHECK-NEXT: msb z1.d, p0/m, z2.d, z0.d
459-
; CHECK-NEXT: mov z0.d, p1/m, z1.d
455+
; CHECK-NEXT: mls z0.d, p1/m, z2.d, z1.d
460456
; CHECK-NEXT: ret
461457
entry:
462458
%c = icmp sgt <vscale x 2 x i64> %n, zeroinitializer
@@ -472,8 +468,7 @@ define <vscale x 4 x i32> @urem_nxv4i32_x(<vscale x 4 x i32> %x, <vscale x 4 x i
472468
; CHECK-NEXT: cmpgt p1.s, p0/z, z2.s, #0
473469
; CHECK-NEXT: movprfx z2, z0
474470
; CHECK-NEXT: udiv z2.s, p0/m, z2.s, z1.s
475-
; CHECK-NEXT: msb z1.s, p0/m, z2.s, z0.s
476-
; CHECK-NEXT: mov z0.s, p1/m, z1.s
471+
; CHECK-NEXT: mls z0.s, p1/m, z2.s, z1.s
477472
; CHECK-NEXT: ret
478473
entry:
479474
%c = icmp sgt <vscale x 4 x i32> %n, zeroinitializer
@@ -485,19 +480,18 @@ entry:
485480
define <vscale x 8 x i16> @urem_nxv8i16_x(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i16> %n) {
486481
; CHECK-LABEL: urem_nxv8i16_x:
487482
; CHECK: // %bb.0: // %entry
488-
; CHECK-NEXT: ptrue p0.s
483+
; CHECK-NEXT: ptrue p1.s
489484
; CHECK-NEXT: uunpkhi z3.s, z1.h
490485
; CHECK-NEXT: uunpkhi z4.s, z0.h
486+
; CHECK-NEXT: ptrue p0.h
487+
; CHECK-NEXT: udivr z3.s, p1/m, z3.s, z4.s
491488
; CHECK-NEXT: uunpklo z5.s, z1.h
492-
; CHECK-NEXT: udivr z3.s, p0/m, z3.s, z4.s
493489
; CHECK-NEXT: uunpklo z6.s, z0.h
494490
; CHECK-NEXT: movprfx z4, z6
495-
; CHECK-NEXT: udiv z4.s, p0/m, z4.s, z5.s
496-
; CHECK-NEXT: ptrue p0.h
497-
; CHECK-NEXT: uzp1 z3.h, z4.h, z3.h
498-
; CHECK-NEXT: cmpgt p1.h, p0/z, z2.h, #0
499-
; CHECK-NEXT: msb z1.h, p0/m, z3.h, z0.h
500-
; CHECK-NEXT: mov z0.h, p1/m, z1.h
491+
; CHECK-NEXT: udiv z4.s, p1/m, z4.s, z5.s
492+
; CHECK-NEXT: cmpgt p0.h, p0/z, z2.h, #0
493+
; CHECK-NEXT: uzp1 z2.h, z4.h, z3.h
494+
; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h
501495
; CHECK-NEXT: ret
502496
entry:
503497
%c = icmp sgt <vscale x 8 x i16> %n, zeroinitializer
@@ -514,25 +508,25 @@ define <vscale x 16 x i8> @urem_nxv16i8_x(<vscale x 16 x i8> %x, <vscale x 16 x
514508
; CHECK-NEXT: ptrue p0.s
515509
; CHECK-NEXT: uunpkhi z5.s, z3.h
516510
; CHECK-NEXT: uunpkhi z6.s, z4.h
517-
; CHECK-NEXT: uunpklo z3.s, z3.h
518-
; CHECK-NEXT: uunpklo z4.s, z4.h
511+
; CHECK-NEXT: uunpklo z7.h, z1.b
519512
; CHECK-NEXT: udivr z5.s, p0/m, z5.s, z6.s
520-
; CHECK-NEXT: udivr z3.s, p0/m, z3.s, z4.s
521-
; CHECK-NEXT: uunpklo z4.h, z1.b
522513
; CHECK-NEXT: uunpklo z6.h, z0.b
523-
; CHECK-NEXT: uunpkhi z7.s, z4.h
524-
; CHECK-NEXT: uunpkhi z24.s, z6.h
514+
; CHECK-NEXT: uunpklo z3.s, z3.h
525515
; CHECK-NEXT: uunpklo z4.s, z4.h
516+
; CHECK-NEXT: uunpkhi z24.s, z7.h
517+
; CHECK-NEXT: uunpkhi z25.s, z6.h
518+
; CHECK-NEXT: uunpklo z7.s, z7.h
526519
; CHECK-NEXT: uunpklo z6.s, z6.h
527-
; CHECK-NEXT: udivr z7.s, p0/m, z7.s, z24.s
528-
; CHECK-NEXT: udivr z4.s, p0/m, z4.s, z6.s
529-
; CHECK-NEXT: uzp1 z3.h, z3.h, z5.h
530-
; CHECK-NEXT: uzp1 z4.h, z4.h, z7.h
520+
; CHECK-NEXT: udivr z3.s, p0/m, z3.s, z4.s
521+
; CHECK-NEXT: movprfx z4, z25
522+
; CHECK-NEXT: udiv z4.s, p0/m, z4.s, z24.s
523+
; CHECK-NEXT: udiv z6.s, p0/m, z6.s, z7.s
531524
; CHECK-NEXT: ptrue p0.b
532-
; CHECK-NEXT: uzp1 z3.b, z4.b, z3.b
533-
; CHECK-NEXT: cmpgt p1.b, p0/z, z2.b, #0
534-
; CHECK-NEXT: msb z1.b, p0/m, z3.b, z0.b
535-
; CHECK-NEXT: mov z0.b, p1/m, z1.b
525+
; CHECK-NEXT: uzp1 z3.h, z3.h, z5.h
526+
; CHECK-NEXT: uzp1 z4.h, z6.h, z4.h
527+
; CHECK-NEXT: cmpgt p0.b, p0/z, z2.b, #0
528+
; CHECK-NEXT: uzp1 z2.b, z4.b, z3.b
529+
; CHECK-NEXT: mls z0.b, p0/m, z2.b, z1.b
536530
; CHECK-NEXT: ret
537531
entry:
538532
%c = icmp sgt <vscale x 16 x i8> %n, zeroinitializer
@@ -905,9 +899,8 @@ define <vscale x 2 x i64> @mla_nxv2i64_x(<vscale x 2 x i64> %x, <vscale x 2 x i6
905899
; CHECK-LABEL: mla_nxv2i64_x:
906900
; CHECK: // %bb.0: // %entry
907901
; CHECK-NEXT: ptrue p0.d
908-
; CHECK-NEXT: cmpgt p1.d, p0/z, z3.d, #0
909-
; CHECK-NEXT: mad z1.d, p0/m, z2.d, z0.d
910-
; CHECK-NEXT: mov z0.d, p1/m, z1.d
902+
; CHECK-NEXT: cmpgt p0.d, p0/z, z3.d, #0
903+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
911904
; CHECK-NEXT: ret
912905
entry:
913906
%c = icmp sgt <vscale x 2 x i64> %n, zeroinitializer
@@ -921,9 +914,8 @@ define <vscale x 4 x i32> @mla_nxv4i32_x(<vscale x 4 x i32> %x, <vscale x 4 x i3
921914
; CHECK-LABEL: mla_nxv4i32_x:
922915
; CHECK: // %bb.0: // %entry
923916
; CHECK-NEXT: ptrue p0.s
924-
; CHECK-NEXT: cmpgt p1.s, p0/z, z3.s, #0
925-
; CHECK-NEXT: mad z1.s, p0/m, z2.s, z0.s
926-
; CHECK-NEXT: mov z0.s, p1/m, z1.s
917+
; CHECK-NEXT: cmpgt p0.s, p0/z, z3.s, #0
918+
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
927919
; CHECK-NEXT: ret
928920
entry:
929921
%c = icmp sgt <vscale x 4 x i32> %n, zeroinitializer
@@ -937,9 +929,8 @@ define <vscale x 8 x i16> @mla_nxv8i16_x(<vscale x 8 x i16> %x, <vscale x 8 x i1
937929
; CHECK-LABEL: mla_nxv8i16_x:
938930
; CHECK: // %bb.0: // %entry
939931
; CHECK-NEXT: ptrue p0.h
940-
; CHECK-NEXT: cmpgt p1.h, p0/z, z3.h, #0
941-
; CHECK-NEXT: mad z1.h, p0/m, z2.h, z0.h
942-
; CHECK-NEXT: mov z0.h, p1/m, z1.h
932+
; CHECK-NEXT: cmpgt p0.h, p0/z, z3.h, #0
933+
; CHECK-NEXT: mla z0.h, p0/m, z1.h, z2.h
943934
; CHECK-NEXT: ret
944935
entry:
945936
%c = icmp sgt <vscale x 8 x i16> %n, zeroinitializer
@@ -953,9 +944,8 @@ define <vscale x 16 x i8> @mla_nxv16i8_x(<vscale x 16 x i8> %x, <vscale x 16 x i
953944
; CHECK-LABEL: mla_nxv16i8_x:
954945
; CHECK: // %bb.0: // %entry
955946
; CHECK-NEXT: ptrue p0.b
956-
; CHECK-NEXT: cmpgt p1.b, p0/z, z3.b, #0
957-
; CHECK-NEXT: mad z1.b, p0/m, z2.b, z0.b
958-
; CHECK-NEXT: mov z0.b, p1/m, z1.b
947+
; CHECK-NEXT: cmpgt p0.b, p0/z, z3.b, #0
948+
; CHECK-NEXT: mla z0.b, p0/m, z1.b, z2.b
959949
; CHECK-NEXT: ret
960950
entry:
961951
%c = icmp sgt <vscale x 16 x i8> %n, zeroinitializer

0 commit comments

Comments
 (0)