Skip to content

Commit e945259

Browse files
authored
Optimize redundant code
1 parent ca1c215 commit e945259

File tree

1 file changed

+44
-28
lines changed

1 file changed

+44
-28
lines changed

‎TeaCache4Lumina2/teacache_lumina2.py‎

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,30 +23,30 @@ def teacache_forward_working(
2323
lora_scale = attention_kwargs.pop("scale", 1.0)
2424
else:
2525
lora_scale = 1.0
26-
if USE_PEFT_BACKEND:
26+
if USE_PEFT_BACKEND:
2727
scale_lora_layers(self, lora_scale)
2828

2929
batch_size, _, height, width = hidden_states.shape
3030
temb, encoder_hidden_states_processed = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
3131
(image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb,
3232
encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask)
3333
image_patch_embeddings = self.x_embedder(image_patch_embeddings)
34-
for layer in self.context_refiner:
34+
for layer in self.context_refiner:
3535
encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb)
36-
for layer in self.noise_refiner:
36+
for layer in self.noise_refiner:
3737
image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb)
3838

3939
max_seq_len = max(seq_lengths)
4040
input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
4141
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
4242
input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len]
4343
input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i]
44-
44+
4545
use_mask = len(set(seq_lengths)) > 1
4646
attention_mask_for_main_loop_arg = None
4747
if use_mask:
4848
mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
49-
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
49+
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
5050
mask[i, :seq_len_val] = True
5151
attention_mask_for_main_loop_arg = mask
5252

@@ -59,9 +59,9 @@ def teacache_forward_working(
5959
"previous_modulated_input": None,
6060
"previous_residual": None,
6161
}
62-
62+
6363
current_cache = self.cache[cache_key]
64-
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop.clone(), temb.clone())
64+
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop, temb)
6565

6666
if self.cnt == 0 or self.cnt == self.num_steps - 1:
6767
should_calc = True
@@ -72,42 +72,54 @@ def teacache_forward_working(
7272
rescale_func = np.poly1d(coefficients)
7373
prev_mod_input = current_cache["previous_modulated_input"]
7474
prev_mean = prev_mod_input.abs().mean()
75-
75+
7676
if prev_mean.item() > 1e-9:
7777
rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item()
7878
else:
7979
rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf')
80-
80+
8181
current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change)
8282

8383
if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh:
8484
should_calc = False
8585
else:
8686
should_calc = True
8787
current_cache["accumulated_rel_l1_distance"] = 0.0
88-
else:
88+
else:
8989
should_calc = True
9090
current_cache["accumulated_rel_l1_distance"] = 0.0
9191

9292
current_cache["previous_modulated_input"] = modulated_inp.clone()
93-
94-
if not hasattr(self, 'uncond_seq_len'):
93+
94+
if self.uncond_seq_len is None:
9595
self.uncond_seq_len = cache_key
9696
if cache_key != self.uncond_seq_len:
97-
self.cnt += 1
98-
if self.cnt >= self.num_steps:
99-
self.cnt = 0
97+
self.cnt += 1
98+
if self.cnt >= self.num_steps:
99+
self.cnt = 0
100100

101101
if self.enable_teacache and not should_calc:
102-
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
103-
else:
104-
ori_input = input_to_main_loop.clone()
102+
if max_seq_len in self.cache and "previous_residual" in self.cache[max_seq_len] and self.cache[max_seq_len]["previous_residual"] is not None:
103+
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
104+
else:
105+
should_calc = True
106+
current_processing_states = input_to_main_loop
107+
for layer in self.layers:
108+
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
109+
processed_hidden_states = current_processing_states
110+
111+
112+
if not (self.enable_teacache and not should_calc) :
105113
current_processing_states = input_to_main_loop
106114
for layer in self.layers:
107115
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
108-
116+
109117
if self.enable_teacache:
110-
self.cache[max_seq_len]["previous_residual"] = current_processing_states - ori_input
118+
if max_seq_len in self.cache:
119+
self.cache[max_seq_len]["previous_residual"] = current_processing_states - input_to_main_loop
120+
else:
121+
logger.warning(f"TeaCache: Cache key {max_seq_len} not found when trying to save residual.")
122+
111123
processed_hidden_states = current_processing_states
112124

113125
output_after_norm = self.norm_out(processed_hidden_states, temb)
@@ -123,10 +135,13 @@ def teacache_forward_working(
123135
final_output_list.append(reconstructed_image)
124136

125137
final_output_tensor = torch.stack(final_output_list, dim=0)
126-
127-
if USE_PEFT_BACKEND:
138+
139+
if USE_PEFT_BACKEND:
128140
unscale_lora_layers(self, lora_scale)
129-
141+
142+
if not return_dict:
143+
return (final_output_tensor,)
144+
130145
return Transformer2DModelOutput(sample=final_output_tensor)
131146

132147

@@ -137,8 +152,8 @@ def teacache_forward_working(
137152
ckpt_path, torch_dtype=torch.bfloat16
138153
)
139154
pipeline = Lumina2Pipeline.from_pretrained(
140-
"Alpha-VLLM/Lumina-Image-2.0",
141-
transformer=transformer,
155+
"Alpha-VLLM/Lumina-Image-2.0",
156+
transformer=transformer,
142157
torch_dtype=torch.bfloat16
143158
).to("cuda")
144159

@@ -151,9 +166,9 @@ def teacache_forward_working(
151166
pipeline.transformer.__class__.enable_teacache = True
152167
pipeline.transformer.__class__.cnt = 0
153168
pipeline.transformer.__class__.num_steps = num_inference_steps
154-
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # taken from teacache_lumina_next.py, 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
169+
pipeline.transformer.__class__.rel_l1_thresh = 0.3
155170
pipeline.transformer.__class__.cache = {}
156-
pipeline.transformer.__class__.uncond_seq_len = None
171+
pipeline.transformer.__class__.uncond_seq_len = None
157172

158173

159174
pipeline.enable_model_cpu_offload()
@@ -163,4 +178,5 @@ def teacache_forward_working(
163178
generator=torch.Generator("cuda").manual_seed(seed)
164179
).images[0]
165180

166-
image.save(output_filename)
181+
image.save(output_filename)
182+
print(f"Image saved to {output_filename}")

0 commit comments

Comments
 (0)