Fix Wan batch conditioning and config handling#13693
Conversation
| ) | ||
| block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat_interleave( | ||
| block_state.num_videos_per_prompt, dim=0 | ||
| ).to(block_state.dtype) |
There was a problem hiding this comment.
I think we already cast dtype later in denoise step no?
| ] | ||
|
|
||
| @property | ||
| def intermediate_outputs(self) -> list[OutputParam]: |
There was a problem hiding this comment.
I think we don't need to declare it here > this does not go beyond the loop step
https://huggingface.co/docs/diffusers/modular_diffusers/loop_sequential_pipeline_blocks#loop-blocks
| hidden_states=block_state.latent_model_input.to(block_state.dtype), | ||
| timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), | ||
| hidden_states=block_state.latent_model_input.to(dtype), | ||
| timestep=t.expand(block_state.latent_model_input.shape[0]), |
There was a problem hiding this comment.
the dtype update here makes sense,
Keeps modular Wan timesteps in scheduler precision while casting model inputs to the selected transformer dtype.
I also liked we use transformer.dtype instead of prompt_embeds.dtype -> should be the stanrdard preferred way of infer dtype moving forward
But most of the rest of the change on inputs/outputs does not seem valid to me, loop step works a bit differently https://huggingface.co/docs/diffusers/modular_diffusers/loop_sequential_pipeline_blocks#loop-blocks; let me know if I missed anything
|
|
||
| return mask_lat_size | ||
|
|
||
| def _expand_tensor_to_effective_batch( |
There was a problem hiding this comment.
can we make it a regular function?
| " only forward one of the two." | ||
| ) | ||
| if image is None and image_embeds is None: | ||
| if image is None: |
There was a problem hiding this comment.
i don't think the list of image make sense here for I2V pipeline, no?
Maybe we should just fix the docstring instead?
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks! I reviewed the bits around batch expansion. LMK if the comments make sense.
|
|
||
| 5. **Copying a method from another pipeline without `# Copied from`.** When you reuse a method like `encode_prompt`, `prepare_latents`, `check_inputs`, or `_prepare_latent_image_ids` from another pipeline, add a `# Copied from` annotation so `make fix-copies` keeps the two in sync. Forgetting it means future refactors to the source drift away from your copy silently — and reviewers waste time spotting near-identical code that should have been linked. The annotation grammar (decorator placement, rename syntax with `with old->new`, etc.) is implemented in [`utils/check_copies.py`](../utils/check_copies.py) — read it for the exact rules. | ||
|
|
||
| 6. **Partial batch expansion with `num_*_per_prompt`.** When a pipeline accepts `num_images_per_prompt`, `num_videos_per_prompt`, or precomputed conditioning tensors, every per-prompt input that reaches the denoising loop must be expanded to the same effective batch size in the same prompt order. Check prompt embeds, pooled embeds, image/video embeds, masks, control latents, image latents, guidance tensors, and any added conditioning. Prefer `torch.repeat_interleave(tensor, repeats=num_per_prompt, dim=0, output_size=tensor.shape[0] * num_per_prompt)` for contiguous per-prompt grouping, and validate precomputed tensors that are already batched instead of letting a later concat fail. Avoid `tensor.repeat((num_per_prompt, ...))` for batch expansion because it interleaves prompts (`[A, B] -> [A, B, A, B]`) rather than grouping them (`[A, A, B, B]`). |
There was a problem hiding this comment.
every per-prompt input that reaches the denoising loop must be expanded to the same effective batch size in the same prompt order.
Should we try to include an example here?
Check prompt embeds, pooled embeds, image/video embeds, masks, control latents, image latents, guidance tensors, and any added conditioning.
Should we be specific in terms of what we want it to check?
| raise ValueError( | ||
| "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," | ||
| " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", | ||
| " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", | ||
| ) |
There was a problem hiding this comment.
Should we still keep the ValueError, though?
| if image is not None and image_embeds is not None: | ||
| raise ValueError( | ||
| f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" | ||
| " only forward one of the two." | ||
| ) | ||
| if image is None and image_embeds is None: | ||
| raise ValueError( | ||
| "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." | ||
| ) |
There was a problem hiding this comment.
Why do we want to remove this check?
Update: nevermind we need image for VAE encodings. In that case, should we tell the users in case image is None and image_embeds is not None?
| negative_prompt_embeds: torch.Tensor | None = None, | ||
| image_embeds: torch.Tensor | None = None, | ||
| output_type: str | None = "np", | ||
| output_type: str = "np", |
There was a problem hiding this comment.
(Optional): while we're at it, we could also do a Literal and help make it easier for users to understand what values are allowed here.
| tensor: torch.Tensor, | ||
| batch_size: int, | ||
| num_videos_per_prompt: int, | ||
| tensor_name: str, |
There was a problem hiding this comment.
I think tensor_name is a nice-to-have rather than must-have. We could make it optional or completely remove it.
| ) | ||
| with torch.no_grad(): | ||
| image_embeds = pipe.encode_image(inputs["image"], device) | ||
| inputs["image_embeds"] = torch.cat([image_embeds, image_embeds + 1.0], dim=0) |
There was a problem hiding this comment.
Should this not be done within the pipeline itself? If so, I would just do video = pipe(**inputs).frames and test image_embeds separately.
| self.assertEqual(video.shape, (6, 16, 1, 2, 2)) | ||
|
|
||
| def test_image_embeds_invalid_batch_size_raises(self): | ||
| device = "cpu" |
| def test_last_image_embeds_expand_per_prompt_order(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| pipe.transformer.condition_embedder.image_embedder.pos_embed = torch.nn.Parameter(torch.zeros(1, 2, 4)) |
| pipe = self.pipeline_class(**components) | ||
| pipe.transformer.condition_embedder.image_embedder.pos_embed = torch.nn.Parameter(torch.zeros(1, 2, 4)) | ||
|
|
||
| image_embeds = torch.tensor([0.0, 1.0, 10.0, 11.0]).reshape(4, 1, 1) |
There was a problem hiding this comment.
Let's give 4, 1, and 1 variable names for easier parsing of info.
| ) | ||
|
|
||
| expected = torch.tensor([0.0, 1.0, 0.0, 1.0, 10.0, 11.0, 10.0, 11.0]) | ||
| torch.testing.assert_close(image_embeds[:, 0, 0], expected) |
There was a problem hiding this comment.
It would be helpful have to have some comments around how this is the expected order.
Fixes #13578.
What does this PR do?
This PR addresses the Wan systemic review findings:
batch_size * num_videos_per_promptin per-prompt order.imagerequired where VAE conditioning still needs it, while allowing precomputed CLIPimage_embeds.output_type="latent".num_videos_per_prompt != 1..ai/pipelines.mdchecklist note fornum_*_per_promptexpansion drift.Tests
ruff check ...on touched Python filesruff format --check ...on touched Python filesgit diff --checkpython -m compileall -q src/diffusers/pipelines/wan src/diffusers/modular_pipelines/wan tests/pipelines/wan tests/modular_pipelines/wanpython -m pytest tests/pipelines/wan/test_wan_image_to_video.py tests/pipelines/wan/test_wan_vace.py tests/pipelines/wan/test_wan_video_to_video.py tests/pipelines/wan/test_wan_animate.py tests/modular_pipelines/wan/test_modular_pipeline_wan.py -q -k "not test_save_load_float16"Result:
180 passed, 21 skipped, 5 deselected, 61 warnings, 6 subtests passed.Note:
WanFLFToVideoPipelineFastTests::test_save_load_float16fails the same way on cleanupstream/main(max diff 0.04175 > 0.01), so it was excluded from the clean touched-file verification.