parent
							
								
									31ea2f9a9f
								
							
						
					
					
						commit
						952e517db9
					
				
					 1 changed files with 15 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -414,10 +414,12 @@ def llama_attention_forward_4_31_quantized(
 | 
			
		|||
            kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            rope_theta = self.rotary_emb.base
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "llama")
 | 
			
		||||
                                                                         "llama",
 | 
			
		||||
                                                                         rope_theta=rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -603,10 +605,12 @@ def llama_attention_forward_4_31_original(
 | 
			
		|||
            kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            rope_theta = self.rotary_emb.base
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "llama")
 | 
			
		||||
                                                                         "llama",
 | 
			
		||||
                                                                         rope_theta=rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -795,10 +799,12 @@ def llama_attention_selective_batching_forward_4_31(
 | 
			
		|||
            kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
 | 
			
		||||
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            rope_theta = self.rotary_emb.base
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "llama")
 | 
			
		||||
                                                                         "llama",
 | 
			
		||||
                                                                         rope_theta=rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -1006,10 +1012,12 @@ def llama_attention_forward_4_36_quantized(
 | 
			
		|||
                )
 | 
			
		||||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            rope_theta = self.rotary_emb.base
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "llama")
 | 
			
		||||
                                                                         "llama",
 | 
			
		||||
                                                                         rope_theta=rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -1266,10 +1274,12 @@ def llama_attention_forward_4_36_original(
 | 
			
		|||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            rope_theta = self.rotary_emb.base
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "llama")
 | 
			
		||||
                                                                         "llama",
 | 
			
		||||
                                                                         rope_theta=rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue