use config rope_theta (#10787)

* use config rope_theta

* fix style
This commit is contained in:
Yang Wang 2024-04-17 20:39:11 -07:00 committed by GitHub
parent 31ea2f9a9f
commit 952e517db9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -414,10 +414,12 @@ def llama_attention_forward_4_31_quantized(
kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
"llama",
rope_theta=rope_theta)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -603,10 +605,12 @@ def llama_attention_forward_4_31_original(
kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
"llama",
rope_theta=rope_theta)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -795,10 +799,12 @@ def llama_attention_selective_batching_forward_4_31(
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
"llama",
rope_theta=rope_theta)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -1006,10 +1012,12 @@ def llama_attention_forward_4_36_quantized(
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
"llama",
rope_theta=rope_theta)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -1266,10 +1274,12 @@ def llama_attention_forward_4_36_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
"llama",
rope_theta=rope_theta)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,