Skip to content

Commit 4c1b1f6

Browse files
authored
[NVPTX] Add support for clamped funnel shift intrinsics (#113228)
Add support for ``llvm.nvvm.fshl.clamp`` and ``llvm.nvvm.fshr.clamp`` intrinsics. These intrinsics are similar to the generic llvm funnel shift, except that the shift value is clamped to the integer width. Currently only ``i32`` is supported and is implemented with the `shf.[rl].clamp.b32` PTX instruction.
1 parent 9b98455 commit 4c1b1f6

File tree

6 files changed

+217
-1
lines changed

6 files changed

+217
-1
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,64 @@ used in the '``llvm.nvvm.idp4a.[us].u``' variants, while sign-extension is used
319319
with '``llvm.nvvm.idp4a.[us].s``' variants. The dot product of these 4-element
320320
vectors is added to ``%c`` to produce the return.
321321

322+
Bit Manipulation Intrinsics
323+
---------------------------
324+
325+
'``llvm.nvvm.fshl.clamp.*``' Intrinsic
326+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
327+
328+
Syntax:
329+
"""""""
330+
331+
.. code-block:: llvm
332+
333+
declare i32 @llvm.nvvm.fshl.clamp.i32(i32 %hi, i32 %lo, i32 %n)
334+
335+
Overview:
336+
"""""""""
337+
338+
The '``llvm.nvvm.fshl.clamp``' family of intrinsics performs a clamped funnel
339+
shift left. These intrinsics are very similar to '``llvm.fshl``', except the
340+
shift ammont is clamped at the integer width (instead of modulo it). Currently,
341+
only ``i32`` is supported.
342+
343+
Semantics:
344+
""""""""""
345+
346+
The '``llvm.nvvm.fshl.clamp``' family of intrinsic functions performs a clamped
347+
funnel shift left: the first two values are concatenated as { %hi : %lo } (%hi
348+
is the most significant bits of the wide value), the combined value is shifted
349+
left, and the most significant bits are extracted to produce a result that is
350+
the same size as the original arguments. The shift amount is the minimum of the
351+
value of %n and the bit width of the integer type.
352+
353+
'``llvm.nvvm.fshr.clamp.*``' Intrinsic
354+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
355+
356+
Syntax:
357+
"""""""
358+
359+
.. code-block:: llvm
360+
361+
declare i32 @llvm.nvvm.fshr.clamp.i32(i32 %hi, i32 %lo, i32 %n)
362+
363+
Overview:
364+
"""""""""
365+
366+
The '``llvm.nvvm.fshr.clamp``' family of intrinsics perform a clamped funnel
367+
shift right. These intrinsics are very similar to '``llvm.fshr``', except the
368+
shift ammont is clamped at the integer width (instead of modulo it). Currently,
369+
only ``i32`` is supported.
370+
371+
Semantics:
372+
""""""""""
373+
374+
The '``llvm.nvvm.fshr.clamp``' family of intrinsic functions performs a clamped
375+
funnel shift right: the first two values are concatenated as { %hi : %lo } (%hi
376+
is the most significant bits of the wide value), the combined value is shifted
377+
right, and the least significant bits are extracted to produce a result that is
378+
the same size as the original arguments. The shift amount is the minimum of the
379+
value of %n and the bit width of the integer type.
322380

323381

324382
Other Intrinsics

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,16 @@ let TargetPrefix = "nvvm" in {
10801080
}
10811081
}
10821082

1083+
//
1084+
// Funnel-shift
1085+
//
1086+
foreach direction = ["l", "r"] in
1087+
def int_nvvm_fsh # direction # _clamp :
1088+
DefaultAttrsIntrinsic<[llvm_anyint_ty],
1089+
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>],
1090+
[IntrNoMem, IntrSpeculatable, IntrWillReturn]>;
1091+
1092+
10831093
//
10841094
// Convert
10851095
//

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3535,6 +3535,15 @@ let hasSideEffects = false in {
35353535
defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>;
35363536
}
35373537

3538+
def : Pat<(i32 (int_nvvm_fshl_clamp (i32 Int32Regs:$hi), (i32 Int32Regs:$lo), (i32 Int32Regs:$amt))),
3539+
(SHF_L_CLAMP_r (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt))>;
3540+
def : Pat<(i32 (int_nvvm_fshl_clamp (i32 Int32Regs:$hi), (i32 Int32Regs:$lo), (i32 imm:$amt))),
3541+
(SHF_L_CLAMP_i (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt))>;
3542+
def : Pat<(i32 (int_nvvm_fshr_clamp (i32 Int32Regs:$hi), (i32 Int32Regs:$lo), (i32 Int32Regs:$amt))),
3543+
(SHF_R_CLAMP_r (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt))>;
3544+
def : Pat<(i32 (int_nvvm_fshr_clamp (i32 Int32Regs:$hi), (i32 Int32Regs:$lo), (i32 imm:$amt))),
3545+
(SHF_R_CLAMP_i (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt))>;
3546+
35383547
// Count leading zeros
35393548
let hasSideEffects = false in {
35403549
def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a),

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
#include "llvm/CodeGen/BasicTTIImpl.h"
1515
#include "llvm/CodeGen/CostTable.h"
1616
#include "llvm/CodeGen/TargetLowering.h"
17+
#include "llvm/IR/Constants.h"
18+
#include "llvm/IR/Intrinsics.h"
1719
#include "llvm/IR/IntrinsicsNVPTX.h"
18-
#include "llvm/Support/Debug.h"
20+
#include "llvm/IR/Value.h"
21+
#include "llvm/Support/Casting.h"
22+
#include "llvm/Transforms/InstCombine/InstCombiner.h"
1923
#include <optional>
2024
using namespace llvm;
2125

@@ -134,6 +138,7 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
134138
// simplify.
135139
enum SpecialCase {
136140
SPC_Reciprocal,
141+
SCP_FunnelShiftClamp,
137142
};
138143

139144
// SimplifyAction is a poor-man's variant (plus an additional flag) that
@@ -314,6 +319,10 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
314319
case Intrinsic::nvvm_rcp_rn_d:
315320
return {SPC_Reciprocal, FTZ_Any};
316321

322+
case Intrinsic::nvvm_fshl_clamp:
323+
case Intrinsic::nvvm_fshr_clamp:
324+
return {SCP_FunnelShiftClamp, FTZ_Any};
325+
317326
// We do not currently simplify intrinsics that give an approximate
318327
// answer. These include:
319328
//
@@ -384,6 +393,22 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
384393
return BinaryOperator::Create(
385394
Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1),
386395
II->getArgOperand(0), II->getName());
396+
397+
case SCP_FunnelShiftClamp: {
398+
// Canonicalize a clamping funnel shift to the generic llvm funnel shift
399+
// when possible, as this is easier for llvm to optimize further.
400+
if (const auto *ShiftConst = dyn_cast<ConstantInt>(II->getArgOperand(2))) {
401+
const bool IsLeft = II->getIntrinsicID() == Intrinsic::nvvm_fshl_clamp;
402+
if (ShiftConst->getZExtValue() >= II->getType()->getIntegerBitWidth())
403+
return IC.replaceInstUsesWith(*II, II->getArgOperand(IsLeft ? 1 : 0));
404+
405+
const unsigned FshIID = IsLeft ? Intrinsic::fshl : Intrinsic::fshr;
406+
return CallInst::Create(Intrinsic::getOrInsertDeclaration(
407+
II->getModule(), FshIID, II->getType()),
408+
SmallVector<Value *, 3>(II->args()));
409+
}
410+
return nullptr;
411+
}
387412
}
388413
llvm_unreachable("All SpecialCase enumerators should be handled in switch.");
389414
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx -mcpu=sm_61 | FileCheck %s
3+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_61 | FileCheck %s
4+
5+
target triple = "nvptx-nvidia-cuda"
6+
7+
declare i32 @llvm.nvvm.fshr.clamp.i32(i32, i32, i32)
8+
declare i32 @llvm.nvvm.fshl.clamp.i32(i32, i32, i32)
9+
10+
define i32 @fshr_clamp_r(i32 %hi, i32 %lo, i32 %n) {
11+
; CHECK-LABEL: fshr_clamp_r(
12+
; CHECK: {
13+
; CHECK-NEXT: .reg .b32 %r<5>;
14+
; CHECK-EMPTY:
15+
; CHECK-NEXT: // %bb.0:
16+
; CHECK-NEXT: ld.param.u32 %r1, [fshr_clamp_r_param_0];
17+
; CHECK-NEXT: ld.param.u32 %r2, [fshr_clamp_r_param_1];
18+
; CHECK-NEXT: ld.param.u32 %r3, [fshr_clamp_r_param_2];
19+
; CHECK-NEXT: shf.r.clamp.b32 %r4, %r2, %r1, %r3;
20+
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
21+
; CHECK-NEXT: ret;
22+
%call = call i32 @llvm.nvvm.fshr.clamp.i32(i32 %hi, i32 %lo, i32 %n)
23+
ret i32 %call
24+
}
25+
26+
define i32 @fshl_clamp_r(i32 %hi, i32 %lo, i32 %n) {
27+
; CHECK-LABEL: fshl_clamp_r(
28+
; CHECK: {
29+
; CHECK-NEXT: .reg .b32 %r<5>;
30+
; CHECK-EMPTY:
31+
; CHECK-NEXT: // %bb.0:
32+
; CHECK-NEXT: ld.param.u32 %r1, [fshl_clamp_r_param_0];
33+
; CHECK-NEXT: ld.param.u32 %r2, [fshl_clamp_r_param_1];
34+
; CHECK-NEXT: ld.param.u32 %r3, [fshl_clamp_r_param_2];
35+
; CHECK-NEXT: shf.l.clamp.b32 %r4, %r2, %r1, %r3;
36+
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
37+
; CHECK-NEXT: ret;
38+
%call = call i32 @llvm.nvvm.fshl.clamp.i32(i32 %hi, i32 %lo, i32 %n)
39+
ret i32 %call
40+
}
41+
42+
define i32 @fshr_clamp_i(i32 %hi, i32 %lo) {
43+
; CHECK-LABEL: fshr_clamp_i(
44+
; CHECK: {
45+
; CHECK-NEXT: .reg .b32 %r<4>;
46+
; CHECK-EMPTY:
47+
; CHECK-NEXT: // %bb.0:
48+
; CHECK-NEXT: ld.param.u32 %r1, [fshr_clamp_i_param_0];
49+
; CHECK-NEXT: ld.param.u32 %r2, [fshr_clamp_i_param_1];
50+
; CHECK-NEXT: shf.r.clamp.b32 %r3, %r2, %r1, 3;
51+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
52+
; CHECK-NEXT: ret;
53+
%call = call i32 @llvm.nvvm.fshr.clamp.i32(i32 %hi, i32 %lo, i32 3)
54+
ret i32 %call
55+
}
56+
57+
define i32 @fshl_clamp_i(i32 %hi, i32 %lo) {
58+
; CHECK-LABEL: fshl_clamp_i(
59+
; CHECK: {
60+
; CHECK-NEXT: .reg .b32 %r<4>;
61+
; CHECK-EMPTY:
62+
; CHECK-NEXT: // %bb.0:
63+
; CHECK-NEXT: ld.param.u32 %r1, [fshl_clamp_i_param_0];
64+
; CHECK-NEXT: ld.param.u32 %r2, [fshl_clamp_i_param_1];
65+
; CHECK-NEXT: shf.l.clamp.b32 %r3, %r2, %r1, 3;
66+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
67+
; CHECK-NEXT: ret;
68+
%call = call i32 @llvm.nvvm.fshl.clamp.i32(i32 %hi, i32 %lo, i32 3)
69+
ret i32 %call
70+
}

llvm/test/Transforms/InstCombine/NVPTX/nvvm-intrins.ll

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,48 @@ define float @test_sqrt_rn_f_ftz(float %a) #0 {
384384
ret float %ret
385385
}
386386

387+
; CHECK-LABEL: @test_fshl_clamp_1
388+
define i32 @test_fshl_clamp_1(i32 %a, i32 %b) {
389+
; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 3)
390+
%call = call i32 @llvm.nvvm.fshl.clamp.i32(i32 %a, i32 %b, i32 3)
391+
ret i32 %call
392+
}
393+
394+
; CHECK-LABEL: @test_fshl_clamp_2
395+
define i32 @test_fshl_clamp_2(i32 %a, i32 %b) {
396+
; CHECK: ret i32 %b
397+
%call = call i32 @llvm.nvvm.fshl.clamp.i32(i32 %a, i32 %b, i32 300)
398+
ret i32 %call
399+
}
400+
401+
; CHECK-LABEL: @test_fshl_clamp_3
402+
define i32 @test_fshl_clamp_3(i32 %a, i32 %b, i32 %c) {
403+
; CHECK: call i32 @llvm.nvvm.fshl.clamp.i32(i32 %a, i32 %b, i32 %c)
404+
%call = call i32 @llvm.nvvm.fshl.clamp.i32(i32 %a, i32 %b, i32 %c)
405+
ret i32 %call
406+
}
407+
408+
; CHECK-LABEL: @test_fshr_clamp_1
409+
define i32 @test_fshr_clamp_1(i32 %a, i32 %b) {
410+
; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 29)
411+
%call = call i32 @llvm.nvvm.fshr.clamp.i32(i32 %a, i32 %b, i32 3)
412+
ret i32 %call
413+
}
414+
415+
; CHECK-LABEL: @test_fshr_clamp_2
416+
define i32 @test_fshr_clamp_2(i32 %a, i32 %b) {
417+
; CHECK: ret i32 %a
418+
%call = call i32 @llvm.nvvm.fshr.clamp.i32(i32 %a, i32 %b, i32 300)
419+
ret i32 %call
420+
}
421+
422+
; CHECK-LABEL: @test_fshr_clamp_3
423+
define i32 @test_fshr_clamp_3(i32 %a, i32 %b, i32 %c) {
424+
; CHECK: call i32 @llvm.nvvm.fshr.clamp.i32(i32 %a, i32 %b, i32 %c)
425+
%call = call i32 @llvm.nvvm.fshr.clamp.i32(i32 %a, i32 %b, i32 %c)
426+
ret i32 %call
427+
}
428+
387429
declare double @llvm.nvvm.add.rn.d(double, double)
388430
declare float @llvm.nvvm.add.rn.f(float, float)
389431
declare float @llvm.nvvm.add.rn.ftz.f(float, float)
@@ -454,3 +496,5 @@ declare double @llvm.nvvm.ui2d.rn(i32)
454496
declare float @llvm.nvvm.ui2f.rn(i32)
455497
declare double @llvm.nvvm.ull2d.rn(i64)
456498
declare float @llvm.nvvm.ull2f.rn(i64)
499+
declare i32 @llvm.nvvm.fshr.clamp.i32(i32, i32, i32)
500+
declare i32 @llvm.nvvm.fshl.clamp.i32(i32, i32, i32)

0 commit comments

Comments
 (0)