Fix optimize lm head error (#11899)

This commit is contained in:
Guancheng Fu 2024-08-22 17:45:26 +08:00 committed by GitHub
parent c5b51d41fb
commit 278b191dc1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -225,6 +225,8 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \ from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
# Currently, vLLM does not support optimize_lm_head = True
optimize_lm_head = False
if isinstance(module, ParallelLMHead): if isinstance(module, ParallelLMHead):
if qtype == ggml_tensor_qtype["fp16"]: if qtype == ggml_tensor_qtype["fp16"]:
new_linear = FP16Linear( new_linear = FP16Linear(