optimize qwen2 gpu memory usage again (#11435)
This commit is contained in:
parent
ab9f7f3ac5
commit
2a0f8087e3
1 changed files with 16 additions and 0 deletions
|
|
@ -357,6 +357,21 @@ def merge_qkv(module: torch.nn.Module):
|
|||
|
||||
del module.q_proj, module.k_proj, module.v_proj
|
||||
|
||||
# Qwen2 uses pre-computed rope table to accelerate rope,
|
||||
# original `cos_cached` and `sin_cached` are added by `register_buffer`,
|
||||
# so they will move to xpu during `model.to('xpu')`.
|
||||
# But gpu fuse kernel doesn't need this rope table, only cpu needs them,
|
||||
# so delete them then add them with `=`, so that they will be pinned on CPU,
|
||||
# this can save about 0.5GB gpu memory usage when running Qwen2
|
||||
if hasattr(module.rotary_emb, "cos_cached"):
|
||||
cos_cached = module.rotary_emb.cos_cached
|
||||
del module.rotary_emb.cos_cached
|
||||
module.rotary_emb.cos_cached = cos_cached
|
||||
if hasattr(module.rotary_emb, "sin_cached"):
|
||||
sin_cached = module.rotary_emb.sin_cached
|
||||
del module.rotary_emb.sin_cached
|
||||
module.rotary_emb.sin_cached = sin_cached
|
||||
|
||||
|
||||
def padding_mlp(module: torch.nn.Module):
|
||||
# for qwen 1.5 14B
|
||||
|
|
@ -433,6 +448,7 @@ def qwen2_attention_forward(
|
|||
query_states, key_states)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
cos, sin = cos.to(device), sin.to(device)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue