diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 78a77ebcfea9..363546dda67d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...image_processor import IPAdapterMaskProcessor from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph @@ -244,28 +245,100 @@ def __call__( # IP-adapter ip_attn_output = torch.zeros_like(hidden_states) - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, list): + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if mask is None: + continue + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + num_ip_images = 1 if ip_state.ndim == 3 else ip_state.shape[1] + if mask.shape[1] != num_ip_images: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({num_ip_images}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) - ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) - - current_ip_hidden_states = dispatch_attention_fn( - ip_query, - ip_key, - ip_value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) - ip_attn_output += scale * current_ip_hidden_states + if mask is not None: + if current_ip_hidden_states.ndim == 3: + current_ip_hidden_states = current_ip_hidden_states[:, None, :, :] + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + for i in range(mask.shape[1]): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) + ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) + + _current_ip_hidden_states = dispatch_attention_fn( + ip_query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + _current_ip_hidden_states = _current_ip_hidden_states.reshape( + batch_size, -1, attn.heads * attn.head_dim + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(ip_query.dtype) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + mask_downsample = mask_downsample.to(dtype=ip_query.dtype, device=ip_query.device) + + ip_attn_output += scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) + ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) + + current_ip_hidden_states = dispatch_attention_fn( + ip_query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + current_ip_hidden_states = current_ip_hidden_states.reshape( + batch_size, -1, attn.heads * attn.head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states return hidden_states, encoder_hidden_states, ip_attn_output else: diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index c28154775f5a..e6864f3a68c5 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -17,13 +17,13 @@ import numpy as np import torch -from ...pipelines import FluxPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import FluxModularPipeline +from .pipeline_helpers import pack_latents, prepare_latent_image_ids logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -390,7 +390,7 @@ def prepare_latents( # TODO: move packing latents code to a patchifier similar to Qwen latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = pack_latents(latents, batch_size, num_channels_latents, height, width) return latents @@ -470,7 +470,7 @@ def intermediate_outputs(self) -> list[OutputParam]: def check_inputs(image_latents, latents): if image_latents.shape[0] != latents.shape[0]: raise ValueError( - f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}" + f"`image_latents` must have the same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}" ) if image_latents.ndim != 3: @@ -541,7 +541,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) - block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + block_state.img_ids = prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) self.set_block_state(state, block_state) @@ -598,15 +598,13 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip ): image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2)) image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2)) - img_ids = FluxPipeline._prepare_latent_image_ids( - None, image_latent_height // 2, image_latent_width // 2, device, dtype - ) + img_ids = prepare_latent_image_ids(None, image_latent_height // 2, image_latent_width // 2, device, dtype) # image ids are the same as latent ids with the first dimension set to 1 instead of 0 img_ids[..., 0] = 1 height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) - latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + latent_ids = prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) if img_ids is not None: latent_ids = torch.cat([latent_ids, img_ids], dim=0) diff --git a/src/diffusers/modular_pipelines/flux/decoders.py b/src/diffusers/modular_pipelines/flux/decoders.py index 5da861e78fcb..b2febbebfd17 100644 --- a/src/diffusers/modular_pipelines/flux/decoders.py +++ b/src/diffusers/modular_pipelines/flux/decoders.py @@ -24,27 +24,12 @@ from ...video_processor import VaeImageProcessor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .pipeline_helpers import unpack_latents logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - - class FluxDecodeStep(ModularPipelineBlocks): model_name = "flux" @@ -95,7 +80,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState: if not block_state.output_type == "latent": latents = block_state.latents - latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor) + latents = unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor) latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor block_state.images = vae.decode(latents, return_dict=False)[0] block_state.images = components.image_processor.postprocess( diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index 583c139ff22e..429a85b9fa6e 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -26,6 +26,7 @@ from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import FluxModularPipeline +from .pipeline_helpers import PREFERRED_KONTEXT_RESOLUTIONS if is_ftfy_available(): @@ -170,8 +171,6 @@ def intermediate_outputs(self) -> list[OutputParam]: @torch.no_grad() def __call__(self, components: FluxModularPipeline, state: PipelineState): - from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS - block_state = self.get_block_state(state) images = block_state.image diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py index 9d2f69dbe26f..f5834b6ae0bf 100644 --- a/src/diffusers/modular_pipelines/flux/inputs.py +++ b/src/diffusers/modular_pipelines/flux/inputs.py @@ -15,14 +15,11 @@ import torch -from ...pipelines import FluxPipeline from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import InputParam, OutputParam - -# TODO: consider making these common utilities for modular if they are not pipeline-specific. -from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size from .modular_pipeline import FluxModularPipeline +from .pipeline_helpers import calculate_dimension_from_latents, pack_latents, repeat_tensor_to_batch_size logger = logging.get_logger(__name__) @@ -209,7 +206,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip # 2. Patchify the image latent tensor # TODO: Implement patchifier for Flux. latent_height, latent_width = image_latent_tensor.shape[2:] - image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor = pack_latents( image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width ) @@ -266,7 +263,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip # 2. Patchify the image latent tensor # TODO: Implement patchifier for Flux. latent_height, latent_width = image_latent_tensor.shape[2:] - image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor = pack_latents( image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width ) diff --git a/src/diffusers/modular_pipelines/flux/pipeline_helpers.py b/src/diffusers/modular_pipelines/flux/pipeline_helpers.py new file mode 100644 index 000000000000..722cfd4751bf --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/pipeline_helpers.py @@ -0,0 +1,110 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +PREFERRED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +# Copied from diffusers.pipelines.flux.pipeline_flux +def prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + +# Copied from diffusers.pipelines.flux.pipeline_flux +def pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + +# Copied from diffusers.pipelines.flux.pipeline_flux +def unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + +# Copied from diffusers.modular_pipelines.qwenimage.inputs.calculate_dimension_from_latents +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> tuple[int, int]: + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}") + + latent_height, latent_width = latents.shape[-2:] + + height = latent_height * vae_scale_factor + width = latent_width * vae_scale_factor + + return height, width + + +# Copied from diffusers.modular_pipelines.qwenimage.inputs.repeat_tensor_to_batch_size +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError(f"`{input_name}` must have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}") + + return input_tensor.repeat_interleave(repeat_by, dim=0) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index e125924adf7f..b57e4d44dd18 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -491,6 +491,14 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index d8dcdfcd4640..5317daf4c714 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -860,7 +860,10 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index fdaff9b0af8a..34b2c12e1f05 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -454,7 +454,7 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0: + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index cadff7736ff4..ea99c1af85a6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -899,7 +899,10 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index b8ce25a4f5a9..553d64dacec1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -974,7 +974,10 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index f4bbe42ef850..a025ae085379 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -542,6 +542,14 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 330e2623b287..03bd060f8726 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -13,6 +13,7 @@ # limitations under the License. +import numpy as np import torch from PIL import Image from transformers import ( @@ -47,6 +48,17 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def _get_image_batch_size(image): + if isinstance(image, Image.Image): + return 1 + if isinstance(image, list): + return len(image) + if isinstance(image, (np.ndarray, torch.Tensor)): + return image.shape[0] if image.ndim == 4 else 1 + return 1 + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -164,19 +176,39 @@ def check_inputs( raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)): + + image_batch_size = _get_image_batch_size(image) + if isinstance(prompt, list) and len(prompt) != image_batch_size: raise ValueError( - f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images" + f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {image_batch_size} images" + ) + if isinstance(prompt_2, list) and len(prompt_2) != image_batch_size: + raise ValueError( + f"number of prompt_2 prompts must be equal to number of images, but {len(prompt_2)} prompts were provided and {image_batch_size} images" ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) - if isinstance(prompt_embeds_scale, list) and ( - isinstance(image, list) and len(prompt_embeds_scale) != len(image) + if prompt_embeds is not None and prompt_embeds.shape[0] != image_batch_size: + raise ValueError( + f"`prompt_embeds` batch size must be equal to number of images, but {prompt_embeds.shape[0]} prompt embeds were provided and {image_batch_size} images" + ) + if ( + prompt_embeds is not None + and pooled_prompt_embeds is not None + and pooled_prompt_embeds.shape[0] != image_batch_size ): raise ValueError( - f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images" + f"`pooled_prompt_embeds` batch size must be equal to number of images, but {pooled_prompt_embeds.shape[0]} pooled prompt embeds were provided and {image_batch_size} images" + ) + if isinstance(prompt_embeds_scale, list) and len(prompt_embeds_scale) != image_batch_size: + raise ValueError( + f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {image_batch_size} images" + ) + if isinstance(pooled_prompt_embeds_scale, list) and len(pooled_prompt_embeds_scale) != image_batch_size: + raise ValueError( + f"number of pooled weights must be equal to number of images, but {len(pooled_prompt_embeds_scale)} weights were provided and {image_batch_size} images" ) def encode_image(self, image, device, num_images_per_prompt): @@ -421,12 +453,7 @@ def __call__( ) # 2. Define call parameters - if image is not None and isinstance(image, Image.Image): - batch_size = 1 - elif image is not None and isinstance(image, list): - batch_size = len(image) - else: - batch_size = image.shape[0] + batch_size = _get_image_batch_size(image) if prompt is not None and isinstance(prompt, str): prompt = batch_size * [prompt] if isinstance(prompt_embeds_scale, float): diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index e4e91e52fb80..658068444ca6 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -20,7 +20,7 @@ from diffusers import FluxTransformer2DModel from diffusers.models.embeddings import ImageProjection -from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor +from diffusers.models.transformers.transformer_flux import FluxAttention, FluxIPAdapterAttnProcessor from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device @@ -272,6 +272,83 @@ def modify_inputs_for_ip_adapter(self, model, inputs_dict): def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: return create_flux_ip_adapter_state_dict(model) + def test_ip_adapter_masks_are_applied(self): + torch.manual_seed(0) + attn = FluxAttention( + query_dim=4, + heads=1, + dim_head=4, + added_kv_proj_dim=4, + processor=FluxIPAdapterAttnProcessor(hidden_size=4, cross_attention_dim=4, num_tokens=(2,), scale=10.0), + ).to(torch_device) + + hidden_states = torch.randn(1, 4, 4, device=torch_device) + encoder_hidden_states = torch.randn(1, 2, 4, device=torch_device) + ip_hidden_states = [torch.randn(1, 1, 2, 4, device=torch_device)] + zero_mask = [torch.zeros(1, 1, 2, 2, device=torch_device)] + one_mask = [torch.ones(1, 1, 2, 2, device=torch_device)] + + zero_mask_ip_output = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ip_hidden_states=ip_hidden_states, + ip_adapter_masks=zero_mask, + )[2] + one_mask_ip_output = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ip_hidden_states=ip_hidden_states, + ip_adapter_masks=one_mask, + )[2] + + assert torch.allclose(zero_mask_ip_output, torch.zeros_like(zero_mask_ip_output), atol=1e-6) + assert not torch.allclose(zero_mask_ip_output, one_mask_ip_output) + + def test_ip_adapter_masks_validate_num_images(self): + torch.manual_seed(0) + attn = FluxAttention( + query_dim=4, + heads=1, + dim_head=4, + added_kv_proj_dim=4, + processor=FluxIPAdapterAttnProcessor(hidden_size=4, cross_attention_dim=4, num_tokens=(2,), scale=10.0), + ).to(torch_device) + + hidden_states = torch.randn(1, 4, 4, device=torch_device) + encoder_hidden_states = torch.randn(1, 2, 4, device=torch_device) + ip_hidden_states = [torch.randn(1, 1, 2, 4, device=torch_device)] + mismatched_mask = [torch.ones(1, 2, 2, 2, device=torch_device)] + + with pytest.raises(ValueError, match="Number of masks"): + attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ip_hidden_states=ip_hidden_states, + ip_adapter_masks=mismatched_mask, + ) + + def test_ip_adapter_no_mask_keeps_3d_hidden_state_support(self): + torch.manual_seed(0) + attn = FluxAttention( + query_dim=4, + heads=1, + dim_head=4, + added_kv_proj_dim=4, + processor=FluxIPAdapterAttnProcessor(hidden_size=4, cross_attention_dim=4, num_tokens=(2,), scale=10.0), + ).to(torch_device) + + hidden_states = torch.randn(1, 4, 4, device=torch_device) + encoder_hidden_states = torch.randn(1, 2, 4, device=torch_device) + ip_hidden_states = [torch.randn(1, 2, 4, device=torch_device)] + + ip_output = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ip_hidden_states=ip_hidden_states, + )[2] + + assert ip_output.shape == hidden_states.shape + class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): """LoRA adapter tests for Flux Transformer.""" diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py index 05fe16e372ec..01471630649a 100644 --- a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py +++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py @@ -14,6 +14,7 @@ # limitations under the License. import random +from pathlib import Path import numpy as np import PIL @@ -27,6 +28,11 @@ FluxModularPipeline, ModularPipeline, ) +from diffusers.modular_pipelines.flux.pipeline_helpers import ( + pack_latents, + prepare_latent_image_ids, + unpack_latents, +) from ...testing_utils import floats_tensor, torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin @@ -45,6 +51,31 @@ } +def test_flux_modular_blocks_do_not_import_classic_or_qwenimage_helpers(): + flux_dir = Path(__file__).parents[3] / "src" / "diffusers" / "modular_pipelines" / "flux" + offenders = [] + + for path in sorted(flux_dir.glob("*.py")): + for line_no, line in enumerate(path.read_text().splitlines(), 1): + if "from ...pipelines" in line or "from ..qwenimage" in line or "FluxPipeline._" in line: + offenders.append(f"{path.relative_to(flux_dir)}:{line_no}: {line.strip()}") + + assert offenders == [] + + +def test_flux_modular_latent_helpers_roundtrip(): + latents = torch.arange(1 * 4 * 4 * 4, dtype=torch.float32).reshape(1, 4, 4, 4) + + packed = pack_latents(latents, batch_size=1, num_channels_latents=4, height=4, width=4) + unpacked = unpack_latents(packed, height=32, width=32, vae_scale_factor=8) + + assert torch.equal(unpacked, latents) + + latent_ids = prepare_latent_image_ids(None, 2, 2, torch.device("cpu"), torch.float32) + expected_ids = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]], dtype=torch.float32) + assert torch.equal(latent_ids, expected_ids) + + class TestFluxModularPipelineFast(ModularPipelineTesterMixin): pipeline_class = FluxModularPipeline pipeline_blocks_class = FluxAutoBlocks diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 8607cd6944d9..26bd351f5748 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -208,6 +208,32 @@ def test_flux_image_output_shape(self): output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + def test_true_cfg_with_negative_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device, dtype=torch.float16) + + prompt_embeds = torch.zeros(1, 48, 32, device=torch_device, dtype=torch.float16) + pooled_prompt_embeds = torch.zeros(1, 32, device=torch_device, dtype=torch.float16) + + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("prompt") + inputs.pop("generator") + inputs.update( + { + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "negative_prompt_embeds": torch.zeros_like(prompt_embeds), + "negative_pooled_prompt_embeds": torch.zeros_like(pooled_prompt_embeds), + "true_cfg_scale": 2.0, + } + ) + output_with_zero_negative = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + + inputs["negative_prompt_embeds"] = torch.ones_like(prompt_embeds) + inputs["negative_pooled_prompt_embeds"] = torch.ones_like(pooled_prompt_embeds) + output_with_one_negative = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + + assert not np.allclose(output_with_zero_negative, output_with_one_negative) + @nightly @require_big_accelerator diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index a4749188dfd8..bc9615b1f6e5 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -11,9 +11,10 @@ FluxControlNetModel, FluxTransformer2DModel, ) +from diffusers.pipelines.flux.pipeline_flux_controlnet_image_to_image import logger as controlnet_img2img_logger from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import torch_device +from ...testing_utils import CaptureLogger, torch_device from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist @@ -216,3 +217,38 @@ def test_flux_image_output_shape(self): image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + + def test_dimension_check_uses_packed_latent_multiple(self): + pipe = object.__new__(self.pipeline_class) + pipe.vae_scale_factor = 8 + pipe._callback_tensor_inputs = ["latents", "prompt_embeds", "control_image"] + + with CaptureLogger(controlnet_img2img_logger) as cap_logger: + pipe.check_inputs( + prompt="x", + prompt_2=None, + strength=0.5, + height=72, + width=72, + callback_on_step_end_tensor_inputs=["latents"], + prompt_embeds=None, + pooled_prompt_embeds=None, + max_sequence_length=48, + ) + + assert "have to be divisible by 16" in cap_logger.out + + with CaptureLogger(controlnet_img2img_logger) as cap_logger: + pipe.check_inputs( + prompt="x", + prompt_2=None, + strength=0.5, + height=64, + width=64, + callback_on_step_end_tensor_inputs=["latents"], + prompt_embeds=None, + pooled_prompt_embeds=None, + max_sequence_length=48, + ) + + assert "have to be divisible by 16" not in cap_logger.out diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 13336f0cde9b..c4a40496b5be 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -234,6 +234,25 @@ def test_flux_true_cfg(self): np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set." ) + def test_negative_prompt_embeds_shape_mismatch_raises(self): + pipe = object.__new__(self.pipeline_class) + pipe.vae_scale_factor = 8 + pipe._callback_tensor_inputs = ["latents", "prompt_embeds"] + + with self.assertRaisesRegex(ValueError, "negative_prompt_embeds"): + pipe.check_inputs( + prompt=None, + prompt_2=None, + height=64, + width=64, + prompt_embeds=torch.zeros(2, 4, 8), + pooled_prompt_embeds=torch.zeros(2, 8), + negative_prompt_embeds=torch.zeros(1, 4, 8), + negative_pooled_prompt_embeds=torch.zeros(1, 8), + callback_on_step_end_tensor_inputs=["latents"], + max_sequence_length=48, + ) + @nightly @require_big_accelerator diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index 00587905d337..de45a52070d6 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -140,3 +140,29 @@ def test_flux_image_output_shape(self): image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + + def test_true_cfg_with_negative_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + prompt_embeds = torch.zeros(1, 48, 32, device=torch_device) + pooled_prompt_embeds = torch.zeros(1, 32, device=torch_device) + + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("prompt") + inputs.pop("generator") + inputs.update( + { + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "negative_prompt_embeds": torch.zeros_like(prompt_embeds), + "negative_pooled_prompt_embeds": torch.zeros_like(pooled_prompt_embeds), + "true_cfg_scale": 2.0, + } + ) + output_with_zero_negative = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + + inputs["negative_prompt_embeds"] = torch.ones_like(prompt_embeds) + inputs["negative_pooled_prompt_embeds"] = torch.ones_like(pooled_prompt_embeds) + output_with_one_negative = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + + assert not np.allclose(output_with_zero_negative, output_with_one_negative) diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py index 14edb9e441b5..d43390db8ac2 100644 --- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py @@ -142,3 +142,29 @@ def test_flux_image_output_shape(self): image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + + def test_true_cfg_with_negative_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + prompt_embeds = torch.zeros(1, 48, 32, device=torch_device) + pooled_prompt_embeds = torch.zeros(1, 32, device=torch_device) + + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("prompt") + inputs.pop("generator") + inputs.update( + { + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "negative_prompt_embeds": torch.zeros_like(prompt_embeds), + "negative_pooled_prompt_embeds": torch.zeros_like(pooled_prompt_embeds), + "true_cfg_scale": 2.0, + } + ) + output_with_zero_negative = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + + inputs["negative_prompt_embeds"] = torch.ones_like(prompt_embeds) + inputs["negative_pooled_prompt_embeds"] = torch.ones_like(pooled_prompt_embeds) + output_with_one_negative = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + + assert not np.allclose(output_with_zero_negative, output_with_one_negative) diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext.py b/tests/pipelines/flux/test_pipeline_flux_kontext.py index 1c018f14b522..027d0c023f52 100644 --- a/tests/pipelines/flux/test_pipeline_flux_kontext.py +++ b/tests/pipelines/flux/test_pipeline_flux_kontext.py @@ -176,3 +176,22 @@ def test_flux_true_cfg(self): inputs["true_cfg_scale"] = 2.0 true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] assert not np.allclose(no_true_cfg_out, true_cfg_out) + + def test_negative_prompt_embeds_shape_mismatch_raises(self): + pipe = object.__new__(self.pipeline_class) + pipe.vae_scale_factor = 8 + pipe._callback_tensor_inputs = ["latents", "prompt_embeds"] + + with self.assertRaisesRegex(ValueError, "negative_prompt_embeds"): + pipe.check_inputs( + prompt=None, + prompt_2=None, + height=64, + width=64, + prompt_embeds=torch.zeros(2, 4, 8), + pooled_prompt_embeds=torch.zeros(2, 8), + negative_prompt_embeds=torch.zeros(1, 4, 8), + negative_pooled_prompt_embeds=torch.zeros(1, 8), + callback_on_step_end_tensor_inputs=["latents"], + max_sequence_length=48, + ) diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py index bbeee28e6a62..618bcc86b4c1 100644 --- a/tests/pipelines/flux/test_pipeline_flux_redux.py +++ b/tests/pipelines/flux/test_pipeline_flux_redux.py @@ -17,6 +17,68 @@ ) +class FluxReduxFastTests(unittest.TestCase): + pipeline_class = FluxPriorReduxPipeline + + def test_check_inputs_rejects_tensor_image_prompt_batch_mismatch(self): + pipe = object.__new__(self.pipeline_class) + + with self.assertRaisesRegex(ValueError, "number of prompts"): + pipe.check_inputs( + image=torch.zeros(2, 3, 32, 32), + prompt=["first", "second", "third"], + prompt_2=None, + ) + + def test_check_inputs_allows_string_prompt_for_tensor_image_batch(self): + pipe = object.__new__(self.pipeline_class) + + pipe.check_inputs( + image=torch.zeros(2, 3, 32, 32), + prompt="same prompt", + prompt_2=None, + ) + + def test_check_inputs_rejects_prompt_embed_batch_mismatch(self): + pipe = object.__new__(self.pipeline_class) + + with self.assertRaisesRegex(ValueError, "prompt_embeds"): + pipe.check_inputs( + image=torch.zeros(2, 3, 32, 32), + prompt=None, + prompt_2=None, + prompt_embeds=torch.zeros(1, 4, 8), + pooled_prompt_embeds=torch.zeros(1, 8), + ) + + def test_check_inputs_rejects_prompt_scale_batch_mismatch(self): + pipe = object.__new__(self.pipeline_class) + + with self.assertRaisesRegex(ValueError, "number of weights"): + pipe.check_inputs( + image=torch.zeros(2, 3, 32, 32), + prompt=["first", "second"], + prompt_2=None, + prompt_embeds_scale=[1.0], + ) + + with self.assertRaisesRegex(ValueError, "number of pooled weights"): + pipe.check_inputs( + image=torch.zeros(2, 3, 32, 32), + prompt=["first", "second"], + prompt_2=None, + pooled_prompt_embeds_scale=[1.0], + ) + + pipe.check_inputs( + image=torch.zeros(2, 3, 32, 32), + prompt=["first", "second"], + prompt_2=None, + prompt_embeds_scale=[1.0, 1.0], + pooled_prompt_embeds_scale=[1.0, 1.0], + ) + + @slow @require_big_accelerator class FluxReduxSlowTests(unittest.TestCase):