Fix setting of use_quantize_kv_cache on different GPU in pipeline parallel (#11516)
				
					
				
			This commit is contained in:
		
							parent
							
								
									7cb09a8eac
								
							
						
					
					
						commit
						252426793b
					
				
					 1 changed files with 25 additions and 0 deletions
				
			
		| 
						 | 
					@ -27,6 +27,7 @@ from typing import Callable, List, Optional
 | 
				
			||||||
from types import SimpleNamespace
 | 
					from types import SimpleNamespace
 | 
				
			||||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
 | 
					from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
| 
						 | 
					@ -106,6 +107,29 @@ def init_pipeline_parallel():
 | 
				
			||||||
    dist.init_process_group('ccl')
 | 
					    dist.init_process_group('ccl')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _check_quantize_kv_cache(model, idx, batch_size):
 | 
				
			||||||
 | 
					    # align use_quantize_kv_cache setting for different GPU in pipeline parallel
 | 
				
			||||||
 | 
					    pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
 | 
				
			||||||
 | 
					        (os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) == "1") or \
 | 
				
			||||||
 | 
					        (os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
 | 
				
			||||||
 | 
					    if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
 | 
				
			||||||
 | 
					        # for Qwen-VL-Chat
 | 
				
			||||||
 | 
					        linear = model._modules['transformer'].h[idx].mlp.c_proj
 | 
				
			||||||
 | 
					    elif model.config.model_type == "chatglm":
 | 
				
			||||||
 | 
					        # for chatglm3-6b, glm-4-9b-chat
 | 
				
			||||||
 | 
					        linear = model._modules['transformer'].encoder.layers[idx].self_attention.query_key_value
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        linear = model._modules['model'].layers[idx].mlp.up_proj
 | 
				
			||||||
 | 
					    pp_quantize_kv_cache = pp_quantize_kv_cache or (1 < batch_size and batch_size <= 8 and
 | 
				
			||||||
 | 
					                                                    hasattr(linear, "qtype") and
 | 
				
			||||||
 | 
					                                                    linear.qtype != ggml_tensor_qtype["fp16"] and
 | 
				
			||||||
 | 
					                                                    linear.qtype != ggml_tensor_qtype["bf16"])
 | 
				
			||||||
 | 
					    if pp_quantize_kv_cache:
 | 
				
			||||||
 | 
					        os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "1"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def pipeline_parallel(model, pipeline_parallel_stages):
 | 
					def pipeline_parallel(model, pipeline_parallel_stages):
 | 
				
			||||||
    global num_layers
 | 
					    global num_layers
 | 
				
			||||||
    if hasattr(model.config, 'num_hidden_layers'):
 | 
					    if hasattr(model.config, 'num_hidden_layers'):
 | 
				
			||||||
| 
						 | 
					@ -255,6 +279,7 @@ def pipeline_parallel_generate(self,
 | 
				
			||||||
    _past_key_values = None
 | 
					    _past_key_values = None
 | 
				
			||||||
    bs = inputs.shape[0]
 | 
					    bs = inputs.shape[0]
 | 
				
			||||||
    output_ids = inputs.clone()
 | 
					    output_ids = inputs.clone()
 | 
				
			||||||
 | 
					    _check_quantize_kv_cache(self, layer_start, bs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    step = 0
 | 
					    step = 0
 | 
				
			||||||
    # keep track of which sequences are already finished
 | 
					    # keep track of which sequences are already finished
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue