Skip to content

[ARM] Disable UpperBound loop unrolling for MVE tail predicated loops. #69709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2430,9 +2430,15 @@ ARMTTIImpl::getPreferredTailFoldingStyle(bool IVUpdateMayOverflow) const {
void ARMTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
TTI::UnrollingPreferences &UP,
OptimizationRemarkEmitter *ORE) {
// Enable Upper bound unrolling universally, not dependant upon the conditions
// below.
UP.UpperBound = true;
// Enable Upper bound unrolling universally, providing that we do not see an
// active lane mask, which will be better kept as a loop to become tail
// predicated than to be conditionally unrolled.
UP.UpperBound =
!ST->hasMVEIntegerOps() || !any_of(*L->getHeader(), [](Instruction &I) {
return isa<IntrinsicInst>(I) &&
cast<IntrinsicInst>(I).getIntrinsicID() ==
Intrinsic::get_active_lane_mask;
});

// Only currently enable these preferences for M-Class cores.
if (!ST->isMClass())
Expand Down
79 changes: 79 additions & 0 deletions llvm/test/Transforms/LoopUnroll/ARM/mve-upperbound.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes=loop-unroll -S -mtriple thumbv8.1m.main-none-eabi -mattr=+mve %s | FileCheck %s

; The vector loop here is better kept as a loop than conditionally unrolled,
; letting it transform into a tail predicted loop.

define void @unroll_upper(ptr noundef %pSrc, ptr nocapture noundef writeonly %pDst, i32 noundef %blockSize) {
; CHECK-LABEL: @unroll_upper(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP_NOT23:%.*]] = icmp ult i32 [[BLOCKSIZE:%.*]], 16
; CHECK-NEXT: [[AND:%.*]] = and i32 [[BLOCKSIZE]], 15
; CHECK-NEXT: [[CMP6_NOT28:%.*]] = icmp eq i32 [[AND]], 0
; CHECK-NEXT: br i1 [[CMP6_NOT28]], label [[WHILE_END12:%.*]], label [[VECTOR_MEMCHECK:%.*]]
; CHECK: vector.memcheck:
; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[PDST:%.*]], i32 [[AND]]
; CHECK-NEXT: [[TMP0:%.*]] = shl nuw nsw i32 [[AND]], 1
; CHECK-NEXT: [[SCEVGEP32:%.*]] = getelementptr i8, ptr [[PSRC:%.*]], i32 [[TMP0]]
; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult ptr [[PDST]], [[SCEVGEP32]]
; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult ptr [[PSRC]], [[SCEVGEP]]
; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]]
; CHECK-NEXT: [[N_RND_UP:%.*]] = add nuw nsw i32 [[AND]], 7
; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N_RND_UP]], 24
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_MEMCHECK]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[PDST]], i32 [[INDEX]]
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[INDEX]], 1
; CHECK-NEXT: [[NEXT_GEP37:%.*]] = getelementptr i8, ptr [[PSRC]], i32 [[TMP1]]
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[AND]])
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[NEXT_GEP37]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
; CHECK-NEXT: [[TMP2:%.*]] = lshr <8 x i16> [[WIDE_MASKED_LOAD]], <i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8>
; CHECK-NEXT: [[TMP3:%.*]] = trunc <8 x i16> [[TMP2]] to <8 x i8>
; CHECK-NEXT: call void @llvm.masked.store.v8i8.p0(<8 x i8> [[TMP3]], ptr [[NEXT_GEP]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]])
; CHECK-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], 8
; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP4]], label [[WHILE_END12_LOOPEXIT:%.*]], label [[VECTOR_BODY]]
; CHECK: while.end12.loopexit:
; CHECK-NEXT: br label [[WHILE_END12]]
; CHECK: while.end12:
; CHECK-NEXT: ret void
;
entry:
%cmp.not23 = icmp ult i32 %blockSize, 16
%and = and i32 %blockSize, 15
%cmp6.not28 = icmp eq i32 %and, 0
br i1 %cmp6.not28, label %while.end12, label %vector.memcheck

vector.memcheck: ; preds = %entry
%scevgep = getelementptr i8, ptr %pDst, i32 %and
%0 = shl nuw nsw i32 %and, 1
%scevgep32 = getelementptr i8, ptr %pSrc, i32 %0
%bound0 = icmp ult ptr %pDst, %scevgep32
%bound1 = icmp ult ptr %pSrc, %scevgep
%found.conflict = and i1 %bound0, %bound1
%n.rnd.up = add nuw nsw i32 %and, 7
%n.vec = and i32 %n.rnd.up, 24
br label %vector.body

vector.body: ; preds = %vector.body, %vector.memcheck
%index = phi i32 [ 0, %vector.memcheck ], [ %index.next, %vector.body ]
%next.gep = getelementptr i8, ptr %pDst, i32 %index
%1 = shl i32 %index, 1
%next.gep37 = getelementptr i8, ptr %pSrc, i32 %1
%active.lane.mask = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 %index, i32 %and)
%wide.masked.load = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr %next.gep37, i32 2, <8 x i1> %active.lane.mask, <8 x i16> poison)
%2 = lshr <8 x i16> %wide.masked.load, <i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8>
%3 = trunc <8 x i16> %2 to <8 x i8>
call void @llvm.masked.store.v8i8.p0(<8 x i8> %3, ptr %next.gep, i32 1, <8 x i1> %active.lane.mask)
%index.next = add i32 %index, 8
%4 = icmp eq i32 %index.next, %n.vec
br i1 %4, label %while.end12, label %vector.body

while.end12: ; preds = %vector.body, %entry
ret void
}

declare <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32, i32)
declare <8 x i16> @llvm.masked.load.v8i16.p0(ptr nocapture, i32 immarg, <8 x i1>, <8 x i16>)
declare void @llvm.masked.store.v8i8.p0(<8 x i8>, ptr nocapture, i32 immarg, <8 x i1>)