Skip to content

Commit 297e06c

Browse files
committed
[RISCVGatherScatterLowering] Remove restriction that shift must have constant operand
This has been present from the original patch which added the pass, and doesn't appear to be strongly justified. We do need to be careful of commutativity. Differential Revision: https://reviews.llvm.org/D150468
1 parent 64f1fda commit 297e06c

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,6 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
236236
case Instruction::Add:
237237
break;
238238
case Instruction::Shl:
239-
// Only support shift by constant.
240-
if (!isa<Constant>(BO->getOperand(1)))
241-
return false;
242239
break;
243240
case Instruction::Mul:
244241
break;
@@ -251,7 +248,8 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
251248
Index = cast<Instruction>(BO->getOperand(0));
252249
OtherOp = BO->getOperand(1);
253250
} else if (isa<Instruction>(BO->getOperand(1)) &&
254-
L->contains(cast<Instruction>(BO->getOperand(1)))) {
251+
L->contains(cast<Instruction>(BO->getOperand(1))) &&
252+
Instruction::isCommutative(BO->getOpcode())) {
255253
Index = cast<Instruction>(BO->getOperand(1));
256254
OtherOp = BO->getOperand(0);
257255
} else {

llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,100 @@ for.cond.cleanup: ; preds = %vector.body
310310
ret void
311311
}
312312

313+
define void @gather_unknown_pow2(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
314+
; CHECK-LABEL: @gather_unknown_pow2(
315+
; CHECK-NEXT: entry:
316+
; CHECK-NEXT: [[STEP:%.*]] = shl i64 8, [[SHIFT:%.*]]
317+
; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SHIFT]]
318+
; CHECK-NEXT: [[TMP0:%.*]] = mul i64 [[STRIDE]], 4
319+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
320+
; CHECK: vector.body:
321+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
322+
; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[ENTRY]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
323+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[B:%.*]], i64 [[VEC_IND_SCALAR]]
324+
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.riscv.masked.strided.load.v8i32.p0.i64(<8 x i32> undef, ptr [[TMP1]], i64 [[TMP0]], <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
325+
; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
326+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1
327+
; CHECK-NEXT: [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]]
328+
; CHECK-NEXT: store <8 x i32> [[I4]], ptr [[I2]], align 1
329+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
330+
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STEP]]
331+
; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
332+
; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
333+
; CHECK: for.cond.cleanup:
334+
; CHECK-NEXT: ret void
335+
;
336+
entry:
337+
%.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0
338+
%.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer
339+
br label %vector.body
340+
341+
vector.body: ; preds = %vector.body, %entry
342+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
343+
%vec.ind = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, %entry ], [ %vec.ind.next, %vector.body ]
344+
%i = shl nsw <8 x i64> %vec.ind, %.splat
345+
%i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i
346+
%wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
347+
%i2 = getelementptr inbounds i32, ptr %A, i64 %index
348+
%wide.load = load <8 x i32>, ptr %i2, align 1
349+
%i4 = add <8 x i32> %wide.load, %wide.masked.gather
350+
store <8 x i32> %i4, ptr %i2, align 1
351+
%index.next = add nuw i64 %index, 8
352+
%vec.ind.next = add <8 x i64> %vec.ind, <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
353+
%i6 = icmp eq i64 %index.next, 1024
354+
br i1 %i6, label %for.cond.cleanup, label %vector.body
355+
356+
for.cond.cleanup: ; preds = %vector.body
357+
ret void
358+
}
359+
360+
define void @negative_shl_non_commute(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
361+
; CHECK-LABEL: @negative_shl_non_commute(
362+
; CHECK-NEXT: entry:
363+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x i64> poison, i64 [[SHIFT:%.*]], i64 0
364+
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x i64> [[DOTSPLATINSERT]], <8 x i64> poison, <8 x i32> zeroinitializer
365+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
366+
; CHECK: vector.body:
367+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
368+
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, [[ENTRY]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
369+
; CHECK-NEXT: [[I:%.*]] = shl nsw <8 x i64> [[DOTSPLAT]], [[VEC_IND]]
370+
; CHECK-NEXT: [[I1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], <8 x i64> [[I]]
371+
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> [[I1]], i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
372+
; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
373+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1
374+
; CHECK-NEXT: [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]]
375+
; CHECK-NEXT: store <8 x i32> [[I4]], ptr [[I2]], align 1
376+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
377+
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <8 x i64> [[VEC_IND]], <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
378+
; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
379+
; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
380+
; CHECK: for.cond.cleanup:
381+
; CHECK-NEXT: ret void
382+
;
383+
entry:
384+
%.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0
385+
%.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer
386+
br label %vector.body
387+
388+
vector.body: ; preds = %vector.body, %entry
389+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
390+
%vec.ind = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, %entry ], [ %vec.ind.next, %vector.body ]
391+
%i = shl nsw <8 x i64> %.splat, %vec.ind
392+
%i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i
393+
%wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
394+
%i2 = getelementptr inbounds i32, ptr %A, i64 %index
395+
%wide.load = load <8 x i32>, ptr %i2, align 1
396+
%i4 = add <8 x i32> %wide.load, %wide.masked.gather
397+
store <8 x i32> %i4, ptr %i2, align 1
398+
%index.next = add nuw i64 %index, 8
399+
%vec.ind.next = add <8 x i64> %vec.ind, <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
400+
%i6 = icmp eq i64 %index.next, 1024
401+
br i1 %i6, label %for.cond.cleanup, label %vector.body
402+
403+
for.cond.cleanup: ; preds = %vector.body
404+
ret void
405+
}
406+
313407
;void scatter_pow2(signed char * __restrict A, signed char * __restrict B) {
314408
; for (int i = 0; i < 1024; ++i)
315409
; A[i * 4] += B[i];

0 commit comments

Comments
 (0)