Fix setting of use_quantize_kv_cache on different GPU in pipeline parallel (#11516)
This commit is contained in:
parent
7cb09a8eac
commit
252426793b
1 changed files with 25 additions and 0 deletions
|
|
@ -27,6 +27,7 @@ from typing import Callable, List, Optional
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -106,6 +107,29 @@ def init_pipeline_parallel():
|
||||||
dist.init_process_group('ccl')
|
dist.init_process_group('ccl')
|
||||||
|
|
||||||
|
|
||||||
|
def _check_quantize_kv_cache(model, idx, batch_size):
|
||||||
|
# align use_quantize_kv_cache setting for different GPU in pipeline parallel
|
||||||
|
pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
|
||||||
|
(os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) == "1") or \
|
||||||
|
(os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
|
||||||
|
if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
|
||||||
|
# for Qwen-VL-Chat
|
||||||
|
linear = model._modules['transformer'].h[idx].mlp.c_proj
|
||||||
|
elif model.config.model_type == "chatglm":
|
||||||
|
# for chatglm3-6b, glm-4-9b-chat
|
||||||
|
linear = model._modules['transformer'].encoder.layers[idx].self_attention.query_key_value
|
||||||
|
else:
|
||||||
|
linear = model._modules['model'].layers[idx].mlp.up_proj
|
||||||
|
pp_quantize_kv_cache = pp_quantize_kv_cache or (1 < batch_size and batch_size <= 8 and
|
||||||
|
hasattr(linear, "qtype") and
|
||||||
|
linear.qtype != ggml_tensor_qtype["fp16"] and
|
||||||
|
linear.qtype != ggml_tensor_qtype["bf16"])
|
||||||
|
if pp_quantize_kv_cache:
|
||||||
|
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "1"
|
||||||
|
else:
|
||||||
|
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
|
||||||
|
|
||||||
|
|
||||||
def pipeline_parallel(model, pipeline_parallel_stages):
|
def pipeline_parallel(model, pipeline_parallel_stages):
|
||||||
global num_layers
|
global num_layers
|
||||||
if hasattr(model.config, 'num_hidden_layers'):
|
if hasattr(model.config, 'num_hidden_layers'):
|
||||||
|
|
@ -255,6 +279,7 @@ def pipeline_parallel_generate(self,
|
||||||
_past_key_values = None
|
_past_key_values = None
|
||||||
bs = inputs.shape[0]
|
bs = inputs.shape[0]
|
||||||
output_ids = inputs.clone()
|
output_ids = inputs.clone()
|
||||||
|
_check_quantize_kv_cache(self, layer_start, bs)
|
||||||
|
|
||||||
step = 0
|
step = 0
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue