Fix error while using pipeline parallism (#11434)
This commit is contained in:
		
							parent
							
								
									a45ceac4e4
								
							
						
					
					
						commit
						99cd16ef9f
					
				
					 1 changed files with 13 additions and 14 deletions
				
			
		| 
						 | 
					@ -146,6 +146,14 @@ def is_linear_module(module):
 | 
				
			||||||
        global _VLLM_VERSION
 | 
					        global _VLLM_VERSION
 | 
				
			||||||
        if _VLLM_VERSION is None:
 | 
					        if _VLLM_VERSION is None:
 | 
				
			||||||
            _VLLM_VERSION = get_package_version('vllm')
 | 
					            _VLLM_VERSION = get_package_version('vllm')
 | 
				
			||||||
 | 
					        from vllm.model_executor.layers.linear import (
 | 
				
			||||||
 | 
					            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        VLLM_LINEAR_LIST = [
 | 
				
			||||||
 | 
					            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        if is_module_in_classes(module, VLLM_LINEAR_LIST):
 | 
				
			||||||
            if 'xpu' in _VLLM_VERSION:
 | 
					            if 'xpu' in _VLLM_VERSION:
 | 
				
			||||||
                # For vllm xpu
 | 
					                # For vllm xpu
 | 
				
			||||||
                from vllm.model_executor.parallel_utils.parallel_state import (
 | 
					                from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
| 
						 | 
					@ -159,15 +167,6 @@ def is_linear_module(module):
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                # For vllm cpu
 | 
					                # For vllm cpu
 | 
				
			||||||
                tp_size = 1
 | 
					                tp_size = 1
 | 
				
			||||||
 | 
					 | 
				
			||||||
        from vllm.model_executor.layers.linear import (
 | 
					 | 
				
			||||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        VLLM_LINEAR_LIST = [
 | 
					 | 
				
			||||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        if is_module_in_classes(module, VLLM_LINEAR_LIST):
 | 
					 | 
				
			||||||
            in_features = module.input_size
 | 
					            in_features = module.input_size
 | 
				
			||||||
            out_features = module.output_size
 | 
					            out_features = module.output_size
 | 
				
			||||||
            result = True
 | 
					            result = True
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue