support new model (#12523)
This commit is contained in:
		
							parent
							
								
									922958c018
								
							
						
					
					
						commit
						77404d2a63
					
				
					 2 changed files with 15 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -1049,6 +1049,10 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
            model.llm.config.model_type = "qwen2"
 | 
			
		||||
        elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
 | 
			
		||||
            model.llm.config.model_type = "llama"
 | 
			
		||||
        elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464:
 | 
			
		||||
            from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
 | 
			
		||||
            model.llm.apply(pre_compute_inv_freq)
 | 
			
		||||
            model.llm.config.model_type = "minicpm"
 | 
			
		||||
        _optimize_pre(model.llm, qtype=qtype)
 | 
			
		||||
        model.llm.config.model_type = "minicpmv"
 | 
			
		||||
    elif model.config.model_type == "chatglm":
 | 
			
		||||
| 
						 | 
				
			
			@ -2137,6 +2141,9 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
 | 
			
		||||
            # MiniCPM-V 2.5
 | 
			
		||||
            model.llm.config.model_type = "llama"
 | 
			
		||||
        elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464:
 | 
			
		||||
            # MiniCPM-V ?
 | 
			
		||||
            model.llm.config.model_type = "minicpm"
 | 
			
		||||
        _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
        model.llm.config.model_type = "minicpmv"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -100,8 +100,15 @@ def minicpm_attention_forward(
 | 
			
		|||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
        if self.rotary_emb.__class__.__name__ == "MiniCPMLongRoPE":
 | 
			
		||||
            if kv_seq_len > self.rotary_emb.original_max_position_embeddings:
 | 
			
		||||
                inv_freq = self.rotary_emb.long_inv_freq
 | 
			
		||||
            else:
 | 
			
		||||
                inv_freq = self.rotary_emb.short_inv_freq
 | 
			
		||||
        else:
 | 
			
		||||
            inv_freq = self.rotary_emb.inv_freq
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
        xe_addons.rotary_half_inplaced(inv_freq, position_ids,
 | 
			
		||||
                                       query_states, key_states)
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue