From 3b6372ab120c9c0e7f51bf6e200b23f74b3ea775 Mon Sep 17 00:00:00 2001 From: Jiao Wang Date: Mon, 8 Jan 2024 00:32:23 -0800 Subject: [PATCH] Fix Llama transformers 4.36 support (#9852) * supoort 4.36 * style * update * update * update * fix merge * update --- .../bigdl/llm/transformers/models/llama.py | 19 ++++++++++--------- .../bigdl/llm/transformers/models/utils.py | 5 +++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 4b7dc3ad..11473803 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -531,17 +531,9 @@ def llama_attention_forward_4_36( device = hidden_states.device # for flash attention original_dtype = hidden_states.dtype - if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(hidden_states) - else: - fsdp_flag = False - if fsdp_flag: - attention_dtype = torch.float16 # use fp16 for flash attention - else: - attention_dtype = original_dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) qtype = getattr(self.q_proj, "qtype", None) is_q4_0 = qtype == SYM_INT4 no_tp = not self.config.pretraining_tp > 1 @@ -664,6 +656,15 @@ def llama_attention_forward_4_36( past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states + if not self.training and not hidden_states.requires_grad: + fsdp_flag = use_flash_attention(query_states, key_states) + else: + fsdp_flag = False + if fsdp_flag: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, dtype=attention_dtype) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 8502ec63..2bee81f7 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -171,10 +171,11 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): f"{model_family} is not supported.") -def is_enough_kv_cache_room_4_36(past_key_value, idx): +def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1): # to determinate if is enough kv cache room in transformers==4.36 return past_key_value is not None and len(past_key_value.key_cache) > idx and \ - past_key_value.key_cache[idx].stride()[1] > past_key_value.key_cache[idx].size(2) * \ + past_key_value.key_cache[idx].stride()[1] > \ + (past_key_value.key_cache[idx].size(2) + seq_len - 1) * \ past_key_value.key_cache[idx].size(3)