@@ -23,30 +23,30 @@ def teacache_forward_working(
23
23
lora_scale = attention_kwargs .pop ("scale" , 1.0 )
24
24
else :
25
25
lora_scale = 1.0
26
- if USE_PEFT_BACKEND :
26
+ if USE_PEFT_BACKEND :
27
27
scale_lora_layers (self , lora_scale )
28
28
29
29
batch_size , _ , height , width = hidden_states .shape
30
30
temb , encoder_hidden_states_processed = self .time_caption_embed (hidden_states , timestep , encoder_hidden_states )
31
31
(image_patch_embeddings , context_rotary_emb , noise_rotary_emb , joint_rotary_emb ,
32
32
encoder_seq_lengths , seq_lengths ) = self .rope_embedder (hidden_states , encoder_attention_mask )
33
33
image_patch_embeddings = self .x_embedder (image_patch_embeddings )
34
- for layer in self .context_refiner :
34
+ for layer in self .context_refiner :
35
35
encoder_hidden_states_processed = layer (encoder_hidden_states_processed , encoder_attention_mask , context_rotary_emb )
36
- for layer in self .noise_refiner :
36
+ for layer in self .noise_refiner :
37
37
image_patch_embeddings = layer (image_patch_embeddings , None , noise_rotary_emb , temb )
38
38
39
39
max_seq_len = max (seq_lengths )
40
40
input_to_main_loop = image_patch_embeddings .new_zeros (batch_size , max_seq_len , self .config .hidden_size )
41
41
for i , (enc_len , seq_len_val ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
42
42
input_to_main_loop [i , :enc_len ] = encoder_hidden_states_processed [i , :enc_len ]
43
43
input_to_main_loop [i , enc_len :seq_len_val ] = image_patch_embeddings [i ]
44
-
44
+
45
45
use_mask = len (set (seq_lengths )) > 1
46
46
attention_mask_for_main_loop_arg = None
47
47
if use_mask :
48
48
mask = input_to_main_loop .new_zeros (batch_size , max_seq_len , dtype = torch .bool )
49
- for i , (enc_len , seq_len_val ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
49
+ for i , (enc_len , seq_len_val ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
50
50
mask [i , :seq_len_val ] = True
51
51
attention_mask_for_main_loop_arg = mask
52
52
@@ -59,9 +59,9 @@ def teacache_forward_working(
59
59
"previous_modulated_input" : None ,
60
60
"previous_residual" : None ,
61
61
}
62
-
62
+
63
63
current_cache = self .cache [cache_key ]
64
- modulated_inp , _ , _ , _ = self .layers [0 ].norm1 (input_to_main_loop . clone () , temb . clone () )
64
+ modulated_inp , _ , _ , _ = self .layers [0 ].norm1 (input_to_main_loop , temb )
65
65
66
66
if self .cnt == 0 or self .cnt == self .num_steps - 1 :
67
67
should_calc = True
@@ -72,42 +72,54 @@ def teacache_forward_working(
72
72
rescale_func = np .poly1d (coefficients )
73
73
prev_mod_input = current_cache ["previous_modulated_input" ]
74
74
prev_mean = prev_mod_input .abs ().mean ()
75
-
75
+
76
76
if prev_mean .item () > 1e-9 :
77
77
rel_l1_change = ((modulated_inp - prev_mod_input ).abs ().mean () / prev_mean ).cpu ().item ()
78
78
else :
79
79
rel_l1_change = 0.0 if modulated_inp .abs ().mean ().item () < 1e-9 else float ('inf' )
80
-
80
+
81
81
current_cache ["accumulated_rel_l1_distance" ] += rescale_func (rel_l1_change )
82
82
83
83
if current_cache ["accumulated_rel_l1_distance" ] < self .rel_l1_thresh :
84
84
should_calc = False
85
85
else :
86
86
should_calc = True
87
87
current_cache ["accumulated_rel_l1_distance" ] = 0.0
88
- else :
88
+ else :
89
89
should_calc = True
90
90
current_cache ["accumulated_rel_l1_distance" ] = 0.0
91
91
92
92
current_cache ["previous_modulated_input" ] = modulated_inp .clone ()
93
-
94
- if not hasattr ( self , ' uncond_seq_len' ) :
93
+
94
+ if self . uncond_seq_len is None :
95
95
self .uncond_seq_len = cache_key
96
96
if cache_key != self .uncond_seq_len :
97
- self .cnt += 1
98
- if self .cnt >= self .num_steps :
99
- self .cnt = 0
97
+ self .cnt += 1
98
+ if self .cnt >= self .num_steps :
99
+ self .cnt = 0
100
100
101
101
if self .enable_teacache and not should_calc :
102
- processed_hidden_states = input_to_main_loop + self .cache [max_seq_len ]["previous_residual" ]
103
- else :
104
- ori_input = input_to_main_loop .clone ()
102
+ if max_seq_len in self .cache and "previous_residual" in self .cache [max_seq_len ] and self .cache [max_seq_len ]["previous_residual" ] is not None :
103
+ processed_hidden_states = input_to_main_loop + self .cache [max_seq_len ]["previous_residual" ]
104
+ else :
105
+ should_calc = True
106
+ current_processing_states = input_to_main_loop
107
+ for layer in self .layers :
108
+ current_processing_states = layer (current_processing_states , attention_mask_for_main_loop_arg , joint_rotary_emb , temb )
109
+ processed_hidden_states = current_processing_states
110
+
111
+
112
+ if not (self .enable_teacache and not should_calc ) :
105
113
current_processing_states = input_to_main_loop
106
114
for layer in self .layers :
107
115
current_processing_states = layer (current_processing_states , attention_mask_for_main_loop_arg , joint_rotary_emb , temb )
108
-
116
+
109
117
if self .enable_teacache :
110
- self .cache [max_seq_len ]["previous_residual" ] = current_processing_states - ori_input
118
+ if max_seq_len in self .cache :
119
+ self .cache [max_seq_len ]["previous_residual" ] = current_processing_states - input_to_main_loop
120
+ else :
121
+ logger .warning (f"TeaCache: Cache key { max_seq_len } not found when trying to save residual." )
122
+
111
123
processed_hidden_states = current_processing_states
112
124
113
125
output_after_norm = self .norm_out (processed_hidden_states , temb )
@@ -123,10 +135,13 @@ def teacache_forward_working(
123
135
final_output_list .append (reconstructed_image )
124
136
125
137
final_output_tensor = torch .stack (final_output_list , dim = 0 )
126
-
127
- if USE_PEFT_BACKEND :
138
+
139
+ if USE_PEFT_BACKEND :
128
140
unscale_lora_layers (self , lora_scale )
129
-
141
+
142
+ if not return_dict :
143
+ return (final_output_tensor ,)
144
+
130
145
return Transformer2DModelOutput (sample = final_output_tensor )
131
146
132
147
@@ -137,8 +152,8 @@ def teacache_forward_working(
137
152
ckpt_path , torch_dtype = torch .bfloat16
138
153
)
139
154
pipeline = Lumina2Pipeline .from_pretrained (
140
- "Alpha-VLLM/Lumina-Image-2.0" ,
141
- transformer = transformer ,
155
+ "Alpha-VLLM/Lumina-Image-2.0" ,
156
+ transformer = transformer ,
142
157
torch_dtype = torch .bfloat16
143
158
).to ("cuda" )
144
159
@@ -151,9 +166,9 @@ def teacache_forward_working(
151
166
pipeline .transformer .__class__ .enable_teacache = True
152
167
pipeline .transformer .__class__ .cnt = 0
153
168
pipeline .transformer .__class__ .num_steps = num_inference_steps
154
- pipeline .transformer .__class__ .rel_l1_thresh = 0.3 # taken from teacache_lumina_next.py, 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
169
+ pipeline .transformer .__class__ .rel_l1_thresh = 0.3
155
170
pipeline .transformer .__class__ .cache = {}
156
- pipeline .transformer .__class__ .uncond_seq_len = None
171
+ pipeline .transformer .__class__ .uncond_seq_len = None
157
172
158
173
159
174
pipeline .enable_model_cpu_offload ()
@@ -163,4 +178,5 @@ def teacache_forward_working(
163
178
generator = torch .Generator ("cuda" ).manual_seed (seed )
164
179
).images [0 ]
165
180
166
- image .save (output_filename )
181
+ image .save (output_filename )
182
+ print (f"Image saved to { output_filename } " )
0 commit comments