Add phi3 cached RotaryEmbedding (#11013)
* phi3cachedrotaryembed * pep8
This commit is contained in:
		
							parent
							
								
									0b7e78b592
								
							
						
					
					
						commit
						0a732bebe7
					
				
					 2 changed files with 56 additions and 0 deletions
				
			
		| 
						 | 
					@ -848,6 +848,15 @@ def convert_forward(m, target_m, new_forward):
 | 
				
			||||||
        convert_forward(sub_m, target_m, new_forward)
 | 
					        convert_forward(sub_m, target_m, new_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def replace_RotaryEmbed(m, target_m,  replace_embed):
 | 
				
			||||||
 | 
					    for attr_name, sub_m in m.named_children():
 | 
				
			||||||
 | 
					        if isinstance(sub_m, target_m):
 | 
				
			||||||
 | 
					            setattr(m, attr_name, replace_embed(sub_m.dim,
 | 
				
			||||||
 | 
					                                                sub_m.max_position_embeddings,
 | 
				
			||||||
 | 
					                                                sub_m.base))
 | 
				
			||||||
 | 
					        replace_RotaryEmbed(sub_m, target_m, replace_embed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def replace_func(m, target_m, func_name, new_func):
 | 
					def replace_func(m, target_m, func_name, new_func):
 | 
				
			||||||
    for _, sub_m in m.named_children():
 | 
					    for _, sub_m in m.named_children():
 | 
				
			||||||
        if isinstance(sub_m, target_m):
 | 
					        if isinstance(sub_m, target_m):
 | 
				
			||||||
| 
						 | 
					@ -1517,6 +1526,8 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        from ipex_llm.transformers.models.phi3 import model_forward_wrapper
 | 
					        from ipex_llm.transformers.models.phi3 import model_forward_wrapper
 | 
				
			||||||
        model_forward = model_forward_wrapper(module.Phi3Model.forward)
 | 
					        model_forward = model_forward_wrapper(module.Phi3Model.forward)
 | 
				
			||||||
        convert_forward(model, module.Phi3Model, model_forward)
 | 
					        convert_forward(model, module.Phi3Model, model_forward)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.phi3 import Phi3RotaryEmbeddingCached
 | 
				
			||||||
 | 
					        replace_RotaryEmbed(model, module.Phi3RotaryEmbedding, Phi3RotaryEmbeddingCached)
 | 
				
			||||||
        from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
 | 
					        from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
 | 
				
			||||||
        convert_forward(
 | 
					        convert_forward(
 | 
				
			||||||
            model,
 | 
					            model,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,6 +34,7 @@
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.transformers.models.utils import (
 | 
					from ipex_llm.transformers.models.utils import (
 | 
				
			||||||
    rotate_half, should_use_fuse_rope,
 | 
					    rotate_half, should_use_fuse_rope,
 | 
				
			||||||
| 
						 | 
					@ -255,6 +256,50 @@ def model_forward_wrapper(origin_model_forward):
 | 
				
			||||||
    return model_forward
 | 
					    return model_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Phi3RotaryEmbeddingCached(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.dim = dim
 | 
				
			||||||
 | 
					        self.max_position_embeddings = max_position_embeddings
 | 
				
			||||||
 | 
					        self.base = base
 | 
				
			||||||
 | 
					        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2,
 | 
				
			||||||
 | 
					                                                     dtype=torch.int64,
 | 
				
			||||||
 | 
					                                                     device=device).float() / self.dim))
 | 
				
			||||||
 | 
					        self.register_buffer("inv_freq", inv_freq, persistent=False)
 | 
				
			||||||
 | 
					        # self.gen_seq_len = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Build here to make `torch.jit.trace` work.
 | 
				
			||||||
 | 
					        self._set_cos_sin_cache(
 | 
				
			||||||
 | 
					            seq_len=max_position_embeddings,
 | 
				
			||||||
 | 
					            device=self.inv_freq.device, dtype=torch.get_default_dtype()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _set_cos_sin_cache(self, seq_len, device, dtype):
 | 
				
			||||||
 | 
					        self.max_seq_len_cached = seq_len
 | 
				
			||||||
 | 
					        position_ids_expanded = torch.arange(self.max_seq_len_cached,
 | 
				
			||||||
 | 
					                                             device=device,
 | 
				
			||||||
 | 
					                                             dtype=self.inv_freq.dtype).reshape(1, 1, -1)
 | 
				
			||||||
 | 
					        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(1, -1, 1)
 | 
				
			||||||
 | 
					        with torch.autocast(device_type=device.type, enabled=False):
 | 
				
			||||||
 | 
					            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 | 
				
			||||||
 | 
					            # Different from paper,
 | 
				
			||||||
 | 
					            # but it uses a different permutation in order to obtain the same calculation
 | 
				
			||||||
 | 
					            emb = torch.cat((freqs, freqs), dim=-1)
 | 
				
			||||||
 | 
					            self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
 | 
				
			||||||
 | 
					            self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, position_ids, seq_len=None):
 | 
				
			||||||
 | 
					        # x: [bs, num_attention_heads, seq_len, head_size]
 | 
				
			||||||
 | 
					        if seq_len > self.max_seq_len_cached:
 | 
				
			||||||
 | 
					            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return (
 | 
				
			||||||
 | 
					            self.cos_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
 | 
				
			||||||
 | 
					            self.sin_cached[:, seq_len-position_ids.shape[-1]:seq_len, :].to(dtype=x.dtype),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def phi3_rms_norm_forward(self, hidden_states):
 | 
					def phi3_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
					    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue