Fix setting of use_quantize_kv_cache on different GPU in pipeline parallel (#11516)
This commit is contained in:
parent
7cb09a8eac
commit
252426793b
1 changed files with 25 additions and 0 deletions
|
|
@ -27,6 +27,7 @@ from typing import Callable, List, Optional
|
|||
from types import SimpleNamespace
|
||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
import asyncio
|
||||
|
|
@ -106,6 +107,29 @@ def init_pipeline_parallel():
|
|||
dist.init_process_group('ccl')
|
||||
|
||||
|
||||
def _check_quantize_kv_cache(model, idx, batch_size):
|
||||
# align use_quantize_kv_cache setting for different GPU in pipeline parallel
|
||||
pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
|
||||
(os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) == "1") or \
|
||||
(os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
|
||||
if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
|
||||
# for Qwen-VL-Chat
|
||||
linear = model._modules['transformer'].h[idx].mlp.c_proj
|
||||
elif model.config.model_type == "chatglm":
|
||||
# for chatglm3-6b, glm-4-9b-chat
|
||||
linear = model._modules['transformer'].encoder.layers[idx].self_attention.query_key_value
|
||||
else:
|
||||
linear = model._modules['model'].layers[idx].mlp.up_proj
|
||||
pp_quantize_kv_cache = pp_quantize_kv_cache or (1 < batch_size and batch_size <= 8 and
|
||||
hasattr(linear, "qtype") and
|
||||
linear.qtype != ggml_tensor_qtype["fp16"] and
|
||||
linear.qtype != ggml_tensor_qtype["bf16"])
|
||||
if pp_quantize_kv_cache:
|
||||
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "1"
|
||||
else:
|
||||
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
|
||||
|
||||
|
||||
def pipeline_parallel(model, pipeline_parallel_stages):
|
||||
global num_layers
|
||||
if hasattr(model.config, 'num_hidden_layers'):
|
||||
|
|
@ -255,6 +279,7 @@ def pipeline_parallel_generate(self,
|
|||
_past_key_values = None
|
||||
bs = inputs.shape[0]
|
||||
output_ids = inputs.clone()
|
||||
_check_quantize_kv_cache(self, layer_start, bs)
|
||||
|
||||
step = 0
|
||||
# keep track of which sequences are already finished
|
||||
|
|
|
|||
Loading…
Reference in a new issue