-
Notifications
You must be signed in to change notification settings - Fork 193
Use tl.exp2
for all gating operations
#361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughA new alias for the base-2 exponential function, Changes
Poem
Tip ⚡💬 Agentic Chat (Pro Plan, General Availability)
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
fla/ops/utils/constant.py (1)
4-6
: Single‑source the reciprocal‑log(2) constant to avoid drift
RCP_LN2
is now hard‑coded here, but the same literal is also replicated in every Triton kernel (seeparallel.py
). While the duplication is currently harmless, it is a latent maintenance hazard—future changes (e.g. switching to the FP64‑exact1.4426950408889634
) would have to be applied in several places.Because Triton kernels accept default arguments, the kernels can reference the constant defined in this module:
from fla.ops.utils.constant import RCP_LN2 as _RCP_LN2 @triton.jit def kernel(..., RCP_LN2: tl.constexpr = _RCP_LN2): ...This keeps one authoritative value and eliminates the risk of silently diverging approximations.
fla/ops/attn/parallel.py (1)
66-68
: Avoid magic duplicates ofRCP_LN2
inside kernelsThe constant is repeated three times in this file. Inlining it works, but it is easy to miss one site if the value ever changes. You can import the Python‑level constant defined in
fla.ops.utils.constant
and forward it as atl.constexpr
default argument, keeping a single source of truth (see comment inconstant.py
).Also applies to: 224-225, 365-366
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/ops/attn/parallel.py
(14 hunks)fla/ops/utils/constant.py
(1 hunks)fla/ops/utils/op.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/ops/utils/op.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/ops/attn/parallel.py (1)
fla/ops/utils/op.py (1)
safe_exp2
(35-36)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: test
- GitHub Check: test
- GitHub Check: test
🔇 Additional comments (4)
fla/ops/attn/parallel.py (4)
23-27
: Re‑enable shared‑memory guard when adding 8‑warp configs
num_warps
= 8 profiles well on Hopper, but on devices with ≤64 KiB shared memory the same configuration can spill registers and regress performance.
Previously this branch was gated bycheck_shared_mem('hopper')
; now the guard is removed. Consider restoring a capability check, e.g.:- for num_warps in [1, 2, 4, 8] + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])to prevent accidentally autotuning into an infeasible occupancy point on older GPUs.
105-113
: Good use ofexp2
/safe_exp2
for numerically‑stable softmaxReplacing
exp
withexp2
and folding the1/ln2
factor intoscale
reduces one fused‑multiply‑add per element and fully exploits the SFU. Nice!
148-151
: Correct base‑2 log accumulationSwitching from
log
→log2
when updatingb_m
keeps the log‑sum‑exp in the same base as the exponentials and avoids a redundant* RCP_LN2
conversion.
234-235
:❓ Verification inconclusive
Gradient scaling looks algebraically sound—please run the unit tests
You rescale Q/K by
scale * RCP_LN2
for the forward pass, then apply onlyscale
when writing backdq
/dk
.
Becaused/ds 2^s = ln(2)·2^s
andRCP_LN2·ln(2) = 1
, the extra factors indeed cancel, so the gradients match the original natural‑log implementation.Nevertheless, a quick regression test (forward+backward vs. PyTorch reference on random tensors) would give high confidence:
If this script matches the pre‑patch gradients numerically (to ≤1e‑4 relative error), everything is wired correctly.
Also applies to: 305-306, 373-374, 457-458
🏁 Script executed:
#!/bin/bash python - <<'PY' import torch, math from fla.ops.attn.parallel import parallel_attn B,T,H,K,V = 2,128,4,64,64 q = torch.randn(B,T,H*2,K,device='cuda',requires_grad=True) k = torch.randn(B,T,H,K,device='cuda',requires_grad=True) v = torch.randn(B,T,H,V,device='cuda',requires_grad=True) scale = 1/math.sqrt(K) o = parallel_attn(q,k,v,scale=scale) loss = o.square().mean() loss.backward() print("max|grad|:", q.grad.abs().max().item()) PYLength of output: 128
Unable to run regression script in CI—please verify gradients locally
I attempted to execute the provided forward‑and‑backward test, but PyTorch isn’t available in this environment:
ModuleNotFoundError: No module named 'torch'
To ensure the
scale * RCP_LN2
factor cancels correctly in the backward pass, please run the following script on a machine with PyTorch installed:#!/bin/bash python - <<'PY' import torch, math from fla.ops.attn.parallel import parallel_attn B,T,H,K,V = 2,128,4,64,64 q = torch.randn(B,T,H*2,K,device='cuda',requires_grad=True) k = torch.randn(B,T,H,K,device='cuda',requires_grad=True) v = torch.randn(B,T,H,V,device='cuda',requires_grad=True) scale = 1/math.sqrt(K) o = parallel_attn(q, k, v, scale=scale) loss = o.square().mean() loss.backward() print("max|dq error|:", (q.grad - q.grad).abs().max().item()) # (Compare against pre‑patch gradients to ≤1e-4 relative error.) PYThis check should be performed for all occurrences of the new scaling logic:
- Lines 234–235 in
parallel.py
- Lines 305–306
- Lines 373–374
- Lines 457–458
aacf017
to
0d815aa
Compare
In triton-lang/triton#2893
Summary by CodeRabbit