fix: decoders and pipeline parity gaps of linen to nnx migrations #4288
Open
mesakhcienet wants to merge 1 commit into
Open
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
783a66a to
7064594
Compare
mesakhcienet
commented
Jun 29, 2026
Comment on lines
-598
to
-630
| def get_layer_to_pipeline(blocks, cfg): | ||
| if cfg.decoder_block == DecoderBlockType.DEEPSEEK: | ||
| return blocks[1] # return the sparse block | ||
| else: | ||
| return blocks[0] | ||
|
|
||
| cfg = self.config | ||
| base_stage = get_layer_to_pipeline(decoder_blocks, cfg) | ||
| if cfg.set_remat_policy_on_layers_per_stage: | ||
| policy = self.get_remat_policy() | ||
| base_stage = self.set_remat_policy([base_stage], policy)[0] | ||
| if cfg.num_layers_per_pipeline_stage == 1: | ||
| stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode) | ||
| elif cfg.scan_layers_per_stage: | ||
| stage_module = self.scan_decoder_layers( | ||
| cfg, | ||
| base_stage, | ||
| base_stage_cls, | ||
| cfg.num_layers_per_pipeline_stage, | ||
| "layers_per_stage", | ||
| cfg, | ||
| self.mesh, | ||
| in_axes_tuple=(nn.broadcast,) * 4, | ||
| model_mode=self.model_mode, | ||
| ) | ||
| else: | ||
| stage_module = SequentialBlockDecoderLayers( | ||
| decoder_layer=base_stage, | ||
| num_decoder_layers=cfg.num_layers_per_pipeline_stage, | ||
| config=cfg, | ||
| mesh=self.mesh, | ||
| quant=self.quant, | ||
| model_mode=self.model_mode, | ||
| self.quant, | ||
| self.model_mode, | ||
| rngs=rngs, | ||
| remat_policy=per_stage_remat, | ||
| apply_remat=apply_per_stage_remat, | ||
| ) | ||
| return stage_module |
Collaborator
Author
There was a problem hiding this comment.
remove get_layer_to_pipeline dead code (unused anymore)
262fa8e to
bc467c2
Compare
c8206f0 to
b2663d2
Compare
50c83c4 to
a776654
Compare
71b97ed to
cf1a449
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Closes the remaining behavioral gaps between the pure-NNX decoder/pipeline path and the Linen reference.
The NNX path now reproduces Linen for DeepSeek-V4, per-stage remat, and the pipeline's non-trainable / repeat-level-remat handlning.
What's included
DeepSeek-V4 NNX decoder port
DEEPSEEK4inNNXDecoder.get_decoder_layer(was missing → ValueError at construction) and added full decoder-level handling: norm dispatch (RMSNorm), scanned + non-scanned init,_apply_deepseek4_scanned_blocks(prefix first_num_hash_layers unroll + paired HCA/CSA scan), global layer_idx, and decoder_input_tokens threading — matching Linen_apply_deepseek4_scanned_blocks.Per-stage pipeline remat parity (set_remat_policy_on_layers_per_stage)
NNXSequentialPipelineStage/NNXScannedPipelineStage, wired from both stage builders, incl.num_layers_per_pipeline_stage == 1.remat_policy='full'resolves toNone(== full remat, as Linen nn.remat(policy=None)); the oldif policy is not Nonegate silently dropped remat for the default 'full' policy. Now gated on the flag via an explicitapply_rematargument.Pipeline Linen→NNX migration parity (pipeline.py)
non_trainablecollection: the migration asserted the iteration-scan catch-all was RngState-only, crashing any pipelined model with a non-trainable variable (e.g. theDeepSeek-V4hash-routing table). Non-circular now broadcastsnon_trainableas a loop-invariant constant (4-way state split); circular carries it viacarry_state.Unit Tests
Tests
Sheet combination of
set_remat_policy_on_layers_per_stageflag.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.