-
Notifications
You must be signed in to change notification settings - Fork 193
[WIP] Add MLA layers into fla #395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughA new file defines the Deepseek Multi Latent Attention (MLA) module as a PyTorch class, implementing a configurable multi-headed attention mechanism. The module supports rotary positional embeddings, LoRA low-rank adapters, caching for autoregressive decoding, and integrates Flash Attention for efficient computation. It exposes a constructor, a forward method, and a helper function. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant MLA_Module
participant FlashAttention
participant Cache
Caller->>MLA_Module: forward(hidden_states, attention_mask, past_key_values)
MLA_Module->>MLA_Module: Project queries, keys, values (with LoRA)
MLA_Module->>MLA_Module: Apply rotary embeddings
MLA_Module->>Cache: Update or retrieve cached key/values
MLA_Module->>FlashAttention: Compute attention (causal, variable length)
FlashAttention-->>MLA_Module: Attention output
MLA_Module->>MLA_Module: Project output
MLA_Module-->>Caller: Return output, updated cache
Poem
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (2)
fla/layers/mla.py (2)
19-21
: Remove unused utilities to unblock lint gate
pad_input
andunpad_input
are imported but never referenced, triggeringF401
errors and breaking the pipeline.
Either delete the imports or wire the helpers in (e.g. to unpad before calling Flash-Attention).-from fla.layers.utils import pad_input, unpad_input +# (imports removed – re-add only when the helpers are used)🧰 Tools
🪛 Ruff (0.8.2)
20-20:
fla.layers.utils.pad_input
imported but unusedRemove unused import
(F401)
20-20:
fla.layers.utils.unpad_input
imported but unusedRemove unused import
(F401)
🪛 GitHub Actions: lint
[error] 19-19: flake8: 'fla.layers.utils.pad_input' imported but unused (F401)
[error] 19-19: flake8: 'fla.layers.utils.unpad_input' imported but unused (F401)
[error] pre-commit hooks 'trailing-whitespace' and 'isort' modified the file to fix issues
173-176
: Minor: simplify kwargs access
kwargs.get("cu_seqlens", None)
→kwargs.get("cu_seqlens")
(Ruff SIM910).
Not functional, but fixing it keeps the linter green.🧰 Tools
🪛 Ruff (0.8.2)
173-173: Use
kwargs.get("cu_seqlens")
instead ofkwargs.get("cu_seqlens", None)
Replace
kwargs.get("cu_seqlens", None)
withkwargs.get("cu_seqlens")
(SIM910)
🪛 GitHub Actions: lint
[error] pre-commit hooks 'trailing-whitespace' and 'isort' modified the file to fix issues
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/layers/mla.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/mla.py (5)
fla/modules/layernorm.py (1)
RMSNorm
(1059-1106)fla/modules/rotary.py (1)
RotaryEmbedding
(303-497)fla/ops/utils/index.py (1)
prepare_lens_from_mask
(41-42)fla/layers/utils.py (2)
pad_input
(176-197)unpad_input
(101-173)fla/models/utils.py (3)
Cache
(11-148)get_seq_length
(122-126)update
(43-120)
🪛 Ruff (0.8.2)
fla/layers/mla.py
20-20: fla.layers.utils.pad_input
imported but unused
Remove unused import
(F401)
20-20: fla.layers.utils.unpad_input
imported but unused
Remove unused import
(F401)
26-26: No explicit stacklevel
keyword argument found
(B028)
173-173: Use kwargs.get("cu_seqlens")
instead of kwargs.get("cu_seqlens", None)
Replace kwargs.get("cu_seqlens", None)
with kwargs.get("cu_seqlens")
(SIM910)
🪛 GitHub Actions: lint
fla/layers/mla.py
[error] 6-6: flake8: line too long (129 > 127 characters) (E501)
[error] 19-19: flake8: 'fla.layers.utils.pad_input' imported but unused (F401)
[error] 19-19: flake8: 'fla.layers.utils.unpad_input' imported but unused (F401)
[error] 54-54: flake8: at least two spaces before inline comment (E261)
[error] 56-56: flake8: at least two spaces before inline comment (E261)
[error] 59-59: flake8: at least two spaces before inline comment (E261)
[error] 66-66: flake8: too many blank lines (2) (E303)
[error] 74-74: flake8: too many blank lines (2) (E303)
[error] 78-78: flake8: too many blank lines (2) (E303)
[error] 91-91: flake8: too many blank lines (2) (E303)
[error] 180-180: flake8: block comment should start with '# ' (E265)
[error] pre-commit hooks 'trailing-whitespace' and 'isort' modified the file to fix issues
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: test
- GitHub Check: test
🔇 Additional comments (1)
fla/layers/mla.py (1)
203-211
: Potential None inputs to Flash-Attn
cu_seqlens
can beNone
unless the caller provides it, yetflash_attn_varlen_func
expects valid cumulative length tensors.
Consider:
- Falling back to
flash_attn_func
(padded path) whencu_seqlens is None
, or- Call
unpad_input
earlier to always supply validcu_seqlens
.Without this, the code will crash for the common “full padded batch” path.
🧰 Tools
🪛 GitHub Actions: lint
[error] pre-commit hooks 'trailing-whitespace' and 'isort' modified the file to fix issues
I resolved the comments and fixed the lint today, but the implementation logic is wrong. I'll fix the impl and add more tests later |
Initial impl, draft for now, needs setting up tests
Relates to: #392
Summary by CodeRabbit