diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 00845966..bfcb50ec 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -141,34 +141,24 @@ def stablelm_attention_forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_emb.dim], - query_states[..., self.rotary_emb.dim:], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_emb.dim], - key_states[..., self.rotary_emb.dim:], - ) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - # [batch_size, num_heads, seq_length, head_dim // config.partial_rotary_factor] + # [batch_size, num_heads, seq_length, head_dim * config.partial_rotary_factor] + rot_dim = self.rotary_emb.dim if should_use_fuse_rope(hidden_states, position_ids, self.training): - query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, - key_rot, - sin, - cos, - "stablelm", - position_ids) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states[..., :rot_dim], key_states[..., :rot_dim]) else: + query_rot, query_pass = query_states[..., :rot_dim], query_states[..., rot_dim:] + key_rot, key_pass = key_states[..., :rot_dim], key_states[..., rot_dim:] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "stablelm") - - # [batch_size, num_heads, seq_length, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, None)