Skip to content

Commit 715a043

Browse files
committed
[RISCVGatherScatterLowering] Support shl in non-recursive matching
We can apply the same logic as for multiply since a left shift is just a multiply by a power of two. Note that since shl is not commutative, we do need to be careful to match sure that the splat is the RHS of the instruction. Differential Revision: https://reviews.llvm.org/D150471
1 parent 5d57a9f commit 715a043

File tree

2 files changed

+106
-8
lines changed

2 files changed

+106
-8
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,16 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
135135
// multipled.
136136
auto *BO = dyn_cast<BinaryOperator>(Start);
137137
if (!BO || (BO->getOpcode() != Instruction::Add &&
138+
BO->getOpcode() != Instruction::Shl &&
138139
BO->getOpcode() != Instruction::Mul))
139140
return std::make_pair(nullptr, nullptr);
140141

141142
// Look for an operand that is splatted.
142-
unsigned OtherIndex = 1;
143-
Value *Splat = getSplatValue(BO->getOperand(0));
144-
if (!Splat) {
145-
Splat = getSplatValue(BO->getOperand(1));
146-
OtherIndex = 0;
143+
unsigned OtherIndex = 0;
144+
Value *Splat = getSplatValue(BO->getOperand(1));
145+
if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
146+
Splat = getSplatValue(BO->getOperand(0));
147+
OtherIndex = 1;
147148
}
148149
if (!Splat)
149150
return std::make_pair(nullptr, nullptr);
@@ -158,13 +159,22 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
158159
Builder.SetCurrentDebugLocation(DebugLoc());
159160
// Add the splat value to the start or multiply the start and stride by the
160161
// splat.
161-
if (BO->getOpcode() == Instruction::Add) {
162+
switch (BO->getOpcode()) {
163+
default:
164+
llvm_unreachable("Unexpected opcode");
165+
case Instruction::Add:
162166
Start = Builder.CreateAdd(Start, Splat);
163-
} else {
164-
assert(BO->getOpcode() == Instruction::Mul && "Unexpected opcode");
167+
break;
168+
case Instruction::Mul:
165169
Start = Builder.CreateMul(Start, Splat);
166170
Stride = Builder.CreateMul(Stride, Splat);
171+
break;
172+
case Instruction::Shl:
173+
Start = Builder.CreateShl(Start, Splat);
174+
Stride = Builder.CreateShl(Stride, Splat);
175+
break;
167176
}
177+
168178
return std::make_pair(Start, Stride);
169179
}
170180

llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,94 @@ define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
112112
ret <vscale x 1 x i64> %x
113113
}
114114

115+
define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
116+
; CHECK-LABEL: @straightline_offset_add(
117+
; CHECK-NEXT: [[TMP1:%.*]] = add i64 0, [[OFFSET:%.*]]
118+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
119+
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP2]], i64 4, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
120+
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
121+
;
122+
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
123+
%splat.insert = insertelement <vscale x 1 x i64> poison, i64 %offset, i64 0
124+
%splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
125+
%offsetv = add <vscale x 1 x i64> %step, %splat
126+
%ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offsetv
127+
%x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
128+
<vscale x 1 x ptr> %ptrs,
129+
i32 8,
130+
<vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
131+
<vscale x 1 x i64> poison
132+
)
133+
ret <vscale x 1 x i64> %x
134+
}
135+
136+
define <vscale x 1 x i64> @straightline_offset_shl(ptr %p) {
137+
; CHECK-LABEL: @straightline_offset_shl(
138+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 0
139+
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 32, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
140+
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
141+
;
142+
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
143+
%splat.insert = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
144+
%splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
145+
%offset = shl <vscale x 1 x i64> %step, %splat
146+
%ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
147+
%x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
148+
<vscale x 1 x ptr> %ptrs,
149+
i32 8,
150+
<vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
151+
<vscale x 1 x i64> poison
152+
)
153+
ret <vscale x 1 x i64> %x
154+
}
155+
156+
define <vscale x 1 x i64> @neg_shl_is_not_commutative(ptr %p) {
157+
; CHECK-LABEL: @neg_shl_is_not_commutative(
158+
; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
159+
; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
160+
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
161+
; CHECK-NEXT: [[OFFSET:%.*]] = shl <vscale x 1 x i64> [[SPLAT]], [[STEP]]
162+
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSET]]
163+
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
164+
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
165+
;
166+
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
167+
%splat.insert = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
168+
%splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
169+
%offset = shl <vscale x 1 x i64> %splat, %step
170+
%ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
171+
%x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
172+
<vscale x 1 x ptr> %ptrs,
173+
i32 8,
174+
<vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
175+
<vscale x 1 x i64> poison
176+
)
177+
ret <vscale x 1 x i64> %x
178+
}
179+
180+
define <vscale x 1 x i64> @straightline_offset_shl_nonc(ptr %p, i64 %shift) {
181+
; CHECK-LABEL: @straightline_offset_shl_nonc(
182+
; CHECK-NEXT: [[TMP1:%.*]] = shl i64 0, [[SHIFT:%.*]]
183+
; CHECK-NEXT: [[TMP2:%.*]] = shl i64 1, [[SHIFT]]
184+
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
185+
; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4
186+
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
187+
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
188+
;
189+
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
190+
%splat.insert = insertelement <vscale x 1 x i64> poison, i64 %shift, i64 0
191+
%splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
192+
%offset = shl <vscale x 1 x i64> %step, %splat
193+
%ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
194+
%x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
195+
<vscale x 1 x ptr> %ptrs,
196+
i32 8,
197+
<vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
198+
<vscale x 1 x i64> poison
199+
)
200+
ret <vscale x 1 x i64> %x
201+
}
202+
115203
define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
116204
; CHECK-LABEL: @scatter_loopless(
117205
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]

0 commit comments

Comments
 (0)