Add support for StereoRoPE positional embedding and Strata features#1750
Add support for StereoRoPE positional embedding and Strata features#1750zyhu-hu wants to merge 31 commits into
Conversation
… physicsnemo.nn Extend physicsnemo/nn/module/rope.py with a stereographic 2D rotary position embedding for tokens on a sphere, reusing this stack's rope primitives (apply_rotary_pos_emb / rotate_half_pairs). Adds build_rope_cos_sin_1d_continuous and build_axial_rope_cos_sin_continuous (the continuous-coordinate generalizations of build_rope_cos_sin_1d and build_axial_rope_cos_sin), stereographic_projection, spherical_centroid (a pole- and seam-robust tile center), and the StereographicRotaryPositionEmbedding2D module, with tests in test/nn/module/test_rope.py. Stacked on PR NVIDIA#1731 (pzharrington/physicsnemo:dit-enhancements); merge after NVIDIA#1731. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Add the DiT3D 3D transformer backbone and the two-stage PixelDiT model under physicsnemo/experimental/models/strata/. Both use the DiT architecture as deterministic regression models (no diffusion / timestep / label conditioning) and reuse the stereographic RoPE + rope primitives from physicsnemo.nn (StereographicRotaryPositionEmbedding2D / apply_rotary_pos_emb) and physicsnemo.nn.functional.na3d for 3D neighborhood attention. Geometry is decoupled to an optional forward input; the stereographic mode derives each tile center from spherical_centroid (pole- and seam-robust). Includes constructor, non-regression golden, checkpoint, and CUDA/NATTEN tests. Stacked on PR NVIDIA#1731 (pzharrington/physicsnemo:dit-enhancements); merge after NVIDIA#1731. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…nting _should_checkpoint_block gates on self.training, so the test running in eval mode never checkpointed — it compared two plain forwards and passed trivially. Run both models in train mode and assert checkpointing is active, so the test genuinely verifies that activation checkpointing reproduces the non-checkpointed output and gradients. Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Mirror DiT3D's mechanism for the pixel stage: a new activation_checkpointing_pixel arg (bool | float) reusing DiT3D._parse_checkpointing_param; a train-gated, ratio-based _should_checkpoint_pixel_block; and the pixel-block loop wraps each block in torch.utils.checkpoint when selected. Ports the pixel-block checkpointing from the screamcast Strata source. Adds test_pixeldit_activation_checkpointing_matches (train mode) asserting the checkpointed run reproduces the non-checkpointed output and per-parameter gradients. Additive and default-off, so the goldens are unchanged. Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Hoist the length-3 (kd, kh, kw) check into _as_kernel_triple and validate eagerly in Natten3DSelfAttention.__init__, which every block (DiT3D and PixelDiT) constructs — so a malformed tuple fails clearly at construction instead of deep inside NATTEN. Removes the now-redundant DiT3D-only check. Adds a test that a non-length-3 tuple is rejected. Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Use self.num_adaln_params instead of the literal 6 in the pixel_proj chunk, and rename precompute_bilinear_cond's locals ph/pw -> h/w (they are pixel height/width, not patch sizes, which ph/pw mean everywhere else in the file). Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…ture The DiT3D test now uses do_alt_depthwise_attn + stereographic RoPE so per-block rope_tables differ (depth blocks get None) — a homogeneous config could not catch a late-binding closure regression. The PixelDiT test is parametrized over adaln_mode with first_block_only_adaln=True so both closure branches (PixelDiTBlock + plain DiT3DBlock) and the s_cond_bilinear capture are checkpointed; its grad tolerance is loosened to 1e-3 for the chunked-vmap DepthwiseConv path (forward still matches exactly). Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…-trip Add test_depthwise_conv_chunked_matches_plain (the torch.vmap chunked path used by bilinear_dw == the plain conv with the same weights), and parametrize test_pixeldit_checkpoint over adaln_mode so the bilinear_dw shadowed-forward DepthwiseConv survives a full .mdlus save/load round-trip. Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…ath) The existing pixeldit_bilinear golden uses patch_size=(1,2,2) (semantic depth == pixel depth), so it only pins the 2D-bilinear fallback. Add pixeldit_bilinear_pd2 with patch_size=(2,2,2) (sd=2, pixel d=4) so the trilinear depth-upsample path — the riskiest part of bilinear_dw — gets a numeric regression check, not just a shape/finite assertion. Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…ing hook Add test_dit3d_set_tile_size_axial (re-tiling rebuilds the axial RoPE buffers so forward works at the new size; a stale buffer would make RoPE broadcast fail). Note in the PixelDiT docstring that the pixel stage has no set_tile_size equivalent, so an axial pixel RoPE is fixed to the construction-time grid. Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Be precise that DiT3D/PixelDiT are independent reimplementations of the DiT-style architecture, not a reuse of physicsnemo.models.dit.DiT's components (they define their own 3D blocks and share only the na3d functional, Mlp, and RoPE primitives). Add the naming rationale: 'DiT' names the architecture family, not a training objective; DiT3D is the conceptual 3D analog, and PixelDiT in particular retains DiT's defining adaLN conditioning (driven by semantic features instead of a diffusion timestep). Both remain deterministic regression models. Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The chunked path bound a `self`-capturing closure to `self.forward` in `__init__`. `copy.deepcopy` copies that closure's cell by reference to the ORIGINAL module, so a deep-copied DepthwiseConv silently computed with the source module's conv weights. This corrupts any EMA / SWA / AveragedModel / snapshot workflow (all of which deepcopy the model): the copy's own parameters are trained and checkpointed, but forward used stale ones. Replace the closure with a normal `forward` that dispatches on `self.chunk_size` and threads `self.weight` / `self.bias` in live. The chunked vmap callable now captures only the static conv configuration, not the module. This also fixes a `bias=False` device bug (the synthetic zero bias was captured at construction and never moved by `.to(device)`) and removes the `self.forward` reassignment that is hostile to `torch.compile`. Add regression tests: a deep-copied chunked conv uses its own (zeroed) weights, and a `bias=False` chunked conv matches the plain conv and survives a device move. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Its two siblings in the same production module (RotaryPositionEmbedding1D / RotaryPositionEmbedding2D) subclass physicsnemo.core.Module, but this class subclassed raw torch.nn.Module. Per MOD-001 all model classes in the public nn module must inherit from physicsnemo.Module; the inconsistency meant this exported layer could not participate in the Module.save() / from_checkpoint() recipe its siblings support. Switch the base to Module (the __init__ takes only head_dim / theta scalars, so arg capture works) and drop the now-unused `import torch.nn as nn`. The module is stateless (no buffers/params), so a checkpoint round-trip reconstructs it from the saved __init__ args and reproduces the forward. Fix the stale "Both RoPE modules" comment in the test (there are three) and add a checkpoint round-trip test for the stereographic module. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
set_tile_size updates the expected input tile (height / width / input_shape) for every rope_mode so the forward shape check and unpatchify accept the new resolution; for rope_mode="axial" it additionally rebuilds the cached (cos, sin) RoPE buffers. The stereographic / none modes build their tables per forward (from the supplied pos, or not at all), so updating the expected shape alone re-tiles them -- the regional / tiled-global stereographic inference path. Fix the docstring, which previously claimed the method only affects axial mode. Tests: the axial case compares the rebuilt RoPE buffers against a model built directly at the new size (catching a wrong row/col assignment that a token-count-only check would miss); a parametrized test asserts set_tile_size re-tiles stereographic / none models (a forward at the new tile runs and the old tile is correctly rejected). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
FinalLayer3D and PixelDiTLastLayer force the output projection to fp32 (norm on x.float(), autocast disabled) for numerical stability. That works under bf16_mixed autocast, where weights stay fp32, but crashes when the model is cast wholesale with model.bfloat16() / .half(): the fp32 input then meets a bf16 linear weight and F.linear raises "mat1 and mat2 must have the same dtype". Pure-half casting is a plausible deployment given the models' MetaData advertises bf16. Match the linear input to the weight dtype. Under the common fp32-weight path the cast is a no-op, so existing numerics (and goldens) are unchanged. Add a regression test that casts DiT3D and PixelDiT to bf16 and runs forward (the PixelDiT case also exercises the bf16 DepthwiseConv bilinear_dw path). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Mutation testing showed two correctness paths were caught only by the regen-fragile goldens, never by the targeted unit test: - apply_rotary_pos_emb: inverting rotate_half's sign passed every rope unit test (norm preservation and relative-position invariance are sign-blind). Add a test that pins the handedness: rotating (1,0)/(0,1) by a generic angle must follow a proper CCW rotation (1,0)->(cos,sin), (0,1)->(-sin,cos). - build_stereographic_token_coords: zeroing the projection (collapsing RoPE to identity) or scrambling the patch pooling passed the shape/finite/tiling-only test. Add value assertions: coords are non-degenerate, and the North/East coordinates increase monotonically with the patch row/column (matching the latitude/longitude gradient of the input grid). Both were verified to fail under the corresponding source mutation and pass on the real implementation. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
- Define the `RopeTables` type alias once in layers.py (the lowest-level
strata module) and import it into dit3d.py / pixel.py, instead of repeating
the identical alias in three files (MOD-004 shared utility).
- PixelDiT.semantic_config docstring said `default={}` but the signature
default is `None` (MOD-003h); document the None -> {} coercion.
- precompute_bilinear_cond: clarify that the 2D-bilinear branch at s_d == D is
deliberate (a 3D trilinear upsample agrees only to ~1e-7 there, not exactly),
so a future "simplify to always-trilinear" refactor would silently shift the
numerics of every matching-depth model.
No behavior change.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The experimental strata package's __all__ also exposes the reusable layers (DiT3DBlock, Natten3DSelfAttention, PatchEmbed3D, FinalLayer3D, PixelDiTBlock, PixelDiTLastLayer); record that they are part of the (experimental) API. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The existing pure-bf16-cast test uses the SDPA fallback (attn_kernel=-1), so it never reaches the CUDA-only NATTEN neighborhood-attention kernel. Add a CUDA-gated test that casts a DiT3D with attn_kernel>0 (real NA3D), do_alt_depthwise_attn, and gated_attention to bfloat16 and runs a forward, confirming the NA3D kernel, the depth-axis attention, the gate, and the fp32-forced output head all operate under genuine bf16 weights. Skips on CPU like the other NATTEN tests. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Greptile SummaryThis PR contributes the Strata weather-emulator architecture: a
Important Files Changed
Reviews (1): Last reviewed commit: "Test pure bf16 cast on the real NA3D att..." | Re-trigger Greptile |
| for i, block in enumerate(self.pixel_blocks): | ||
| # The two block types take different keyword signatures; wrap the | ||
| # call in a closure so activation checkpointing can drive either. | ||
| if isinstance(block, PixelDiTBlock): | ||
|
|
||
| def _run(inp, b=block): | ||
| return b( | ||
| inp, | ||
| s_cond=s_cond, | ||
| pixel_dhw=pixel_dhw, | ||
| semantic_dhw=semantic_dhw, | ||
| s_cond_bilinear=s_cond_bilinear, | ||
| rope_tables=rope_tables, | ||
| ) | ||
|
|
||
| else: | ||
|
|
||
| def _run(inp, b=block): | ||
| return b(inp, latent_dhw=pixel_dhw, rope_tables=rope_tables) | ||
|
|
||
| if self._should_checkpoint_pixel_block(i): | ||
| x_pix = activation_checkpoint(_run, x_pix, use_reentrant=False) | ||
| else: | ||
| x_pix = _run(x_pix) |
There was a problem hiding this comment.
Tensors captured by closure in
activation_checkpoint calls
s_cond, s_cond_bilinear, pixel_dhw, semantic_dhw, and rope_tables are all closed over by _run rather than passed as explicit arguments to activation_checkpoint. With use_reentrant=False this is safe — PyTorch's non-reentrant implementation correctly handles closed-over tensors and the PR's gradient-flow tests validate it — but it differs from the DiT3D.forward_tokens pattern (where rope_tables is bound via a default arg r=block_rope). Passing at minimum s_cond and s_cond_bilinear as explicit arguments would align with the non-reentrant best practice and make the gradient path easier to audit.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
@zyhu-hu I think this is reasonable ask, for code clarity
Per review: `two` is not conventional jaxtyping; a fixed-size axis should be a literal `2`. Changes the `pos` annotation from "batch two height width" to "batch 2 height width" in coords.py, dit3d.py, and pixel.py. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review: PixelDiT had no initialize_weights override, so pixel_patch_embed kept PyTorch default init (not the fan-based Xavier the backbone uses) and pixel_final_layer was never zero-initialized (breaking the near-identity residual start the DiT convention relies on). Add a targeted initialize_weights(): fan-Xavier the pixel patch-embed conv (+zero bias) and zero-init the pixel output head, mirroring DiT3D.initialize_weights. A blanket Xavier pass is deliberately avoided so it does not clobber the AdaLN-zero projections the pixel blocks set at construction. Goldens are unaffected (they seed all parameters via _seed_params). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review: StereographicRotaryPositionEmbedding2D was used purely as a 2D RoPE
table builder, never as a stateful module, so the module wrapper added no value
in production nn. Drop the class and relocate the four functional helpers
(build_rope_cos_sin_1d_continuous, build_axial_rope_cos_sin_2d_continuous,
stereographic_projection, spherical_centroid) to a new experimental module
physicsnemo/experimental/nn/rope.py, giving the continuous/stereographic RoPE
API room to mature outside the stable surface. The production integer-grid
builders (build_rope_cos_sin_1d, build_axial_rope_cos_sin_2d, apply_rotary_pos_emb,
RotaryPositionEmbedding{1,2}D) stay in physicsnemo/nn/module/rope.py.
Rewire DiT3D and PixelDiT to call build_axial_rope_cos_sin_2d_continuous
directly (storing head_dim / rope_base) instead of self.rope.build_tables;
coords.py imports the geometry helpers from physicsnemo.experimental.nn.
Tests: move the continuous-builder / projection / centroid tests to
test/experimental/nn/test_rope.py and convert the valuable module-property
tests (relative-position invariance, theta wiring, bf16 dtype) to function-level
tests; drop the module-API and now-redundant tests. The production
apply_rotary_pos_emb tests (incl. rotation-direction) stay in
test/nn/module/test_rope.py. import-linter contract unchanged (no
production -> experimental import).
This supersedes the earlier change that made the class a physicsnemo.Module
(MOD-001): the class is removed entirely.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review: - FinalLayer3D and PixelDiTLastLayer were near-identical fp32 LayerNorm+Linear heads differing only in output dim. Consolidate into a single FinalLayer3D(hidden_size, out_features); the backbone passes p_d*p_h*p_w*C_out (unpatchified to a field), the pixel stage passes C_out. Submodule names are unchanged, so checkpoints/goldens are key-preserving. - Replace the `del self.semantic.final_layer` hack with an include_head kwarg on DiT3D. include_head=False omits the output head, and forward() then returns post-block tokens (equivalently forward_tokens). PixelDiT builds its semantic stage with include_head=False, so the trunk never creates unused output-head parameters. Tests: update the constructor assertion (semantic.final_layer is now None rather than absent) and add a test for the include_head toggle (full returns a field; headless returns tokens and has no final_layer params). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…ackbone" Per review, the "DiT" naming was confusing (it suggested a 3D diffusion DiT). Rename the public API to the Strata family: - PixelDiT -> Strata (the composed train/inference model); PixelDiTMetaData -> StrataMetaData - DiT3D -> StrataTransformer3D (the backbone); DiT3DMetaData -> StrataTransformer3DMetaData; DiT3DBlock -> StrataTransformer3DBlock - PixelDiTBlock -> StrataPixel3DBlock (parallels StrataTransformer3DBlock; scrubs the last "DiT" -- the PixelDiT paper is still cited in the docstring) - the "semantic" coarse-stage identifiers -> "backbone" (backbone_config, self.backbone, backbone_cond, ...), since "semantic" is a vision term and these are simply the backbone's coarse tokens Rename the module/test files to match (dit3d.py -> transformer.py, pixel.py -> strata.py, test_dit3d.py -> test_strata.py, _generate_dit3d_goldens.py -> _generate_strata_goldens.py) and the golden fixtures (dit3d_* -> transformer_*, pixeldit_* -> strata_*). The composed-model goldens were regenerated because self.semantic -> self.backbone changes their state_dict key prefix; the regenerated outputs are bit-identical to the pre-rename outputs (verified), confirming the whole refactor is numerically lossless. The PixelDiT paper citation (arXiv:2511.20645) is preserved. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review, the experimental strata package re-exported all its building-block layers. Narrow __all__ (and the package imports) to just the two models and their metadata: Strata, StrataMetaData, StrataTransformer3D, StrataTransformer3DMetaData. The internal layers (StrataTransformer3DBlock, StrataPixel3DBlock, Natten3DSelfAttention, PatchEmbed3D, FinalLayer3D) remain importable from their submodules but are no longer part of the public surface. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review, replace the four separate Strata-related Added entries (StereoRoPE, DiT3D, PixelDiT, exported layers) with a single bullet describing the Strata architecture and the experimental continuous/stereographic RoPE helpers, using the renamed classes and the experimental RoPE location. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Found by an adversarial self-review pass; doc/hygiene only, no behavior change: - strata.py: restore the PixelDiT paper citation in the Notes section — the rename had turned "adapted from PixelDiT" into "adapted from Strata", a self-reference pointing at the PixelDiT reference entry. - coords.py: repoint two `:func:` cross-references (spherical_centroid / stereographic_projection) to physicsnemo.experimental.nn, where they now live. - transformer.py: document the include_head=False case in the class Outputs section (forward returns post-block tokens, not a field). - strata.py / transformer.py: fix import ordering (isort I001) introduced by the module rename (ruff excludes experimental, so CI does not flag it). - strata.py: drop StrataPixel3DBlock from the module __all__, matching transformer.py and the package's models-only public surface. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The test migration dropped the assertion that pinned the length_scale
normalization ("doubling length_scale halves the coords"); the new
shape/monotonicity checks are blind to the scale exponent (a `length_scale**2`
bug passed the whole suite). Re-add the ratio check: coords at length_scale 1.0
vs 2.0 must satisfy coords_2x == coords_1x / 2.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…count) The non-chunked fallback warning computed a byte count (numel * itemsize) and compared it to 2**32, but conv2d's actual limit is element-count: PyTorch's 32-bit indexing guard (canUse32BitIndexMath) raises a RuntimeError when numel > INT_MAX (2**31 - 1). Verified empirically (torch 2.7.1, cuDNN 9.7): fp32 and bf16 both fail at the same numel = 2**31 despite fp32 having 2x the bytes, confirming the limit is elements, not bytes. The byte threshold was also dtype-dependent and ~2x too conservative for fp32. Check `x.numel() > torch.iinfo(torch.int32).max` instead, and correct the warning text: it raises a RuntimeError (not "revert to a slow implementation"), and the remedy is chunk_size (which keeps each conv2d call under the limit and on the fast cuDNN path). Drop the now-unused `import math`. Addresses a Greptile review comment (misleading byte-vs-element threshold). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Match the surrounding Added entries, which use plain prose (code names in backticks) and no arbitrary bold. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The constructor (via PatchEmbed3D) rejects an input shape not divisible by the patch size, but set_tile_size updated height/width/input_shape without the same check -- so a non-divisible tile was silently accepted and floor-truncated by the patch-embed conv (and a caller reading model.input_shape would see a larger shape than is actually used). Add the matching height/width divisibility check (same messages as PatchEmbed3D) at the top of set_tile_size so re-tiling fails fast for every rope_mode. Add a regression test. Addresses a Greptile review comment. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
| spherical-grid dependency. Also adds the supporting continuous-coordinate and | ||
| stereographic RoPE helpers to `physicsnemo.experimental.nn` | ||
| (`build_rope_cos_sin_1d_continuous`, `build_axial_rope_cos_sin_2d_continuous`, | ||
| `stereographic_projection`, `spherical_centroid`). |
There was a problem hiding this comment.
Still overly verbose, these details can be put in documentation rather than changelog. Just listing the added public-level modules and functions is fine
There was a problem hiding this comment.
No need for a separate file, might as well put this in layers.py
| # __init__ args); the stereographic module is fully stateless (tables are built | ||
| # per forward from the input coordinates). In every case a round-trip reproduces | ||
| # the forward exactly without any tables appearing in the checkpoint. | ||
|
|
There was a problem hiding this comment.
Stale references to stereographic module should be removed, and there are only two RoPE modules now without it.
| output head, so the post-block tokens of shape :math:`(B, N, E)` are | ||
| returned instead (equivalently use :meth:`forward_tokens`). | ||
| """ | ||
| _, _, d_in, h_in, w_in = x.shape |
There was a problem hiding this comment.
This happens before input validation and could lead to "too many values to unpack" error. Since forward_tokens guarantees the shape equals the configured one, you could just use self.depth/height/width for unpatchify, and drop the early unpack here.
| f"{tuple(x.shape)}" | ||
| ) | ||
| if self.rope_mode_pixel == "stereographic" and pos is None: | ||
| raise ValueError( |
There was a problem hiding this comment.
This only checks for pos is None but not shape, and the backbone validates pos.shape only when the backbone's rope_mode="stereographic". So the backbone="none"/"axial", pixel="stereographic" config never shape-checks pos; a wrong (B,2,H,W) falls into a cryptic rearrange error in build_stereographic_token_coords. Worth a one-line shape check in Strata.forward.
| # Insert a heads axis so the tables broadcast over (B, heads, N, head_dim). | ||
| return cos.unsqueeze(1), sin.unsqueeze(1) | ||
|
|
||
| def prepare_tokens( |
There was a problem hiding this comment.
Use jaxtyping or make private to the class
| x = block(x, latent_dhw=latent_dhw, rope_tables=block_rope) | ||
| return x, latent_dhw | ||
|
|
||
| def unpatchify( |
There was a problem hiding this comment.
Use jaxtyping or make private to the class
|
|
||
| @staticmethod | ||
| def precompute_bilinear_cond( | ||
| backbone_cond: torch.Tensor, |
There was a problem hiding this comment.
Use jaxtyping annotations
| nn.init.constant_(self.pixel_patch_embed.proj.bias, 0.0) | ||
|
|
||
| nn.init.constant_(self.pixel_final_layer.linear.weight, 0.0) | ||
| nn.init.constant_(self.pixel_final_layer.linear.bias, 0.0) |
There was a problem hiding this comment.
Nothing here ever initializes the pixel blocks' attention and MLP Linears (they keep PyTorch's nn.Linear default), intended or no? It sorta contradicts the docstring claim of matching StrataTransformer3D.initialize_weights
| "Strata", | ||
| "StrataMetaData", | ||
| "StrataTransformer3D", | ||
| "StrataTransformer3DMetaData", |
There was a problem hiding this comment.
No need to export the MetaData classes
| ) | ||
| from .transformer import StrataTransformer3D | ||
|
|
||
| __all__ = ["Strata", "StrataMetaData"] |
There was a problem hiding this comment.
Same here and in transformer.py, no need to export the MetaData, it's just a module-local helper
| auto_grad: bool = False | ||
|
|
||
|
|
||
| class StrataPixel3DBlock(nn.Module): |
There was a problem hiding this comment.
Since it's a block, maybe move it to layers.py?
|
/ok to test 78e69d3 |
PhysicsNeMo Pull Request
Description
Contributes the Strata weather-emulator architecture in two layers:
physicsnemo.nn— a stereographic 2D rotary position embedding primitive(the continuous-coordinate sibling of
physicsnemo.nn'sRotaryPositionEmbedding2D).physicsnemo.experimental.models.strata— theDiT3DandPixelDiTregression models that consume it.
Part 1 — StereoRoPE (
physicsnemo/nn/module/rope.py, re-exported fromphysicsnemo.nn)Token latitude/longitude are mapped to a local tangent plane via a stereographic
projection; the resulting continuous
(x, y)coordinates drive a 2D RoPE thatreuses
physicsnemo.nn'sapply_rotary_pos_embverbatim (the per-pairrotation math is identical — no new rotation code). New public API:
StereographicRotaryPositionEmbedding2D—project(lat, lon, length_scale, …)maps lat/lon to continuous tangent-plane coords;
build_tables/forward(q, k, x_pos, y_pos)rotate the query/key tensors.
length_scaleis a required argument (no sensibledefault can be inferred from data).
build_axial_rope_cos_sin_2d_continuous— continuous-coordinate generalization ofbuild_axial_rope_cos_sin_2d(the integer grid is its special case, asserted by atest), built on a reusable
build_rope_cos_sin_1d_continuouscore.stereographic_projection— antipode-guarded tangent-plane projection.spherical_centroid— pole- and seam-robust tile center (3D unit-vector mean),used for
project's default centering (correct at the poles, where a per-axislat/lon mean is not).
Placement: a small, parameter-free primitive sharing the rotation convention
and channel layout of
physicsnemo.nn's existing RoPE, so per MOD-002a's maturityexception it lives in production
nn/module/rope.py(MOD-000a reusable layer), notexperimental. It is a plaintorch.nn.Module, not aphysicsnemo.Modulemodel.Part 2 — DiT3D / PixelDiT (
physicsnemo/experimental/models/strata/)DiT3D— a 3D transformer backbone (its own 3D reimplementation of theDiT-style template, not a reuse of
physicsnemo.models.dit.DiT's components):3D patch embedding, 3D neighborhood attention (reusing
physicsnemo.nn.functional.na3d)with optional depth-axis and gated attention, optional axial or stereographic
RoPE, activation checkpointing, and bf16 autocast (CUDA).
set_tile_sizere-tilesa trained model for inference at a new resolution (rebuilding the cached axial tables;
the stereographic / none modes build per forward, so only the expected shape updates) —
supporting regional / tiled-global rollout.
PixelDiT— two-stage: aDiT3Dsemantic stage (coarse patches) conditions apixel-resolution stage whose blocks inject the semantic tokens via pixel-wise
adaptive layer norm (
pixel_proj/bilinear_dw).bilinear_dwsupports anysemantic vertical patch size (trilinear depth upsample, with an exact 2D-bilinear
fallback when pixel depth == semantic depth, so the common case is numerically
unchanged). Pixel-stage and semantic-stage RoPE are independent; pixel blocks also
support activation checkpointing. Adapts the pixel-wise-AdaLN idea from PixelDiT
(arXiv:2511.20645) — an adaptation, not a faithful port: deterministic regression
(no diffusion/timestep/label conditioning), an independent AdaLN reimplementation,
with the original
bilinear_dwconditioning path added beyond the paper.Decoupled geometry: lat/lon is an optional
forwardinput used only forstereographic RoPE — no
earth2grid/ grid-library dependency, no hard spherical-gridassumption. The token-coordinate builders are pure free functions in
coords.pyshared by both stages (no cross-stage instance coupling). Per MOD-002a, new models
are introduced in
experimental.Tests
test/nn/module/test_rope.py(17 new tests): continuous-builder ≡ integer-builderconsistency, rotation handedness (a flipped
rotate_halfis caught), projectiongeometry + closed-form
2·tan(Δ/2), antipode finiteness,spherical_centroidpole/seam robustness,
projectcentering + requiredlength_scale, moduleshapes/validation/norm preservation, relative-position invariance, batched-coord
broadcast, end-to-end lat/lon →
project→forward, exact forward-vs-build_tables,physicsnemo.Modulecheckpoint round-trip, bf16 dtype preservation,thetawiring.test/experimental/models/strata/: constructor/attribute, validation, forward-shapeacross RoPE / AdaLN modes, non-regression goldens (
{args, state_dict, y}reload —version-robust, CPU-reproducible on the
attn_kernel=-1SDPA path),physicsnemo.Modulecheckpoint round-trip, the NATTEN NA3D path on CUDA, gradient-flow (no dead params;
PixelDiT semantic stage reached through the pixel stage),
DepthwiseConv(chunked ≡plain, deepcopy uses its own weights,
bias=Falsedevice move), pure-bfloat16/halfcast of both models (SDPA path on CPU/CUDA, plus the real NA3D + depth-axis + gated
attention path on CUDA),
set_tile_sizere-tiling (axial buffer rebuild + non-axial shapeupdate, with the new tile accepted and the old one rejected), shape variation (incl.
vertical patch > 1),
torch.compile, bf16 autocast, activation checkpointing (DiT3D +PixelDiT, train-mode), PixelDiT RoPE in both stages independently, and the
coords.pybuilders (with value-pinned stereographic geometry).
74 passed, 3 skipped. The three skips aretest_dit3d_natten_forward[cpu],test_pixeldit_natten_forward[cpu], andtest_dit3d_natten_bf16_cast[cpu]— theypytest.skipby design because NATTEN's neighborhood-attention kernels areCUDA-only (no CPU kernel exists), not because of any missing coverage or
environment gap. Their
[cuda]counterparts run and pass on a GPU, and the CPU sideof these models is exercised by the dense SDPA fallback (
attn_kernel=-1), whichthe goldens use for deterministic CPU reproduction. CI without a GPU will skip the
NA3D cases identically.
ruff,interrogate(100%), license headers, andimport-linter(10/0) all clean.Checklist
author note: no existing issue is relevant to this PR
Dependencies
None new. 3D neighborhood attention uses
physicsnemo.nn.functional.na3d(NATTEN —already an optional
physicsnemoextra), with a CPU SDPA fallback (attn_kernel=-1).Geometry is an optional
forwardinput, so there is noearth2grid/ grid-librarydependency.
import-linter's "Prevent Non-listed external imports" contract passesunchanged.
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.