[NPU ] fix load logic of glm-edge models (#12698)
This commit is contained in:
		
							parent
							
								
									584c1c5373
								
							
						
					
					
						commit
						da8bcb7db1
					
				
					 1 changed files with 7 additions and 3 deletions
				
			
		| 
						 | 
					@ -182,13 +182,17 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        if hasattr(model, "config") and model.config.model_type == "glm":
 | 
					        if hasattr(model, "config") and model.config.model_type == "glm":
 | 
				
			||||||
            # convert to llama structure
 | 
					            # convert to llama structure
 | 
				
			||||||
            from .npu_models.glm_edge import convert_config, load_weights, convert_state_dict
 | 
					            from .npu_models.glm_edge import convert_config, load_weights, convert_state_dict
 | 
				
			||||||
            import json
 | 
					 | 
				
			||||||
            original_path = model.config._name_or_path
 | 
					            original_path = model.config._name_or_path
 | 
				
			||||||
            del model
 | 
					            del model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            with open(os.path.join(original_path, "config.json")) as f:
 | 
					            original_config, _ = PretrainedConfig.get_config_dict(original_path)
 | 
				
			||||||
                original_config = json.load(f)
 | 
					 | 
				
			||||||
            config = convert_config(original_config)
 | 
					            config = convert_config(original_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not os.path.isdir(original_path):
 | 
				
			||||||
 | 
					                # all model files are already cached
 | 
				
			||||||
 | 
					                from transformers.utils.hub import cached_file
 | 
				
			||||||
 | 
					                resolved_file = cached_file(original_path, "config.json")
 | 
				
			||||||
 | 
					                original_path = os.path.dirname(resolved_file)
 | 
				
			||||||
            original_state_dict = load_weights(original_path)
 | 
					            original_state_dict = load_weights(original_path)
 | 
				
			||||||
            new_dict, _ = convert_state_dict(original_state_dict, config,
 | 
					            new_dict, _ = convert_state_dict(original_state_dict, config,
 | 
				
			||||||
                                             original_config.get("partial_rotary_factor", 1.0),
 | 
					                                             original_config.get("partial_rotary_factor", 1.0),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue