Fix Llama transformers 4.36 support (#9852)

* supoort 4.36

* style

* update

* update

* update

* fix merge

* update
This commit is contained in:
Jiao Wang 2024-01-08 00:32:23 -08:00 committed by GitHub
parent 1b585b0d40
commit 3b6372ab12
2 changed files with 13 additions and 11 deletions

View file

@ -531,17 +531,9 @@ def llama_attention_forward_4_36(
device = hidden_states.device device = hidden_states.device
# for flash attention # for flash attention
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(hidden_states)
else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None) qtype = getattr(self.q_proj, "qtype", None)
is_q4_0 = qtype == SYM_INT4 is_q4_0 = qtype == SYM_INT4
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
@ -664,6 +656,15 @@ def llama_attention_forward_4_36(
past_key_value.key_cache[self.layer_idx] = key_states past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states past_key_value.value_cache[self.layer_idx] = value_states
if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(query_states, key_states)
else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype) dtype=attention_dtype)

View file

@ -171,10 +171,11 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
f"{model_family} is not supported.") f"{model_family} is not supported.")
def is_enough_kv_cache_room_4_36(past_key_value, idx): def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
# to determinate if is enough kv cache room in transformers==4.36 # to determinate if is enough kv cache room in transformers==4.36
return past_key_value is not None and len(past_key_value.key_cache) > idx and \ return past_key_value is not None and len(past_key_value.key_cache) > idx and \
past_key_value.key_cache[idx].stride()[1] > past_key_value.key_cache[idx].size(2) * \ past_key_value.key_cache[idx].stride()[1] > \
(past_key_value.key_cache[idx].size(2) + seq_len - 1) * \
past_key_value.key_cache[idx].size(3) past_key_value.key_cache[idx].size(3)