LLM: fix model check before attention optimization (#9149)
This commit is contained in:
		
							parent
							
								
									1a1ddc4144
								
							
						
					
					
						commit
						69942d3826
					
				
					 1 changed files with 24 additions and 23 deletions
				
			
		| 
						 | 
				
			
			@ -181,8 +181,9 @@ def optimize(model):
 | 
			
		|||
        # todo implement 4.28.0 ~ 4.30.2
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path:
 | 
			
		||||
        # chatglm-18b or chatglm2-6b
 | 
			
		||||
    if model.config.architectures[0] == "ChatGLMModel":
 | 
			
		||||
        if hasattr(model.config, "padded_vocab_size") and model.config.padded_vocab_size == 65024:
 | 
			
		||||
            # chatglm2-6b
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
 | 
			
		||||
| 
						 | 
				
			
			@ -194,7 +195,7 @@ def optimize(model):
 | 
			
		|||
            convert_forward(model,
 | 
			
		||||
                            module.CoreAttention,
 | 
			
		||||
                            core_attn_forward_8eb45c)
 | 
			
		||||
    elif "chatglm" in model.config._name_or_path:
 | 
			
		||||
        elif model.config.vocab_size == 130528:
 | 
			
		||||
            # chatglm-6b
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			@ -203,7 +204,7 @@ def optimize(model):
 | 
			
		|||
                            module.SelfAttention,
 | 
			
		||||
                            chatglm_attention_forward
 | 
			
		||||
                            )
 | 
			
		||||
    elif "mpt" in model.config._name_or_path:
 | 
			
		||||
    elif "mpt" in model.config.model_type:
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
 | 
			
		||||
        module = importlib.import_module(attention_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue