diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 708c4033..779ccd4b 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -357,6 +357,21 @@ def merge_qkv(module: torch.nn.Module): del module.q_proj, module.k_proj, module.v_proj + # Qwen2 uses pre-computed rope table to accelerate rope, + # original `cos_cached` and `sin_cached` are added by `register_buffer`, + # so they will move to xpu during `model.to('xpu')`. + # But gpu fuse kernel doesn't need this rope table, only cpu needs them, + # so delete them then add them with `=`, so that they will be pinned on CPU, + # this can save about 0.5GB gpu memory usage when running Qwen2 + if hasattr(module.rotary_emb, "cos_cached"): + cos_cached = module.rotary_emb.cos_cached + del module.rotary_emb.cos_cached + module.rotary_emb.cos_cached = cos_cached + if hasattr(module.rotary_emb, "sin_cached"): + sin_cached = module.rotary_emb.sin_cached + del module.rotary_emb.sin_cached + module.rotary_emb.sin_cached = sin_cached + def padding_mlp(module: torch.nn.Module): # for qwen 1.5 14B @@ -433,6 +448,7 @@ def qwen2_attention_forward( query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = cos.to(device), sin.to(device) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)