Skip to content

Use tl.exp2 for all gating operations #361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Use tl.exp2 for all gating operations #361

wants to merge 7 commits into from

Conversation

yzhangcs
Copy link
Member

@yzhangcs yzhangcs commented Apr 17, 2025

In triton-lang/triton#2893

exp2 is faster than exp because the special function unit inside the GPU implements exp2. When you use the exp intrinsic, it expands to a multiplication by 1/log(2) followed by the exp2 instruction. In this case, we can fold this 1/log(2) multiplier into the sm_scale multiplier so that we need only perform one multiplication per element rather than two.

Summary by CodeRabbit

  • New Features
    • Added support for a base-2 exponential function, ensuring consistent behavior across different operation modes.
    • Introduced a new mathematical constant to enhance computational accuracy in logarithmic and exponential calculations.
    • Updated attention kernel computations to use base-2 logarithms and exponentials with improved scaling for better performance and precision.

Copy link

coderabbitai bot commented Apr 17, 2025

Walkthrough

A new alias for the base-2 exponential function, exp2, has been introduced in the utility operations module. Depending on the FLA_USE_FAST_OPS environment variable, exp2 is set to either the fast implementation from Triton's libdevice or the standard implementation from Triton's math module. A new JIT-compiled function safe_exp2 was added to safely compute base-2 exponentials with input masking. Additionally, a new constant RCP_LN2 (approximate 1/ln(2)) was introduced in a new constant.py file. The parallel attention Triton kernels were updated to replace natural logarithm and exponential functions with their base-2 counterparts, scaling inputs by RCP_LN2 accordingly. Kernel autotuning configurations were also simplified.

Changes

File(s) Change Summary
fla/ops/utils/op.py Added exp2 alias selecting between tldevice.exp2 and tl.math.exp2 based on FLA_USE_FAST_OPS; added JIT-compiled safe_exp2 function.
fla/ops/utils/constant.py Added new file defining the constant RCP_LN2 with value 1.4426950216 (approximate 1/ln(2)) including copyright and hex representation.
fla/ops/attn/parallel.py Replaced exp, log, and safe_exp with exp2, log2, and safe_exp2 in Triton kernels; scaled inputs by RCP_LN2; simplified autotuning configs; scaled final gradients appropriately.

Poem

In the warren of code, a new path appears,
exp2 now hops between fast and safe gears.
Whether speed is your wish or precision your friend,
The right base-2 exponent is here to the end!
With a twitch of the nose and a flick of the ear,
The rabbits rejoice—math’s future is clear!
🐇✨

Tip

⚡💬 Agentic Chat (Pro Plan, General Availability)
  • We're introducing multi-step agentic chat in review comments and issue comments, within and outside of PR's. This feature enhances review and issue discussions with the CodeRabbit agentic chat by enabling advanced interactions, including the ability to create pull requests directly from comments and add commits to existing pull requests.
✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@yzhangcs yzhangcs linked an issue Apr 17, 2025 that may be closed by this pull request
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
fla/ops/utils/constant.py (1)

4-6: Single‑source the reciprocal‑log(2) constant to avoid drift

RCP_LN2 is now hard‑coded here, but the same literal is also replicated in every Triton kernel (see parallel.py). While the duplication is currently harmless, it is a latent maintenance hazard—future changes (e.g. switching to the FP64‑exact 1.4426950408889634) would have to be applied in several places.

Because Triton kernels accept default arguments, the kernels can reference the constant defined in this module:

from fla.ops.utils.constant import RCP_LN2 as _RCP_LN2

@triton.jit
def kernel(..., RCP_LN2: tl.constexpr = _RCP_LN2):
    ...

This keeps one authoritative value and eliminates the risk of silently diverging approximations.

fla/ops/attn/parallel.py (1)

66-68: Avoid magic duplicates of RCP_LN2 inside kernels

The constant is repeated three times in this file. Inlining it works, but it is easy to miss one site if the value ever changes. You can import the Python‑level constant defined in fla.ops.utils.constant and forward it as a tl.constexpr default argument, keeping a single source of truth (see comment in constant.py).

Also applies to: 224-225, 365-366

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ac051ec and 344a034.

📒 Files selected for processing (3)
  • fla/ops/attn/parallel.py (14 hunks)
  • fla/ops/utils/constant.py (1 hunks)
  • fla/ops/utils/op.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/ops/utils/op.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/ops/attn/parallel.py (1)
fla/ops/utils/op.py (1)
  • safe_exp2 (35-36)
⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (4)
fla/ops/attn/parallel.py (4)

23-27: Re‑enable shared‑memory guard when adding 8‑warp configs

num_warps = 8 profiles well on Hopper, but on devices with ≤64 KiB shared memory the same configuration can spill registers and regress performance.
Previously this branch was gated by check_shared_mem('hopper'); now the guard is removed. Consider restoring a capability check, e.g.:

-        for num_warps in [1, 2, 4, 8]
+        for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])

to prevent accidentally autotuning into an infeasible occupancy point on older GPUs.


105-113: Good use of exp2/safe_exp2 for numerically‑stable softmax

Replacing exp with exp2 and folding the 1/ln2 factor into scale reduces one fused‑multiply‑add per element and fully exploits the SFU. Nice!


148-151: Correct base‑2 log accumulation

Switching from loglog2 when updating b_m keeps the log‑sum‑exp in the same base as the exponentials and avoids a redundant * RCP_LN2 conversion.


234-235:

❓ Verification inconclusive

Gradient scaling looks algebraically sound—please run the unit tests

You rescale Q/K by scale * RCP_LN2 for the forward pass, then apply only scale when writing back dq/dk.
Because d/ds 2^s = ln(2)·2^s and RCP_LN2·ln(2) = 1, the extra factors indeed cancel, so the gradients match the original natural‑log implementation.

Nevertheless, a quick regression test (forward+backward vs. PyTorch reference on random tensors) would give high confidence:

If this script matches the pre‑patch gradients numerically (to ≤1e‑4 relative error), everything is wired correctly.

Also applies to: 305-306, 373-374, 457-458


🏁 Script executed:

#!/bin/bash
python - <<'PY'
import torch, math
from fla.ops.attn.parallel import parallel_attn
B,T,H,K,V = 2,128,4,64,64
q = torch.randn(B,T,H*2,K,device='cuda',requires_grad=True)
k = torch.randn(B,T,H,K,device='cuda',requires_grad=True)
v = torch.randn(B,T,H,V,device='cuda',requires_grad=True)
scale = 1/math.sqrt(K)
o = parallel_attn(q,k,v,scale=scale)
loss = o.square().mean()
loss.backward()
print("max|grad|:", q.grad.abs().max().item())
PY

Length of output: 128


Unable to run regression script in CI—please verify gradients locally

I attempted to execute the provided forward‑and‑backward test, but PyTorch isn’t available in this environment:

ModuleNotFoundError: No module named 'torch'

To ensure the scale * RCP_LN2 factor cancels correctly in the backward pass, please run the following script on a machine with PyTorch installed:

#!/bin/bash
python - <<'PY'
import torch, math
from fla.ops.attn.parallel import parallel_attn

B,T,H,K,V = 2,128,4,64,64
q = torch.randn(B,T,H*2,K,device='cuda',requires_grad=True)
k = torch.randn(B,T,H,K,device='cuda',requires_grad=True)
v = torch.randn(B,T,H,V,device='cuda',requires_grad=True)
scale = 1/math.sqrt(K)

o = parallel_attn(q, k, v, scale=scale)
loss = o.square().mean()
loss.backward()

print("max|dq error|:", (q.grad - q.grad).abs().max().item())
# (Compare against pre‑patch gradients to ≤1e-4 relative error.)
PY

This check should be performed for all occurrences of the new scaling logic:

  • Lines 234–235 in parallel.py
  • Lines 305–306
  • Lines 373–374
  • Lines 457–458

@yzhangcs yzhangcs requested a review from zhiyuan1i April 21, 2025 06:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Use tl.exp2 for all gating operations
2 participants