use new rope in phi3 (#11047)
This commit is contained in:
		
							parent
							
								
									00d4410746
								
							
						
					
					
						commit
						8cae897643
					
				
					 2 changed files with 36 additions and 91 deletions
				
			
		| 
						 | 
				
			
			@ -706,6 +706,8 @@ def _optimize_pre(model):
 | 
			
		|||
        from ipex_llm.transformers.models.phi import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "phi3":
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import pre_compute_inv_freq
 | 
			
		||||
        model.apply(pre_compute_inv_freq)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import split_mlp
 | 
			
		||||
        model.apply(split_mlp)
 | 
			
		||||
    if model.config.model_type == "qwen":
 | 
			
		||||
| 
						 | 
				
			
			@ -1525,8 +1527,6 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        # for phi-3
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import su_scaled_rope_forward
 | 
			
		||||
        convert_forward(model, module.Phi3SuScaledRotaryEmbedding, su_scaled_rope_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import attention_forward
 | 
			
		||||
        convert_forward(model, module.Phi3Attention, attention_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import mlp_forward
 | 
			
		||||
| 
						 | 
				
			
			@ -1534,13 +1534,8 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        from ipex_llm.transformers.models.phi3 import model_forward_wrapper
 | 
			
		||||
        model_forward = model_forward_wrapper(module.Phi3Model.forward)
 | 
			
		||||
        convert_forward(model, module.Phi3Model, model_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached
 | 
			
		||||
        replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
 | 
			
		||||
        convert_forward(
 | 
			
		||||
            model,
 | 
			
		||||
            module.Phi3RMSNorm,
 | 
			
		||||
            phi3_rms_norm_forward)
 | 
			
		||||
        convert_forward(model, module.Phi3RMSNorm, phi3_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == 'yuan':
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,7 +38,6 @@ from torch import nn
 | 
			
		|||
 | 
			
		||||
from ipex_llm.transformers.models.utils import (
 | 
			
		||||
    rotate_half, should_use_fuse_rope,
 | 
			
		||||
    apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
)
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
			
		||||
| 
						 | 
				
			
			@ -58,46 +57,29 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 | 
			
		|||
    return q_embed, k_embed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def su_scaled_rope_forward(self, x: torch.Tensor, position_ids: torch.Tensor, seq_len=None):
 | 
			
		||||
    if self.inv_freq is None:
 | 
			
		||||
        short_ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
 | 
			
		||||
        inv_freq_shape = torch.arange(0, self.dim, 2,
 | 
			
		||||
                                      dtype=torch.int64, device=x.device).float() / self.dim
 | 
			
		||||
        self.inv_freq = 1.0 / (short_ext_factors * self.base**inv_freq_shape)
 | 
			
		||||
def pre_compute_inv_freq(module: torch.nn.Module):
 | 
			
		||||
    if module.__class__.__name__ == "Phi3RotaryEmbedding":
 | 
			
		||||
        module.inv_freq = 1.0 / (
 | 
			
		||||
            module.base **
 | 
			
		||||
            (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
 | 
			
		||||
        )
 | 
			
		||||
    elif module.__class__.__name__ == "Phi3SuScaledRotaryEmbedding":
 | 
			
		||||
        inv_freq_shape = torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim
 | 
			
		||||
        short_ext_factors = torch.tensor(module.short_factor, dtype=torch.float32)
 | 
			
		||||
        module.inv_freq = 1.0 / (short_ext_factors * module.base ** inv_freq_shape)
 | 
			
		||||
 | 
			
		||||
        long_ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
 | 
			
		||||
        self.register_buffer("long_inv_freq", None, persistent=False)
 | 
			
		||||
        self.long_inv_freq = 1.0 / (long_ext_factors * self.base**inv_freq_shape)
 | 
			
		||||
        long_ext_factors = torch.tensor(module.long_factor, dtype=torch.float32)
 | 
			
		||||
        module.register_buffer("long_inv_freq", None, persistent=False)
 | 
			
		||||
        module.long_inv_freq = 1.0 / (long_ext_factors * module.base ** inv_freq_shape)
 | 
			
		||||
 | 
			
		||||
    seq_len = seq_len if seq_len is not None else torch.max(position_ids) + 1
 | 
			
		||||
    if seq_len > self.original_max_position_embeddings:
 | 
			
		||||
        inv_freq = self.long_inv_freq
 | 
			
		||||
    else:
 | 
			
		||||
        inv_freq = self.inv_freq
 | 
			
		||||
 | 
			
		||||
    inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
 | 
			
		||||
    position_ids_expanded = position_ids[:, None, :].float()
 | 
			
		||||
 | 
			
		||||
    # Force float32 since bfloat16 loses precision on long contexts
 | 
			
		||||
    # See https://github.com/huggingface/transformers/pull/29285
 | 
			
		||||
    device_type = x.device.type
 | 
			
		||||
    device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
 | 
			
		||||
    with torch.autocast(device_type=device_type, enabled=False):
 | 
			
		||||
        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 | 
			
		||||
        emb = torch.cat((freqs, freqs), dim=-1)
 | 
			
		||||
 | 
			
		||||
        scale = self.max_position_embeddings / self.original_max_position_embeddings
 | 
			
		||||
        if scale <= 1.0:
 | 
			
		||||
            scaling_factor = 1.0
 | 
			
		||||
        if module.max_position_embeddings <= module.original_max_position_embeddings:
 | 
			
		||||
            module.scaling_factor = 1.0
 | 
			
		||||
        else:
 | 
			
		||||
            scaling_factor = math.sqrt(
 | 
			
		||||
                1 + math.log(scale) / math.log(self.original_max_position_embeddings)
 | 
			
		||||
            scale = module.max_position_embeddings / module.original_max_position_embeddings
 | 
			
		||||
            module.scaling_factor = math.sqrt(
 | 
			
		||||
                1 + math.log(scale) / math.log(module.original_max_position_embeddings)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        cos = emb.cos() * scaling_factor
 | 
			
		||||
        sin = emb.sin() * scaling_factor
 | 
			
		||||
    return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
| 
						 | 
				
			
			@ -124,12 +106,24 @@ def attention_forward(
 | 
			
		|||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
 | 
			
		||||
                                                                       sin, cos, "phi3")
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        if self.rotary_emb.__class__.__name__ == "Phi3RotaryEmbedding":     # 4k
 | 
			
		||||
            linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                             query_states, key_states)
 | 
			
		||||
        else:   # 128k
 | 
			
		||||
            if kv_seq_len > self.rotary_emb.original_max_position_embeddings:
 | 
			
		||||
                linear_q4_0.rotary_half_inplaced(self.rotary_emb.long_inv_freq, position_ids,
 | 
			
		||||
                                                 query_states, key_states)
 | 
			
		||||
            else:
 | 
			
		||||
                linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                                 query_states, key_states)
 | 
			
		||||
            # todo: fuse scaling_factor
 | 
			
		||||
            query_states *= self.rotary_emb.scaling_factor
 | 
			
		||||
            key_states *= self.rotary_emb.scaling_factor
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                        cos, sin, position_ids)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -257,50 +251,6 @@ def model_forward_wrapper(origin_model_forward):
 | 
			
		|||
    return model_forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Phi3RotaryEmbeddingCached(nn.Module):
 | 
			
		||||
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.base = base
 | 
			
		||||
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2,
 | 
			
		||||
                                                     dtype=torch.int64,
 | 
			
		||||
                                                     device=device).float() / self.dim))
 | 
			
		||||
        self.register_buffer("inv_freq", inv_freq, persistent=False)
 | 
			
		||||
        # self.gen_seq_len = None
 | 
			
		||||
 | 
			
		||||
        # Build here to make `torch.jit.trace` work.
 | 
			
		||||
        self._set_cos_sin_cache(
 | 
			
		||||
            seq_len=max_position_embeddings,
 | 
			
		||||
            device=self.inv_freq.device, dtype=torch.get_default_dtype()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _set_cos_sin_cache(self, seq_len, device, dtype):
 | 
			
		||||
        self.max_seq_len_cached = seq_len
 | 
			
		||||
        position_ids_expanded = torch.arange(self.max_seq_len_cached,
 | 
			
		||||
                                             device=device,
 | 
			
		||||
                                             dtype=self.inv_freq.dtype).reshape(1, 1, -1)
 | 
			
		||||
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1)
 | 
			
		||||
        with torch.autocast(device_type=device.type, enabled=False):
 | 
			
		||||
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 | 
			
		||||
            # Different from paper,
 | 
			
		||||
            # but it uses a different permutation in order to obtain the same calculation
 | 
			
		||||
            emb = torch.cat((freqs, freqs), dim=-1)
 | 
			
		||||
            self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
 | 
			
		||||
            self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, position_ids, seq_len=None):
 | 
			
		||||
        # x: [bs, num_attention_heads, seq_len, head_size]
 | 
			
		||||
        if seq_len > self.max_seq_len_cached:
 | 
			
		||||
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
 | 
			
		||||
            self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def phi3_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue