LLM: accelerate sample of gptneox and update quantize (#8262)
* update quantize & accelerate sample * fix style check * fix style error
This commit is contained in:
parent
2bc0e7abbb
commit
8bd2992a8d
6 changed files with 70 additions and 29 deletions
|
|
@ -1,3 +1,19 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from bigdl.llm.ggml.convert import _convert_to_ggml
|
||||
from bigdl.llm.ggml.quantize import quantize
|
||||
from pathlib import Path
|
||||
|
|
@ -32,10 +48,10 @@ def convert_model(input_path: str,
|
|||
outfile_dir=tmp_ggml_file_path,
|
||||
model_family=model_family,
|
||||
outtype="fp16")
|
||||
|
||||
|
||||
tmp_ggml_file_path = next(Path(tmp_ggml_file_path).iterdir())
|
||||
|
||||
quantize(input_path=tmp_ggml_file_path,
|
||||
output_path=output_path,
|
||||
model_family=model_family,
|
||||
dtype=dtype)
|
||||
dtype=dtype)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@
|
|||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
class Bloom:
|
||||
|
|
@ -81,8 +82,7 @@ class Bloom:
|
|||
A Bloom instance.
|
||||
"""
|
||||
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
|
||||
if not self.ctx:
|
||||
raise RuntimeError(f"Failed to load model from {model_path}")
|
||||
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
|
||||
self.n_ctx = n_ctx
|
||||
self.seed = seed
|
||||
self.logits_all = logits_all
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ from ctypes import (
|
|||
)
|
||||
import pathlib
|
||||
from bigdl.llm.utils import get_avx_flags
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
# Load the library
|
||||
def _load_shared_library(lib_base_name: str):
|
||||
|
|
@ -71,8 +73,8 @@ def _load_shared_library(lib_base_name: str):
|
|||
elif sys.platform == "win32":
|
||||
lib_ext = ".dll"
|
||||
else:
|
||||
raise RuntimeError("Unsupported platform")
|
||||
|
||||
invalidInputError(False, "Unsupported platform")
|
||||
|
||||
avx = get_avx_flags()
|
||||
|
||||
# Construct the paths to the possible shared library names (python/llm/src/bigdl/llm/libs)
|
||||
|
|
@ -101,9 +103,10 @@ def _load_shared_library(lib_base_name: str):
|
|||
try:
|
||||
return ctypes.CDLL(str(_lib_path))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
||||
invalidInputError(False,
|
||||
f"Failed to load shared library '{_lib_path}': {e}")
|
||||
|
||||
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
|
||||
invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found")
|
||||
|
||||
|
||||
# Specify the base name of the shared library to load
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ import uuid
|
|||
import time
|
||||
import math
|
||||
import multiprocessing
|
||||
import ctypes
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
|
||||
from collections import deque, OrderedDict
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
|
@ -342,22 +343,29 @@ class Gptneox:
|
|||
"The attribute `eval_logits` of `Gptneox` object is None.")
|
||||
n_vocab = int(gptneox_cpp.gptneox_n_vocab(self.ctx))
|
||||
logits = self.eval_logits[-1]
|
||||
data = (gptneox_cpp.gptneox_token_data * n_vocab)(
|
||||
*[
|
||||
gptneox_cpp.gptneox_token_data(
|
||||
id=gptneox_cpp.gptneox_token(i),
|
||||
logit=logits[i],
|
||||
p=gptneox_cpp.c_float(0.0),
|
||||
)
|
||||
for i in range(n_vocab)
|
||||
]
|
||||
)
|
||||
size = gptneox_cpp.c_size_t(n_vocab)
|
||||
sorted = False
|
||||
candidates = gptneox_cpp.gptneox_token_data_array(
|
||||
data=data,
|
||||
size=size,
|
||||
sorted=sorted,
|
||||
# accelerate below code by moving to cpp
|
||||
# data = (gptneox_cpp.gptneox_token_data * n_vocab)(
|
||||
# *[
|
||||
# gptneox_cpp.gptneox_token_data(
|
||||
# id=gptneox_cpp.gptneox_token(i),
|
||||
# logit=logits[i],
|
||||
# p=gptneox_cpp.c_float(0.0),
|
||||
# )
|
||||
# for i in range(n_vocab)
|
||||
# ]
|
||||
# )
|
||||
# size = gptneox_cpp.c_size_t(n_vocab)
|
||||
# sorted = False
|
||||
# candidates = gptneox_cpp.gptneox_token_data_array(
|
||||
# data=data,
|
||||
# size=size,
|
||||
# sorted=sorted,
|
||||
# )
|
||||
logits = (ctypes.c_float * n_vocab)(*logits)
|
||||
candidates = gptneox_cpp.gptneox_get_candidates(
|
||||
ctx=self.ctx,
|
||||
n_vocab=n_vocab,
|
||||
logits=logits
|
||||
)
|
||||
gptneox_cpp.gptneox_sample_repetition_penalty(
|
||||
ctx=self.ctx,
|
||||
|
|
|
|||
|
|
@ -529,6 +529,24 @@ _lib.gptneox_token_eos.restype = gptneox_token
|
|||
# Sampling functions
|
||||
|
||||
|
||||
def gptneox_get_candidates(
|
||||
ctx: gptneox_context_p,
|
||||
n_vocab: c_int,
|
||||
logits: c_float_p,
|
||||
):
|
||||
return _lib.gptneox_get_candidates(
|
||||
ctx, n_vocab, logits
|
||||
)
|
||||
|
||||
|
||||
_lib.gptneox_get_candidates.argtypes = [
|
||||
gptneox_context_p,
|
||||
c_int,
|
||||
c_float_p
|
||||
]
|
||||
_lib.gptneox_get_candidates.restype = gptneox_token_data_array
|
||||
|
||||
|
||||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
|
||||
# with negative logit fix.
|
||||
def gptneox_sample_repetition_penalty(
|
||||
|
|
|
|||
|
|
@ -30,10 +30,9 @@ _llama_quantize_type = {"q4_0": 2,
|
|||
"q5_1": 9,
|
||||
"q8_0": 7}
|
||||
_bloom_quantize_type = {"q4_0": 2,
|
||||
"q4_1": 3}
|
||||
"q4_1": 3}
|
||||
_gptneox_quantize_type = {"q4_0": 2,
|
||||
"q4_1": 3,
|
||||
"q4_2": 5,
|
||||
"q5_0": 8,
|
||||
"q5_1": 9,
|
||||
"q8_0": 7}
|
||||
|
|
@ -42,9 +41,6 @@ _quantize_type = {"llama": _llama_quantize_type,
|
|||
"bloom": _bloom_quantize_type,
|
||||
"gptneox": _gptneox_quantize_type}
|
||||
|
||||
_valid_types = set(list(_llama_quantize_type.keys()) + list(_bloomz_quantize_type.keys()) +
|
||||
list(_gptneox_quantize_type.keys()))
|
||||
|
||||
|
||||
def quantize(input_path: str, output_path: str=None,
|
||||
model_family: str = 'llama', dtype: str='q4_0'):
|
||||
|
|
|
|||
Loading…
Reference in a new issue