@@ -735,7 +735,86 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
735735 )
736736 return out
737737
738+ def _var_attention_qkv (q , k , v , heads , skip_reshape ):
739+ if skip_reshape :
740+ return q , k , v , q .shape [- 1 ]
741+ total_tokens , embed_dim = q .shape
742+ head_dim = embed_dim // heads
743+ return (
744+ q .view (total_tokens , heads , head_dim ),
745+ k .view (k .shape [0 ], heads , head_dim ),
746+ v .view (v .shape [0 ], heads , head_dim ),
747+ head_dim ,
748+ )
738749
750+
751+ def _var_attention_output (out , heads , head_dim , skip_output_reshape ):
752+ if skip_output_reshape :
753+ return out
754+ return out .reshape (- 1 , heads * head_dim )
755+
756+
757+ def _use_blackwell_attention ():
758+ device = model_management .get_torch_device ()
759+ if device .type != "cuda" :
760+ return False
761+ major , minor = torch .cuda .get_device_capability (device )
762+ return (major , minor ) >= (12 , 0 )
763+
764+
765+ def _validate_split_cu_seqlens (name , cu_seqlens , token_count ):
766+ if cu_seqlens .dtype not in (torch .int32 , torch .int64 ):
767+ raise ValueError (f"{ name } must use an integer dtype" )
768+ if cu_seqlens .ndim != 1 or cu_seqlens .numel () < 2 :
769+ raise ValueError (f"{ name } must be a 1D tensor with at least two offsets" )
770+ if cu_seqlens [0 ].item () != 0 :
771+ raise ValueError (f"{ name } must start at 0" )
772+ if (cu_seqlens [1 :] <= cu_seqlens [:- 1 ]).any ().item ():
773+ raise ValueError (f"{ name } must be strictly increasing" )
774+ if cu_seqlens [- 1 ].item () != token_count :
775+ raise ValueError (f"{ name } does not match token count" )
776+
777+
778+ def _split_indices (cu_seqlens ):
779+ return cu_seqlens [1 :- 1 ].to (device = "cpu" , dtype = torch .long )
780+
781+
782+ def var_attention_optimized_split (q , k , v , heads , cu_seqlens_q , cu_seqlens_k , * args , skip_reshape = False , skip_output_reshape = False , ** kwargs ):
783+ q , k , v , head_dim = _var_attention_qkv (q , k , v , heads , skip_reshape )
784+
785+ _validate_split_cu_seqlens ("cu_seqlens_q" , cu_seqlens_q , q .shape [0 ])
786+ _validate_split_cu_seqlens ("cu_seqlens_k" , cu_seqlens_k , k .shape [0 ])
787+ if cu_seqlens_k [- 1 ].item () != v .shape [0 ]:
788+ raise ValueError ("cu_seqlens_k does not match v token count" )
789+
790+ q_split_indices = _split_indices (cu_seqlens_q )
791+ k_split_indices = _split_indices (cu_seqlens_k )
792+ q_splits = torch .tensor_split (q , q_split_indices , dim = 0 )
793+ k_splits = torch .tensor_split (k , k_split_indices , dim = 0 )
794+ v_splits = torch .tensor_split (v , k_split_indices , dim = 0 )
795+ if len (q_splits ) != len (k_splits ) or len (q_splits ) != len (v_splits ):
796+ raise ValueError ("cu_seqlens_q and cu_seqlens_k must describe the same sequence count" )
797+
798+ out = []
799+ for q_i , k_i , v_i in zip (q_splits , k_splits , v_splits ):
800+ q_i = q_i .permute (1 , 0 , 2 ).unsqueeze (0 )
801+ k_i = k_i .permute (1 , 0 , 2 ).unsqueeze (0 )
802+ v_i = v_i .permute (1 , 0 , 2 ).unsqueeze (0 )
803+ out_dtype = q_i .dtype
804+ if optimized_attention is attention_sage and q_i .dtype not in (torch .float16 , torch .bfloat16 ):
805+ q_i = q_i .to (torch .bfloat16 )
806+ k_i = k_i .to (torch .bfloat16 )
807+ v_i = v_i .to (torch .bfloat16 )
808+ out_i = optimized_attention (q_i , k_i , v_i , heads , skip_reshape = True , skip_output_reshape = True )
809+ if out_i .dtype != out_dtype :
810+ out_i = out_i .to (out_dtype )
811+ out .append (out_i .squeeze (0 ).permute (1 , 0 , 2 ))
812+
813+ out = torch .cat (out , dim = 0 )
814+ return _var_attention_output (out , heads , head_dim , skip_output_reshape )
815+
816+
817+ optimized_var_attention = var_attention_optimized_split
739818optimized_attention = attention_basic
740819
741820if model_management .sage_attention_enabled ():
@@ -758,6 +837,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
758837 logging .info ("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention" )
759838 optimized_attention = attention_sub_quad
760839
840+ logging .info ("Using optimized_attention split-loop for variable-length attention" )
841+
761842optimized_attention_masked = optimized_attention
762843
763844
@@ -773,6 +854,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
773854register_attention_function ("pytorch" , attention_pytorch )
774855register_attention_function ("sub_quad" , attention_sub_quad )
775856register_attention_function ("split" , attention_split )
857+ register_attention_function ("var_attention_optimized_split" , var_attention_optimized_split )
776858
777859
778860def optimized_attention_for_device (device , mask = False , small_input = False ):
@@ -1209,5 +1291,3 @@ def forward(
12091291 x = self .proj_out (x )
12101292 out = x + x_in
12111293 return out
1212-
1213-
0 commit comments