Skip to content

Add support for StereoRoPE positional embedding and Strata features#1750

Open
zyhu-hu wants to merge 31 commits into
NVIDIA:mainfrom
zyhu-hu:feat/strata
Open

Add support for StereoRoPE positional embedding and Strata features#1750
zyhu-hu wants to merge 31 commits into
NVIDIA:mainfrom
zyhu-hu:feat/strata

Conversation

@zyhu-hu

@zyhu-hu zyhu-hu commented Jun 24, 2026

Copy link
Copy Markdown

PhysicsNeMo Pull Request

Description

Contributes the Strata weather-emulator architecture in two layers:

  1. physicsnemo.nn — a stereographic 2D rotary position embedding primitive
    (the continuous-coordinate sibling of physicsnemo.nn's RotaryPositionEmbedding2D).
  2. physicsnemo.experimental.models.strata — the DiT3D and PixelDiT
    regression models that consume it.

On the "DiT" name. "DiT" denotes the architecture family, not a training
objective. DiT3D is a conceptual 3D analog of physicsnemo.models.dit.DiT:
it follows the same DiT-style template (patch-embed → pre-norm transformer →
linear decode) but is an independent reimplementation, not a reuse of that
class's components — it defines its own 3D blocks and shares only the na3d
functional, Mlp, and the RoPE primitives. PixelDiT additionally keeps DiT's
defining adaptive-layer-norm (adaLN) conditioning — only the conditioning
signal differs (the semantic stage's features vs. a diffusion timestep) — and
its name also credits the PixelDiT paper (arXiv:2511.20645) it adapts.

Both are deterministic regression models — not generative diffusion
models: no diffusion process and no noise / timestep / class-label / text
conditioning (DiT3D carries no adaLN at all; PixelDiT's adaLN is driven by
the semantic stage, not a timestep).

Part 1 — StereoRoPE (physicsnemo/nn/module/rope.py, re-exported from physicsnemo.nn)

Token latitude/longitude are mapped to a local tangent plane via a stereographic
projection; the resulting continuous (x, y) coordinates drive a 2D RoPE that
reuses physicsnemo.nn's apply_rotary_pos_emb verbatim (the per-pair
rotation math is identical — no new rotation code). New public API:

  • StereographicRotaryPositionEmbedding2Dproject(lat, lon, length_scale, …)
    maps lat/lon to continuous tangent-plane coords; build_tables / forward(q, k, x_pos, y_pos)
    rotate the query/key tensors. length_scale is a required argument (no sensible
    default can be inferred from data).
  • build_axial_rope_cos_sin_2d_continuous — continuous-coordinate generalization of
    build_axial_rope_cos_sin_2d (the integer grid is its special case, asserted by a
    test), built on a reusable build_rope_cos_sin_1d_continuous core.
  • stereographic_projection — antipode-guarded tangent-plane projection.
  • spherical_centroid — pole- and seam-robust tile center (3D unit-vector mean),
    used for project's default centering (correct at the poles, where a per-axis
    lat/lon mean is not).

Placement: a small, parameter-free primitive sharing the rotation convention
and channel layout of physicsnemo.nn's existing RoPE, so per MOD-002a's maturity
exception it lives in production
nn/module/rope.py (MOD-000a reusable layer), not experimental. It is a plain
torch.nn.Module, not a physicsnemo.Module model.

Part 2 — DiT3D / PixelDiT (physicsnemo/experimental/models/strata/)

  • DiT3D — a 3D transformer backbone (its own 3D reimplementation of the
    DiT-style template, not a reuse of physicsnemo.models.dit.DiT's components):
    3D patch embedding, 3D neighborhood attention (reusing physicsnemo.nn.functional.na3d)
    with optional depth-axis and gated attention, optional axial or stereographic
    RoPE, activation checkpointing, and bf16 autocast (CUDA). set_tile_size re-tiles
    a trained model for inference at a new resolution (rebuilding the cached axial tables;
    the stereographic / none modes build per forward, so only the expected shape updates) —
    supporting regional / tiled-global rollout.
  • PixelDiT — two-stage: a DiT3D semantic stage (coarse patches) conditions a
    pixel-resolution stage whose blocks inject the semantic tokens via pixel-wise
    adaptive layer norm
    (pixel_proj / bilinear_dw). bilinear_dw supports any
    semantic vertical patch size (trilinear depth upsample, with an exact 2D-bilinear
    fallback when pixel depth == semantic depth, so the common case is numerically
    unchanged). Pixel-stage and semantic-stage RoPE are independent; pixel blocks also
    support activation checkpointing. Adapts the pixel-wise-AdaLN idea from PixelDiT
    (arXiv:2511.20645) — an adaptation, not a faithful port: deterministic regression
    (no diffusion/timestep/label conditioning), an independent AdaLN reimplementation,
    with the original bilinear_dw conditioning path added beyond the paper.

Decoupled geometry: lat/lon is an optional forward input used only for
stereographic RoPE — no earth2grid / grid-library dependency, no hard spherical-grid
assumption. The token-coordinate builders are pure free functions in coords.py
shared by both stages (no cross-stage instance coupling). Per MOD-002a, new models
are introduced in experimental.

Tests

  • test/nn/module/test_rope.py (17 new tests): continuous-builder ≡ integer-builder
    consistency, rotation handedness (a flipped rotate_half is caught), projection
    geometry + closed-form 2·tan(Δ/2), antipode finiteness, spherical_centroid
    pole/seam robustness, project centering + required length_scale, module
    shapes/validation/norm preservation, relative-position invariance, batched-coord
    broadcast, end-to-end lat/lon → projectforward, exact forward-vs-build_tables,
    physicsnemo.Module checkpoint round-trip, bf16 dtype preservation, theta wiring.
  • test/experimental/models/strata/: constructor/attribute, validation, forward-shape
    across RoPE / AdaLN modes, non-regression goldens ({args, state_dict, y} reload —
    version-robust, CPU-reproducible on the attn_kernel=-1 SDPA path), physicsnemo.Module
    checkpoint round-trip, the NATTEN NA3D path on CUDA, gradient-flow (no dead params;
    PixelDiT semantic stage reached through the pixel stage), DepthwiseConv (chunked ≡
    plain, deepcopy uses its own weights, bias=False device move), pure-bfloat16/half
    cast of both models (SDPA path on CPU/CUDA, plus the real NA3D + depth-axis + gated
    attention path on CUDA), set_tile_size re-tiling (axial buffer rebuild + non-axial shape
    update, with the new tile accepted and the old one rejected), shape variation (incl.
    vertical patch > 1), torch.compile, bf16 autocast, activation checkpointing (DiT3D +
    PixelDiT, train-mode), PixelDiT RoPE in both stages independently, and the coords.py
    builders (with value-pinned stereographic geometry).
  • About the CPU skips. On a GPU runner the strata suite reports 74 passed, 3 skipped. The three skips are test_dit3d_natten_forward[cpu],
    test_pixeldit_natten_forward[cpu], and test_dit3d_natten_bf16_cast[cpu] — they
    pytest.skip by design because NATTEN's neighborhood-attention kernels are
    CUDA-only (no CPU kernel exists), not because of any missing coverage or
    environment gap. Their [cuda] counterparts run and pass on a GPU, and the CPU side
    of these models is exercised by the dense SDPA fallback (attn_kernel=-1), which
    the goldens use for deterministic CPU reproduction. CI without a GPU will skip the
    NA3D cases identically.
  • ruff, interrogate (100%), license headers, and import-linter (10/0) all clean.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • [na] An issue is linked to this pull request.
    author note: no existing issue is relevant to this PR
  • If I am implementing a new model or modifying any existing model, I have followed the Models Implementation Coding Standards.

Dependencies

None new. 3D neighborhood attention uses physicsnemo.nn.functional.na3d (NATTEN —
already an optional physicsnemo extra), with a CPU SDPA fallback (attn_kernel=-1).
Geometry is an optional forward input, so there is no earth2grid / grid-library
dependency. import-linter's "Prevent Non-listed external imports" contract passes
unchanged.

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

zyhu-hu and others added 18 commits June 23, 2026 14:41
… physicsnemo.nn

Extend physicsnemo/nn/module/rope.py with a stereographic 2D rotary position embedding for tokens on a sphere, reusing this stack's rope primitives (apply_rotary_pos_emb / rotate_half_pairs). Adds build_rope_cos_sin_1d_continuous and build_axial_rope_cos_sin_continuous (the continuous-coordinate generalizations of build_rope_cos_sin_1d and build_axial_rope_cos_sin), stereographic_projection, spherical_centroid (a pole- and seam-robust tile center), and the StereographicRotaryPositionEmbedding2D module, with tests in test/nn/module/test_rope.py.

Stacked on PR NVIDIA#1731 (pzharrington/physicsnemo:dit-enhancements); merge after NVIDIA#1731.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Add the DiT3D 3D transformer backbone and the two-stage PixelDiT model under physicsnemo/experimental/models/strata/. Both use the DiT architecture as deterministic regression models (no diffusion / timestep / label conditioning) and reuse the stereographic RoPE + rope primitives from physicsnemo.nn (StereographicRotaryPositionEmbedding2D / apply_rotary_pos_emb) and physicsnemo.nn.functional.na3d for 3D neighborhood attention. Geometry is decoupled to an optional forward input; the stereographic mode derives each tile center from spherical_centroid (pole- and seam-robust). Includes constructor, non-regression golden, checkpoint, and CUDA/NATTEN tests.

Stacked on PR NVIDIA#1731 (pzharrington/physicsnemo:dit-enhancements); merge after NVIDIA#1731.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…nting

_should_checkpoint_block gates on self.training, so the test running in eval mode never checkpointed — it compared two plain forwards and passed trivially. Run both models in train mode and assert checkpointing is active, so the test genuinely verifies that activation checkpointing reproduces the non-checkpointed output and gradients.

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Mirror DiT3D's mechanism for the pixel stage: a new activation_checkpointing_pixel arg (bool | float) reusing DiT3D._parse_checkpointing_param; a train-gated, ratio-based _should_checkpoint_pixel_block; and the pixel-block loop wraps each block in torch.utils.checkpoint when selected. Ports the pixel-block checkpointing from the screamcast Strata source. Adds test_pixeldit_activation_checkpointing_matches (train mode) asserting the checkpointed run reproduces the non-checkpointed output and per-parameter gradients. Additive and default-off, so the goldens are unchanged.

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Hoist the length-3 (kd, kh, kw) check into _as_kernel_triple and validate eagerly in Natten3DSelfAttention.__init__, which every block (DiT3D and PixelDiT) constructs — so a malformed tuple fails clearly at construction instead of deep inside NATTEN. Removes the now-redundant DiT3D-only check. Adds a test that a non-length-3 tuple is rejected.

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Use self.num_adaln_params instead of the literal 6 in the pixel_proj chunk, and rename precompute_bilinear_cond's locals ph/pw -> h/w (they are pixel height/width, not patch sizes, which ph/pw mean everywhere else in the file).

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…ture

The DiT3D test now uses do_alt_depthwise_attn + stereographic RoPE so per-block rope_tables differ (depth blocks get None) — a homogeneous config could not catch a late-binding closure regression. The PixelDiT test is parametrized over adaln_mode with first_block_only_adaln=True so both closure branches (PixelDiTBlock + plain DiT3DBlock) and the s_cond_bilinear capture are checkpointed; its grad tolerance is loosened to 1e-3 for the chunked-vmap DepthwiseConv path (forward still matches exactly).

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…-trip

Add test_depthwise_conv_chunked_matches_plain (the torch.vmap chunked path used by bilinear_dw == the plain conv with the same weights), and parametrize test_pixeldit_checkpoint over adaln_mode so the bilinear_dw shadowed-forward DepthwiseConv survives a full .mdlus save/load round-trip.

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…ath)

The existing pixeldit_bilinear golden uses patch_size=(1,2,2) (semantic depth == pixel depth), so it only pins the 2D-bilinear fallback. Add pixeldit_bilinear_pd2 with patch_size=(2,2,2) (sd=2, pixel d=4) so the trilinear depth-upsample path — the riskiest part of bilinear_dw — gets a numeric regression check, not just a shape/finite assertion.

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…ing hook

Add test_dit3d_set_tile_size_axial (re-tiling rebuilds the axial RoPE buffers so forward works at the new size; a stale buffer would make RoPE broadcast fail). Note in the PixelDiT docstring that the pixel stage has no set_tile_size equivalent, so an axial pixel RoPE is fixed to the construction-time grid.

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Be precise that DiT3D/PixelDiT are independent reimplementations of the DiT-style architecture, not a reuse of physicsnemo.models.dit.DiT's components (they define their own 3D blocks and share only the na3d functional, Mlp, and RoPE primitives). Add the naming rationale: 'DiT' names the architecture family, not a training objective; DiT3D is the conceptual 3D analog, and PixelDiT in particular retains DiT's defining adaLN conditioning (driven by semantic features instead of a diffusion timestep). Both remain deterministic regression models.

Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The chunked path bound a `self`-capturing closure to `self.forward` in
`__init__`. `copy.deepcopy` copies that closure's cell by reference to the
ORIGINAL module, so a deep-copied DepthwiseConv silently computed with the
source module's conv weights. This corrupts any EMA / SWA / AveragedModel /
snapshot workflow (all of which deepcopy the model): the copy's own
parameters are trained and checkpointed, but forward used stale ones.

Replace the closure with a normal `forward` that dispatches on
`self.chunk_size` and threads `self.weight` / `self.bias` in live. The
chunked vmap callable now captures only the static conv configuration, not
the module. This also fixes a `bias=False` device bug (the synthetic zero
bias was captured at construction and never moved by `.to(device)`) and
removes the `self.forward` reassignment that is hostile to `torch.compile`.

Add regression tests: a deep-copied chunked conv uses its own (zeroed)
weights, and a `bias=False` chunked conv matches the plain conv and survives
a device move.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Its two siblings in the same production module (RotaryPositionEmbedding1D /
RotaryPositionEmbedding2D) subclass physicsnemo.core.Module, but this class
subclassed raw torch.nn.Module. Per MOD-001 all model classes in the public
nn module must inherit from physicsnemo.Module; the inconsistency meant this
exported layer could not participate in the Module.save() / from_checkpoint()
recipe its siblings support.

Switch the base to Module (the __init__ takes only head_dim / theta scalars,
so arg capture works) and drop the now-unused `import torch.nn as nn`. The
module is stateless (no buffers/params), so a checkpoint round-trip
reconstructs it from the saved __init__ args and reproduces the forward.

Fix the stale "Both RoPE modules" comment in the test (there are three) and
add a checkpoint round-trip test for the stereographic module.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
set_tile_size updates the expected input tile (height / width / input_shape)
for every rope_mode so the forward shape check and unpatchify accept the new
resolution; for rope_mode="axial" it additionally rebuilds the cached
(cos, sin) RoPE buffers. The stereographic / none modes build their tables per
forward (from the supplied pos, or not at all), so updating the expected shape
alone re-tiles them -- the regional / tiled-global stereographic inference
path. Fix the docstring, which previously claimed the method only affects
axial mode.

Tests: the axial case compares the rebuilt RoPE buffers against a model built
directly at the new size (catching a wrong row/col assignment that a
token-count-only check would miss); a parametrized test asserts set_tile_size
re-tiles stereographic / none models (a forward at the new tile runs and the
old tile is correctly rejected).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
FinalLayer3D and PixelDiTLastLayer force the output projection to fp32
(norm on x.float(), autocast disabled) for numerical stability. That works
under bf16_mixed autocast, where weights stay fp32, but crashes when the
model is cast wholesale with model.bfloat16() / .half(): the fp32 input then
meets a bf16 linear weight and F.linear raises "mat1 and mat2 must have the
same dtype". Pure-half casting is a plausible deployment given the models'
MetaData advertises bf16.

Match the linear input to the weight dtype. Under the common fp32-weight
path the cast is a no-op, so existing numerics (and goldens) are unchanged.

Add a regression test that casts DiT3D and PixelDiT to bf16 and runs forward
(the PixelDiT case also exercises the bf16 DepthwiseConv bilinear_dw path).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Mutation testing showed two correctness paths were caught only by the
regen-fragile goldens, never by the targeted unit test:

- apply_rotary_pos_emb: inverting rotate_half's sign passed every rope unit
  test (norm preservation and relative-position invariance are sign-blind).
  Add a test that pins the handedness: rotating (1,0)/(0,1) by a generic angle
  must follow a proper CCW rotation (1,0)->(cos,sin), (0,1)->(-sin,cos).

- build_stereographic_token_coords: zeroing the projection (collapsing RoPE to
  identity) or scrambling the patch pooling passed the shape/finite/tiling-only
  test. Add value assertions: coords are non-degenerate, and the North/East
  coordinates increase monotonically with the patch row/column (matching the
  latitude/longitude gradient of the input grid).

Both were verified to fail under the corresponding source mutation and pass on
the real implementation.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
- Define the `RopeTables` type alias once in layers.py (the lowest-level
  strata module) and import it into dit3d.py / pixel.py, instead of repeating
  the identical alias in three files (MOD-004 shared utility).
- PixelDiT.semantic_config docstring said `default={}` but the signature
  default is `None` (MOD-003h); document the None -> {} coercion.
- precompute_bilinear_cond: clarify that the 2D-bilinear branch at s_d == D is
  deliberate (a 3D trilinear upsample agrees only to ~1e-7 there, not exactly),
  so a future "simplify to always-trilinear" refactor would silently shift the
  numerics of every matching-depth model.

No behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The experimental strata package's __all__ also exposes the reusable layers
(DiT3DBlock, Natten3DSelfAttention, PatchEmbed3D, FinalLayer3D, PixelDiTBlock,
PixelDiTLastLayer); record that they are part of the (experimental) API.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

The existing pure-bf16-cast test uses the SDPA fallback (attn_kernel=-1), so it
never reaches the CUDA-only NATTEN neighborhood-attention kernel. Add a
CUDA-gated test that casts a DiT3D with attn_kernel>0 (real NA3D),
do_alt_depthwise_attn, and gated_attention to bfloat16 and runs a forward,
confirming the NA3D kernel, the depth-axis attention, the gate, and the
fp32-forced output head all operate under genuine bf16 weights. Skips on CPU
like the other NATTEN tests.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
@zyhu-hu zyhu-hu marked this pull request as ready for review June 24, 2026 17:31
@zyhu-hu zyhu-hu requested a review from loliverhennigh as a code owner June 24, 2026 17:31
@greptile-apps

greptile-apps Bot commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR contributes the Strata weather-emulator architecture: a StereographicRotaryPositionEmbedding2D primitive in physicsnemo.nn (continuous-coordinate sphere-aware RoPE using stereographic projection) and two deterministic regression models — DiT3D (a 3D patch-embedding + neighborhood-attention transformer) and PixelDiT (DiT3D semantic stage conditioning a pixel-resolution stage via pixel-wise AdaLN) — placed under physicsnemo.experimental.models.strata.

  • rope.py: adds StereographicRotaryPositionEmbedding2D, build_axial_rope_cos_sin_2d_continuous, stereographic_projection, and spherical_centroid; all reuse the existing apply_rotary_pos_emb rotation kernel and are gated behind the existing module's head_dim % 4 validation.
  • dit3d.py / pixel.py: full physicsnemo.Module models with ModelMetaData, DiT-style zero-init output heads, bf16 autocast, activation checkpointing, and optional axial or stereographic RoPE; DiT3D.set_tile_size rebuilds axial buffers for inference re-tiling.
  • Tests: 17 new RoPE tests and a comprehensive strata suite (non-regression goldens, gradient flow, torch.compile, NATTEN, set_tile_size, DepthwiseConv correctness); all CI-skipped cases are by-design CUDA-only NATTEN paths.

Important Files Changed

Filename Overview
physicsnemo/nn/module/rope.py Adds StereographicRotaryPositionEmbedding2D, build_rope_cos_sin_1d_continuous, build_axial_rope_cos_sin_2d_continuous, stereographic_projection, and spherical_centroid. Well-documented, antipode-guarded, correctly composes with existing apply_rotary_pos_emb; no issues found.
physicsnemo/experimental/models/strata/dit3d.py New DiT3D model; construction, shape validation, activation checkpointing, and set_tile_size are well-implemented. Minor: set_tile_size doesn't validate divisibility of new height/width by patch size when rope_mode=axial.
physicsnemo/experimental/models/strata/pixel.py Two-stage PixelDiT model with pixel-wise AdaLN conditioning; activation checkpointing uses closed-over tensors (s_cond, s_cond_bilinear) rather than explicit checkpoint arguments, which works with use_reentrant=False but diverges from DiT3D's pattern.
physicsnemo/experimental/models/strata/layers.py Natten3DSelfAttention, DiT3DBlock, PatchEmbed3D, FinalLayer3D building blocks; fp32 output heads, correct RoPE application, proper dtype handling. No issues.
physicsnemo/experimental/models/strata/depthwise_conv.py DepthwiseConv with vmap-based chunked path correctly passes weight/bias as explicit arguments for deepcopy safety. Warning threshold uses byte count compared to 2^32 bytes, which is overly conservative for float32 (warns at 2^30 elements, actual limit ~2^31).
physicsnemo/experimental/models/strata/coords.py Pure helpers for axial and stereographic token coordinate building; spherical_centroid used correctly for patch pooling and tile centring; no issues.
test/experimental/models/strata/test_dit3d.py Comprehensive test suite covering construction, forward shapes, non-regression goldens, checkpoint round-trips, gradient flow, set_tile_size, bf16 casts, and activation checkpointing. No issues.
test/nn/module/test_rope.py 17 new tests for the stereographic RoPE additions; verifies consistency with the integer-grid builder, handedness, projection geometry, norm preservation, and dtype fidelity.

Reviews (1): Last reviewed commit: "Test pure bf16 cast on the real NA3D att..." | Re-trigger Greptile

Comment thread physicsnemo/experimental/models/strata/depthwise_conv.py Outdated
Comment on lines +846 to +869
for i, block in enumerate(self.pixel_blocks):
# The two block types take different keyword signatures; wrap the
# call in a closure so activation checkpointing can drive either.
if isinstance(block, PixelDiTBlock):

def _run(inp, b=block):
return b(
inp,
s_cond=s_cond,
pixel_dhw=pixel_dhw,
semantic_dhw=semantic_dhw,
s_cond_bilinear=s_cond_bilinear,
rope_tables=rope_tables,
)

else:

def _run(inp, b=block):
return b(inp, latent_dhw=pixel_dhw, rope_tables=rope_tables)

if self._should_checkpoint_pixel_block(i):
x_pix = activation_checkpoint(_run, x_pix, use_reentrant=False)
else:
x_pix = _run(x_pix)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Tensors captured by closure in activation_checkpoint calls

s_cond, s_cond_bilinear, pixel_dhw, semantic_dhw, and rope_tables are all closed over by _run rather than passed as explicit arguments to activation_checkpoint. With use_reentrant=False this is safe — PyTorch's non-reentrant implementation correctly handles closed-over tensors and the PR's gradient-flow tests validate it — but it differs from the DiT3D.forward_tokens pattern (where rope_tables is bound via a default arg r=block_rope). Passing at minimum s_cond and s_cond_bilinear as explicit arguments would align with the non-reentrant best practice and make the gradient path easier to audit.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zyhu-hu I think this is reasonable ask, for code clarity

Comment thread physicsnemo/experimental/models/strata/transformer.py
Comment thread CHANGELOG.md Outdated
Comment thread physicsnemo/experimental/models/strata/dit3d.py Outdated
Comment thread physicsnemo/experimental/models/strata/transformer.py
Comment thread physicsnemo/experimental/models/strata/__init__.py Outdated
Comment thread physicsnemo/experimental/models/strata/pixel.py Outdated
Comment thread physicsnemo/experimental/models/strata/transformer.py
Comment thread physicsnemo/experimental/models/strata/pixel.py Outdated
Comment thread physicsnemo/experimental/models/strata/pixel.py Outdated
Comment thread physicsnemo/nn/module/rope.py
zyhu-hu and others added 12 commits June 25, 2026 14:29
Per review: `two` is not conventional jaxtyping; a fixed-size axis should be a
literal `2`. Changes the `pos` annotation from "batch two height width" to
"batch 2 height width" in coords.py, dit3d.py, and pixel.py.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review: PixelDiT had no initialize_weights override, so pixel_patch_embed
kept PyTorch default init (not the fan-based Xavier the backbone uses) and
pixel_final_layer was never zero-initialized (breaking the near-identity
residual start the DiT convention relies on).

Add a targeted initialize_weights(): fan-Xavier the pixel patch-embed conv
(+zero bias) and zero-init the pixel output head, mirroring
DiT3D.initialize_weights. A blanket Xavier pass is deliberately avoided so it
does not clobber the AdaLN-zero projections the pixel blocks set at
construction.

Goldens are unaffected (they seed all parameters via _seed_params).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review: StereographicRotaryPositionEmbedding2D was used purely as a 2D RoPE
table builder, never as a stateful module, so the module wrapper added no value
in production nn. Drop the class and relocate the four functional helpers
(build_rope_cos_sin_1d_continuous, build_axial_rope_cos_sin_2d_continuous,
stereographic_projection, spherical_centroid) to a new experimental module
physicsnemo/experimental/nn/rope.py, giving the continuous/stereographic RoPE
API room to mature outside the stable surface. The production integer-grid
builders (build_rope_cos_sin_1d, build_axial_rope_cos_sin_2d, apply_rotary_pos_emb,
RotaryPositionEmbedding{1,2}D) stay in physicsnemo/nn/module/rope.py.

Rewire DiT3D and PixelDiT to call build_axial_rope_cos_sin_2d_continuous
directly (storing head_dim / rope_base) instead of self.rope.build_tables;
coords.py imports the geometry helpers from physicsnemo.experimental.nn.

Tests: move the continuous-builder / projection / centroid tests to
test/experimental/nn/test_rope.py and convert the valuable module-property
tests (relative-position invariance, theta wiring, bf16 dtype) to function-level
tests; drop the module-API and now-redundant tests. The production
apply_rotary_pos_emb tests (incl. rotation-direction) stay in
test/nn/module/test_rope.py. import-linter contract unchanged (no
production -> experimental import).

This supersedes the earlier change that made the class a physicsnemo.Module
(MOD-001): the class is removed entirely.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review:

- FinalLayer3D and PixelDiTLastLayer were near-identical fp32 LayerNorm+Linear
  heads differing only in output dim. Consolidate into a single
  FinalLayer3D(hidden_size, out_features); the backbone passes
  p_d*p_h*p_w*C_out (unpatchified to a field), the pixel stage passes C_out.
  Submodule names are unchanged, so checkpoints/goldens are key-preserving.

- Replace the `del self.semantic.final_layer` hack with an include_head kwarg
  on DiT3D. include_head=False omits the output head, and forward() then
  returns post-block tokens (equivalently forward_tokens). PixelDiT builds its
  semantic stage with include_head=False, so the trunk never creates unused
  output-head parameters.

Tests: update the constructor assertion (semantic.final_layer is now None
rather than absent) and add a test for the include_head toggle (full returns a
field; headless returns tokens and has no final_layer params).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…ackbone"

Per review, the "DiT" naming was confusing (it suggested a 3D diffusion DiT).
Rename the public API to the Strata family:

- PixelDiT -> Strata (the composed train/inference model); PixelDiTMetaData -> StrataMetaData
- DiT3D -> StrataTransformer3D (the backbone); DiT3DMetaData -> StrataTransformer3DMetaData;
  DiT3DBlock -> StrataTransformer3DBlock
- PixelDiTBlock -> StrataPixel3DBlock (parallels StrataTransformer3DBlock; scrubs
  the last "DiT" -- the PixelDiT paper is still cited in the docstring)
- the "semantic" coarse-stage identifiers -> "backbone" (backbone_config,
  self.backbone, backbone_cond, ...), since "semantic" is a vision term and these
  are simply the backbone's coarse tokens

Rename the module/test files to match (dit3d.py -> transformer.py,
pixel.py -> strata.py, test_dit3d.py -> test_strata.py,
_generate_dit3d_goldens.py -> _generate_strata_goldens.py) and the golden
fixtures (dit3d_* -> transformer_*, pixeldit_* -> strata_*). The composed-model
goldens were regenerated because self.semantic -> self.backbone changes their
state_dict key prefix; the regenerated outputs are bit-identical to the
pre-rename outputs (verified), confirming the whole refactor is numerically
lossless. The PixelDiT paper citation (arXiv:2511.20645) is preserved.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review, the experimental strata package re-exported all its building-block
layers. Narrow __all__ (and the package imports) to just the two models and
their metadata: Strata, StrataMetaData, StrataTransformer3D,
StrataTransformer3DMetaData. The internal layers (StrataTransformer3DBlock,
StrataPixel3DBlock, Natten3DSelfAttention, PatchEmbed3D, FinalLayer3D) remain
importable from their submodules but are no longer part of the public surface.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Per review, replace the four separate Strata-related Added entries (StereoRoPE,
DiT3D, PixelDiT, exported layers) with a single bullet describing the Strata
architecture and the experimental continuous/stereographic RoPE helpers, using
the renamed classes and the experimental RoPE location.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Found by an adversarial self-review pass; doc/hygiene only, no behavior change:

- strata.py: restore the PixelDiT paper citation in the Notes section — the
  rename had turned "adapted from PixelDiT" into "adapted from Strata", a
  self-reference pointing at the PixelDiT reference entry.
- coords.py: repoint two `:func:` cross-references (spherical_centroid /
  stereographic_projection) to physicsnemo.experimental.nn, where they now live.
- transformer.py: document the include_head=False case in the class Outputs
  section (forward returns post-block tokens, not a field).
- strata.py / transformer.py: fix import ordering (isort I001) introduced by the
  module rename (ruff excludes experimental, so CI does not flag it).
- strata.py: drop StrataPixel3DBlock from the module __all__, matching
  transformer.py and the package's models-only public surface.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The test migration dropped the assertion that pinned the length_scale
normalization ("doubling length_scale halves the coords"); the new
shape/monotonicity checks are blind to the scale exponent (a `length_scale**2`
bug passed the whole suite). Re-add the ratio check: coords at length_scale 1.0
vs 2.0 must satisfy coords_2x == coords_1x / 2.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
…count)

The non-chunked fallback warning computed a byte count (numel * itemsize) and
compared it to 2**32, but conv2d's actual limit is element-count: PyTorch's
32-bit indexing guard (canUse32BitIndexMath) raises a RuntimeError when
numel > INT_MAX (2**31 - 1). Verified empirically (torch 2.7.1, cuDNN 9.7):
fp32 and bf16 both fail at the same numel = 2**31 despite fp32 having 2x the
bytes, confirming the limit is elements, not bytes. The byte threshold was also
dtype-dependent and ~2x too conservative for fp32.

Check `x.numel() > torch.iinfo(torch.int32).max` instead, and correct the
warning text: it raises a RuntimeError (not "revert to a slow implementation"),
and the remedy is chunk_size (which keeps each conv2d call under the limit and
on the fast cuDNN path). Drop the now-unused `import math`.

Addresses a Greptile review comment (misleading byte-vs-element threshold).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Match the surrounding Added entries, which use plain prose (code names in
backticks) and no arbitrary bold.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
The constructor (via PatchEmbed3D) rejects an input shape not divisible by the
patch size, but set_tile_size updated height/width/input_shape without the same
check -- so a non-divisible tile was silently accepted and floor-truncated by
the patch-embed conv (and a caller reading model.input_shape would see a
larger shape than is actually used).

Add the matching height/width divisibility check (same messages as
PatchEmbed3D) at the top of set_tile_size so re-tiling fails fast for every
rope_mode. Add a regression test.

Addresses a Greptile review comment.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Zeyuan Hu <zeyuanh@nvidia.com>
Comment thread CHANGELOG.md
spherical-grid dependency. Also adds the supporting continuous-coordinate and
stereographic RoPE helpers to `physicsnemo.experimental.nn`
(`build_rope_cos_sin_1d_continuous`, `build_axial_rope_cos_sin_2d_continuous`,
`stereographic_projection`, `spherical_centroid`).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still overly verbose, these details can be put in documentation rather than changelog. Just listing the added public-level modules and functions is fine

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for a separate file, might as well put this in layers.py

# __init__ args); the stereographic module is fully stateless (tables are built
# per forward from the input coordinates). In every case a round-trip reproduces
# the forward exactly without any tables appearing in the checkpoint.

@pzharrington pzharrington Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale references to stereographic module should be removed, and there are only two RoPE modules now without it.

output head, so the post-block tokens of shape :math:`(B, N, E)` are
returned instead (equivalently use :meth:`forward_tokens`).
"""
_, _, d_in, h_in, w_in = x.shape

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens before input validation and could lead to "too many values to unpack" error. Since forward_tokens guarantees the shape equals the configured one, you could just use self.depth/height/width for unpatchify, and drop the early unpack here.

f"{tuple(x.shape)}"
)
if self.rope_mode_pixel == "stereographic" and pos is None:
raise ValueError(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only checks for pos is None but not shape, and the backbone validates pos.shape only when the backbone's rope_mode="stereographic". So the backbone="none"/"axial", pixel="stereographic" config never shape-checks pos; a wrong (B,2,H,W) falls into a cryptic rearrange error in build_stereographic_token_coords. Worth a one-line shape check in Strata.forward.

# Insert a heads axis so the tables broadcast over (B, heads, N, head_dim).
return cos.unsqueeze(1), sin.unsqueeze(1)

def prepare_tokens(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use jaxtyping or make private to the class

x = block(x, latent_dhw=latent_dhw, rope_tables=block_rope)
return x, latent_dhw

def unpatchify(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use jaxtyping or make private to the class


@staticmethod
def precompute_bilinear_cond(
backbone_cond: torch.Tensor,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use jaxtyping annotations

nn.init.constant_(self.pixel_patch_embed.proj.bias, 0.0)

nn.init.constant_(self.pixel_final_layer.linear.weight, 0.0)
nn.init.constant_(self.pixel_final_layer.linear.bias, 0.0)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing here ever initializes the pixel blocks' attention and MLP Linears (they keep PyTorch's nn.Linear default), intended or no? It sorta contradicts the docstring claim of matching StrataTransformer3D.initialize_weights

"Strata",
"StrataMetaData",
"StrataTransformer3D",
"StrataTransformer3DMetaData",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to export the MetaData classes

)
from .transformer import StrataTransformer3D

__all__ = ["Strata", "StrataMetaData"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here and in transformer.py, no need to export the MetaData, it's just a module-local helper

auto_grad: bool = False


class StrataPixel3DBlock(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a block, maybe move it to layers.py?

@pzharrington

Copy link
Copy Markdown
Collaborator

/ok to test 78e69d3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants