Skip to content

Fix shared transformer masks and temporal output channels#13711

Open
taivu1998 wants to merge 1 commit intohuggingface:mainfrom
taivu1998:tdv/issue-13651-model-transformers-shared
Open

Fix shared transformer masks and temporal output channels#13711
taivu1998 wants to merge 1 commit intohuggingface:mainfrom
taivu1998:tdv/issue-13651-model-transformers-shared

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Fixes #13651.

This PR addresses the shared transformer regressions covered by issue 13651:

  • threads attention_mask and per-condition encoder_attention_mask through DualTransformer2DModel
  • exposes DualTransformer2DModel from the top-level diffusers package and refreshes the dummy PyTorch export
  • makes temporal transformer out_channels handling explicit by accepting only None or values equal to in_channels
  • keeps the discrete Transformer2DModel log-softmax path in float32 instead of promoting logits to float64
  • adds focused regression coverage for the dual transformer mask routing, top-level import, temporal channel validation, and discrete log-softmax dtype behavior

Root Cause

DualTransformer2DModel wrapped Transformer2DModel instances but did not accept or forward the attention masks that the wrapped transformer already supported, so masked cross-attention could not be preserved through the dual wrapper. The temporal transformer classes also exposed an out_channels config field that was effectively ignored by the output projection, which made unsupported channel configurations look accepted. Finally, the discrete transformer output helper promoted logits to double precision before log_softmax, which required fp64 support where float32 is sufficient.

Changes

  • Added encoder_attention_mask to DualTransformer2DModel.forward and slice it alongside encoder_hidden_states for each configured condition.
  • Passed both attention_mask and the sliced encoder_attention_mask to the wrapped Transformer2DModel.
  • Added a small CrossAttnDownBlock2D smoke test to cover the downstream wrapper path used in UNet-style blocks.
  • Exported DualTransformer2DModel in src/diffusers/__init__.py and updated dummy_pt_objects.py.
  • Normalized TransformerTemporalModel and TransformerSpatioTemporalModel channel handling so unsupported out_channels != in_channels fails clearly.
  • Removed the double-precision promotion in the vectorized discrete transformer output path.

Validation

  • PYTHONPATH=src .venv/bin/python -m pytest tests/models/transformers/test_models_dual_transformer_2d.py tests/models/transformers/test_models_transformer_temporal.py tests/models/test_layers_utils.py::Transformer2DModelTests::test_spatial_transformer_discrete -q
    • 24 passed, 35 skipped, 1 warning in 8.29s
  • uvx ruff check src/diffusers/__init__.py src/diffusers/models/transformers/dual_transformer_2d.py src/diffusers/models/transformers/transformer_2d.py src/diffusers/models/transformers/transformer_temporal.py src/diffusers/utils/dummy_pt_objects.py tests/models/test_layers_utils.py tests/models/transformers/test_models_transformer_temporal.py tests/models/transformers/test_models_dual_transformer_2d.py
  • uvx ruff format --check src/diffusers/__init__.py src/diffusers/models/transformers/dual_transformer_2d.py src/diffusers/models/transformers/transformer_2d.py src/diffusers/models/transformers/transformer_temporal.py src/diffusers/utils/dummy_pt_objects.py tests/models/test_layers_utils.py tests/models/transformers/test_models_transformer_temporal.py tests/models/transformers/test_models_dual_transformer_2d.py
  • PYTHONPATH=src .venv/bin/python utils/check_dummies.py
  • PYTHONPATH=src .venv/bin/python -c "import sys; sys.path.append('utils'); import check_inits; check_inits.check_all_inits()"
  • git diff --check

Note: running utils/check_inits.py directly in this checkout looks for src/transformers/__init__.py; the underlying check_all_inits() function passes when invoked with diffusers on PYTHONPATH.

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

model_transformers_shared model/pipeline review

1 participant