Skip to content

Commit 7863cf0

Browse files
authored
Add SeedVR2 support (CORE-6) (Comfy-Org#14110)
1 parent 739061d commit 7863cf0

26 files changed

Lines changed: 7383 additions & 40 deletions

comfy/latent_formats.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ class LatentFormat:
44
scale_factor = 1.0
55
latent_channels = 4
66
latent_dimensions = 2
7+
preserve_empty_channel_multiples = False
78
latent_rgb_factors = None
89
latent_rgb_factors_bias = None
910
latent_rgb_factors_reshape = None
@@ -779,6 +780,10 @@ class ACEAudio(LatentFormat):
779780
latent_channels = 8
780781
latent_dimensions = 2
781782

783+
class SeedVR2(LatentFormat):
784+
latent_channels = 16
785+
preserve_empty_channel_multiples = True
786+
782787
class ACEAudio15(LatentFormat):
783788
latent_channels = 64
784789
latent_dimensions = 1

comfy/ldm/modules/attention.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,86 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
735735
)
736736
return out
737737

738+
def _var_attention_qkv(q, k, v, heads, skip_reshape):
739+
if skip_reshape:
740+
return q, k, v, q.shape[-1]
741+
total_tokens, embed_dim = q.shape
742+
head_dim = embed_dim // heads
743+
return (
744+
q.view(total_tokens, heads, head_dim),
745+
k.view(k.shape[0], heads, head_dim),
746+
v.view(v.shape[0], heads, head_dim),
747+
head_dim,
748+
)
738749

750+
751+
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
752+
if skip_output_reshape:
753+
return out
754+
return out.reshape(-1, heads * head_dim)
755+
756+
757+
def _use_blackwell_attention():
758+
device = model_management.get_torch_device()
759+
if device.type != "cuda":
760+
return False
761+
major, minor = torch.cuda.get_device_capability(device)
762+
return (major, minor) >= (12, 0)
763+
764+
765+
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
766+
if cu_seqlens.dtype not in (torch.int32, torch.int64):
767+
raise ValueError(f"{name} must use an integer dtype")
768+
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
769+
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
770+
if cu_seqlens[0].item() != 0:
771+
raise ValueError(f"{name} must start at 0")
772+
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
773+
raise ValueError(f"{name} must be strictly increasing")
774+
if cu_seqlens[-1].item() != token_count:
775+
raise ValueError(f"{name} does not match token count")
776+
777+
778+
def _split_indices(cu_seqlens):
779+
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
780+
781+
782+
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
783+
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
784+
785+
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
786+
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
787+
if cu_seqlens_k[-1].item() != v.shape[0]:
788+
raise ValueError("cu_seqlens_k does not match v token count")
789+
790+
q_split_indices = _split_indices(cu_seqlens_q)
791+
k_split_indices = _split_indices(cu_seqlens_k)
792+
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
793+
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
794+
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
795+
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
796+
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
797+
798+
out = []
799+
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
800+
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
801+
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
802+
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
803+
out_dtype = q_i.dtype
804+
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
805+
q_i = q_i.to(torch.bfloat16)
806+
k_i = k_i.to(torch.bfloat16)
807+
v_i = v_i.to(torch.bfloat16)
808+
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
809+
if out_i.dtype != out_dtype:
810+
out_i = out_i.to(out_dtype)
811+
out.append(out_i.squeeze(0).permute(1, 0, 2))
812+
813+
out = torch.cat(out, dim=0)
814+
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
815+
816+
817+
optimized_var_attention = var_attention_optimized_split
739818
optimized_attention = attention_basic
740819

741820
if model_management.sage_attention_enabled():
@@ -758,6 +837,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
758837
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
759838
optimized_attention = attention_sub_quad
760839

840+
logging.info("Using optimized_attention split-loop for variable-length attention")
841+
761842
optimized_attention_masked = optimized_attention
762843

763844

@@ -773,6 +854,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
773854
register_attention_function("pytorch", attention_pytorch)
774855
register_attention_function("sub_quad", attention_sub_quad)
775856
register_attention_function("split", attention_split)
857+
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
776858

777859

778860
def optimized_attention_for_device(device, mask=False, small_input=False):
@@ -1209,5 +1291,3 @@ def forward(
12091291
x = self.proj_out(x)
12101292
out = x + x_in
12111293
return out
1212-
1213-

comfy/ldm/modules/diffusionmodules/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import xformers
1414
import xformers.ops
1515

16+
1617
def torch_cat_if_needed(xl, dim):
1718
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
1819
if len(xl) > 1:
@@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim):
2223
else:
2324
return None
2425

25-
def get_timestep_embedding(timesteps, embedding_dim):
26+
27+
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
2628
"""
2729
This matches the implementation in Denoising Diffusion Probabilistic Models:
2830
From Fairseq.
@@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
3335
assert len(timesteps.shape) == 1
3436

3537
half_dim = embedding_dim // 2
36-
emb = math.log(10000) / (half_dim - 1)
38+
emb = math.log(10000) / (half_dim - downscale_freq_shift)
3739
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
3840
emb = emb.to(device=timesteps.device)
3941
emb = timesteps.float()[:, None] * emb[None, :]
4042
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
43+
if flip_sin_to_cos:
44+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
4145
if embedding_dim % 2 == 1: # zero pad
4246
emb = torch.nn.functional.pad(emb, (0,1,0,0))
4347
return emb

0 commit comments

Comments
 (0)