Skip to content

got "The size of tensor a (xxx) must match the size of tensor b (xxx) at non-singleton dimension 0" error when i use rescale to calculate the rel_l1_distance. #84

@AuZhoomLee

Description

@AuZhoomLee

你好,感谢你的工作!如果能按需达到效果将会非常棒!

我想使用tea cache加速一个特殊的生图模型,看起来到推理最后一步的时候,发现要计算的两个张量长度不一样。

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:12<00:00,  1.65it/s]
  0%|                                                                                                                                                                        | 0/20 [00:00<?, ?it/s]
tea cache forward runing...
  0%|                                                                                                                                                                        | 0/20 [00:00<?, ?it/s]
Rank 0: 任务执行失败: The size of tensor a (5996) must match the size of tensor b (2198) at non-singleton dimension 0

...

    self.accumulated_rel_l1_distance += rescale_func(((modulated_input-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
                                                       ~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (5996) must match the size of tensor b (2198) at non-singleton dimension 0

请问这里是不是因为modulated_input的计算有一些问题,我参考了flux和lumina2的实现:

modulated_inp, gate_msa, scale_mlp, gate_mlp = self.layers[0].norm1(inp, temb_) # lumina_next
modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_) #flux
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop, temb) #lumina2

想确认一下用什么方式通过hidden_states来构造modulated_inp呢?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions