Fix Llama transformers 4.36 support (#9852)
* supoort 4.36 * style * update * update * update * fix merge * update
This commit is contained in:
		
							parent
							
								
									1b585b0d40
								
							
						
					
					
						commit
						3b6372ab12
					
				
					 2 changed files with 13 additions and 11 deletions
				
			
		| 
						 | 
				
			
			@ -531,17 +531,9 @@ def llama_attention_forward_4_36(
 | 
			
		|||
    device = hidden_states.device
 | 
			
		||||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = use_flash_attention(hidden_states)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
			
		||||
    else:
 | 
			
		||||
        attention_dtype = original_dtype
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
			
		||||
    qtype = getattr(self.q_proj, "qtype", None)
 | 
			
		||||
    is_q4_0 = qtype == SYM_INT4
 | 
			
		||||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
| 
						 | 
				
			
			@ -664,6 +656,15 @@ def llama_attention_forward_4_36(
 | 
			
		|||
                past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
                past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
			
		||||
    else:
 | 
			
		||||
        attention_dtype = original_dtype
 | 
			
		||||
 | 
			
		||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
 | 
			
		||||
                                                                     dtype=attention_dtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -171,10 +171,11 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
 | 
			
		|||
                          f"{model_family} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_enough_kv_cache_room_4_36(past_key_value, idx):
 | 
			
		||||
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
 | 
			
		||||
    # to determinate if is enough kv cache room in transformers==4.36
 | 
			
		||||
    return past_key_value is not None and len(past_key_value.key_cache) > idx and \
 | 
			
		||||
        past_key_value.key_cache[idx].stride()[1] > past_key_value.key_cache[idx].size(2) * \
 | 
			
		||||
        past_key_value.key_cache[idx].stride()[1] > \
 | 
			
		||||
        (past_key_value.key_cache[idx].size(2) + seq_len - 1) * \
 | 
			
		||||
        past_key_value.key_cache[idx].size(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue