Skip to content

Fix Wan batch conditioning and config handling#13693

Open
hlky wants to merge 1 commit into
huggingface:mainfrom
hlky:codex/wan-review-fixes
Open

Fix Wan batch conditioning and config handling#13693
hlky wants to merge 1 commit into
huggingface:mainfrom
hlky:codex/wan-review-fixes

Conversation

@hlky
Copy link
Copy Markdown
Contributor

@hlky hlky commented May 7, 2026

Fixes #13578.

What does this PR do?

This PR addresses the Wan systemic review findings:

  • Expands Wan i2v and Animate image embeddings, masks, latents, and conditioning tensors to batch_size * num_videos_per_prompt in per-prompt order.
  • Keeps image required where VAE conditioning still needs it, while allowing precomputed CLIP image_embeds.
  • Accepts documented Wan i2v image list/tuple inputs and validates precomputed image-embed batch sizes.
  • Uses Wan VAE config scale factors in VACE, video-to-video, and modular Wan paths.
  • Trims VACE reference latents for output_type="latent".
  • Preserves Wan Animate image processor config values.
  • Stops silently ignoring unsupported Wan video-to-video num_videos_per_prompt != 1.
  • Keeps modular Wan timesteps in scheduler precision while casting model inputs to the selected transformer dtype.
  • Adds focused fast regression tests and a .ai/pipelines.md checklist note for num_*_per_prompt expansion drift.

Tests

  • ruff check ... on touched Python files
  • ruff format --check ... on touched Python files
  • git diff --check
  • python -m compileall -q src/diffusers/pipelines/wan src/diffusers/modular_pipelines/wan tests/pipelines/wan tests/modular_pipelines/wan
  • python -m pytest tests/pipelines/wan/test_wan_image_to_video.py tests/pipelines/wan/test_wan_vace.py tests/pipelines/wan/test_wan_video_to_video.py tests/pipelines/wan/test_wan_animate.py tests/modular_pipelines/wan/test_modular_pipeline_wan.py -q -k "not test_save_load_float16"

Result: 180 passed, 21 skipped, 5 deselected, 61 warnings, 6 subtests passed.

Note: WanFLFToVideoPipelineFastTests::test_save_load_float16 fails the same way on clean upstream/main (max diff 0.04175 > 0.01), so it was excluded from the clean touched-file verification.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments

)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat_interleave(
block_state.num_videos_per_prompt, dim=0
).to(block_state.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we already cast dtype later in denoise step no?

]

@property
def intermediate_outputs(self) -> list[OutputParam]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to declare it here > this does not go beyond the loop step
https://huggingface.co/docs/diffusers/modular_diffusers/loop_sequential_pipeline_blocks#loop-blocks

hidden_states=block_state.latent_model_input.to(block_state.dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
hidden_states=block_state.latent_model_input.to(dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dtype update here makes sense,

Keeps modular Wan timesteps in scheduler precision while casting model inputs to the selected transformer dtype.

I also liked we use transformer.dtype instead of prompt_embeds.dtype -> should be the stanrdard preferred way of infer dtype moving forward

But most of the rest of the change on inputs/outputs does not seem valid to me, loop step works a bit differently https://huggingface.co/docs/diffusers/modular_diffusers/loop_sequential_pipeline_blocks#loop-blocks; let me know if I missed anything


return mask_lat_size

def _expand_tensor_to_effective_batch(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it a regular function?

" only forward one of the two."
)
if image is None and image_embeds is None:
if image is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think the list of image make sense here for I2V pipeline, no?
Maybe we should just fix the docstring instead?

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I reviewed the bits around batch expansion. LMK if the comments make sense.

Comment thread .ai/pipelines.md

5. **Copying a method from another pipeline without `# Copied from`.** When you reuse a method like `encode_prompt`, `prepare_latents`, `check_inputs`, or `_prepare_latent_image_ids` from another pipeline, add a `# Copied from` annotation so `make fix-copies` keeps the two in sync. Forgetting it means future refactors to the source drift away from your copy silently — and reviewers waste time spotting near-identical code that should have been linked. The annotation grammar (decorator placement, rename syntax with `with old->new`, etc.) is implemented in [`utils/check_copies.py`](../utils/check_copies.py) — read it for the exact rules.

6. **Partial batch expansion with `num_*_per_prompt`.** When a pipeline accepts `num_images_per_prompt`, `num_videos_per_prompt`, or precomputed conditioning tensors, every per-prompt input that reaches the denoising loop must be expanded to the same effective batch size in the same prompt order. Check prompt embeds, pooled embeds, image/video embeds, masks, control latents, image latents, guidance tensors, and any added conditioning. Prefer `torch.repeat_interleave(tensor, repeats=num_per_prompt, dim=0, output_size=tensor.shape[0] * num_per_prompt)` for contiguous per-prompt grouping, and validate precomputed tensors that are already batched instead of letting a later concat fail. Avoid `tensor.repeat((num_per_prompt, ...))` for batch expansion because it interleaves prompts (`[A, B] -> [A, B, A, B]`) rather than grouping them (`[A, A, B, B]`).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

every per-prompt input that reaches the denoising loop must be expanded to the same effective batch size in the same prompt order.

Should we try to include an example here?

Check prompt embeds, pooled embeds, image/video embeds, masks, control latents, image latents, guidance tensors, and any added conditioning.

Should we be specific in terms of what we want it to check?

Comment on lines -73 to -77
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we still keep the ValueError, though?

Comment on lines -387 to -395
if image is not None and image_embeds is not None:
raise ValueError(
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
" only forward one of the two."
)
if image is None and image_embeds is None:
raise ValueError(
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to remove this check?

Update: nevermind we need image for VAE encodings. In that case, should we tell the users in case image is None and image_embeds is not None?

negative_prompt_embeds: torch.Tensor | None = None,
image_embeds: torch.Tensor | None = None,
output_type: str | None = "np",
output_type: str = "np",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Optional): while we're at it, we could also do a Literal and help make it easier for users to understand what values are allowed here.

tensor: torch.Tensor,
batch_size: int,
num_videos_per_prompt: int,
tensor_name: str,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tensor_name is a nice-to-have rather than must-have. We could make it optional or completely remove it.

)
with torch.no_grad():
image_embeds = pipe.encode_image(inputs["image"], device)
inputs["image_embeds"] = torch.cat([image_embeds, image_embeds + 1.0], dim=0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not be done within the pipeline itself? If so, I would just do video = pipe(**inputs).frames and test image_embeds separately.

self.assertEqual(video.shape, (6, 16, 1, 2, 2))

def test_image_embeds_invalid_batch_size_raises(self):
device = "cpu"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

def test_last_image_embeds_expand_per_prompt_order(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.transformer.condition_embedder.image_embedder.pos_embed = torch.nn.Parameter(torch.zeros(1, 2, 4))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

pipe = self.pipeline_class(**components)
pipe.transformer.condition_embedder.image_embedder.pos_embed = torch.nn.Parameter(torch.zeros(1, 2, 4))

image_embeds = torch.tensor([0.0, 1.0, 10.0, 11.0]).reshape(4, 1, 1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give 4, 1, and 1 variable names for easier parsing of info.

)

expected = torch.tensor([0.0, 1.0, 0.0, 1.0, 10.0, 11.0, 10.0, 11.0])
torch.testing.assert_close(image_embeds[:, 0, 0], expected)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful have to have some comments around how this is the expected order.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

wan model/pipeline review

3 participants