LLM: accelerate sample of gptneox and update quantize (#8262)

* update quantize & accelerate sample

* fix style check

* fix style error
This commit is contained in:
Ruonan Wang 2023-06-05 15:36:00 +08:00 committed by GitHub
parent 2bc0e7abbb
commit 8bd2992a8d
6 changed files with 70 additions and 29 deletions

View file

@ -1,3 +1,19 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.llm.ggml.convert import _convert_to_ggml
from bigdl.llm.ggml.quantize import quantize
from pathlib import Path
@ -32,10 +48,10 @@ def convert_model(input_path: str,
outfile_dir=tmp_ggml_file_path,
model_family=model_family,
outtype="fp16")
tmp_ggml_file_path = next(Path(tmp_ggml_file_path).iterdir())
quantize(input_path=tmp_ggml_file_path,
output_path=output_path,
model_family=model_family,
dtype=dtype)
dtype=dtype)

View file

@ -46,6 +46,7 @@
# only search the first bigdl package and end up finding only one sub-package.
from .bloom_cpp import bloom_load, bloom_free, bloom_run
from bigdl.llm.utils.common import invalidInputError
class Bloom:
@ -81,8 +82,7 @@ class Bloom:
A Bloom instance.
"""
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
if not self.ctx:
raise RuntimeError(f"Failed to load model from {model_path}")
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
self.n_ctx = n_ctx
self.seed = seed
self.logits_all = logits_all

View file

@ -62,6 +62,8 @@ from ctypes import (
)
import pathlib
from bigdl.llm.utils import get_avx_flags
from bigdl.llm.utils.common import invalidInputError
# Load the library
def _load_shared_library(lib_base_name: str):
@ -71,8 +73,8 @@ def _load_shared_library(lib_base_name: str):
elif sys.platform == "win32":
lib_ext = ".dll"
else:
raise RuntimeError("Unsupported platform")
invalidInputError(False, "Unsupported platform")
avx = get_avx_flags()
# Construct the paths to the possible shared library names (python/llm/src/bigdl/llm/libs)
@ -101,9 +103,10 @@ def _load_shared_library(lib_base_name: str):
try:
return ctypes.CDLL(str(_lib_path))
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
invalidInputError(False,
f"Failed to load shared library '{_lib_path}': {e}")
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found")
# Specify the base name of the shared library to load

View file

@ -51,6 +51,7 @@ import uuid
import time
import math
import multiprocessing
import ctypes
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque, OrderedDict
from bigdl.llm.utils.common import invalidInputError
@ -342,22 +343,29 @@ class Gptneox:
"The attribute `eval_logits` of `Gptneox` object is None.")
n_vocab = int(gptneox_cpp.gptneox_n_vocab(self.ctx))
logits = self.eval_logits[-1]
data = (gptneox_cpp.gptneox_token_data * n_vocab)(
*[
gptneox_cpp.gptneox_token_data(
id=gptneox_cpp.gptneox_token(i),
logit=logits[i],
p=gptneox_cpp.c_float(0.0),
)
for i in range(n_vocab)
]
)
size = gptneox_cpp.c_size_t(n_vocab)
sorted = False
candidates = gptneox_cpp.gptneox_token_data_array(
data=data,
size=size,
sorted=sorted,
# accelerate below code by moving to cpp
# data = (gptneox_cpp.gptneox_token_data * n_vocab)(
# *[
# gptneox_cpp.gptneox_token_data(
# id=gptneox_cpp.gptneox_token(i),
# logit=logits[i],
# p=gptneox_cpp.c_float(0.0),
# )
# for i in range(n_vocab)
# ]
# )
# size = gptneox_cpp.c_size_t(n_vocab)
# sorted = False
# candidates = gptneox_cpp.gptneox_token_data_array(
# data=data,
# size=size,
# sorted=sorted,
# )
logits = (ctypes.c_float * n_vocab)(*logits)
candidates = gptneox_cpp.gptneox_get_candidates(
ctx=self.ctx,
n_vocab=n_vocab,
logits=logits
)
gptneox_cpp.gptneox_sample_repetition_penalty(
ctx=self.ctx,

View file

@ -529,6 +529,24 @@ _lib.gptneox_token_eos.restype = gptneox_token
# Sampling functions
def gptneox_get_candidates(
ctx: gptneox_context_p,
n_vocab: c_int,
logits: c_float_p,
):
return _lib.gptneox_get_candidates(
ctx, n_vocab, logits
)
_lib.gptneox_get_candidates.argtypes = [
gptneox_context_p,
c_int,
c_float_p
]
_lib.gptneox_get_candidates.restype = gptneox_token_data_array
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
# with negative logit fix.
def gptneox_sample_repetition_penalty(

View file

@ -30,10 +30,9 @@ _llama_quantize_type = {"q4_0": 2,
"q5_1": 9,
"q8_0": 7}
_bloom_quantize_type = {"q4_0": 2,
"q4_1": 3}
"q4_1": 3}
_gptneox_quantize_type = {"q4_0": 2,
"q4_1": 3,
"q4_2": 5,
"q5_0": 8,
"q5_1": 9,
"q8_0": 7}
@ -42,9 +41,6 @@ _quantize_type = {"llama": _llama_quantize_type,
"bloom": _bloom_quantize_type,
"gptneox": _gptneox_quantize_type}
_valid_types = set(list(_llama_quantize_type.keys()) + list(_bloomz_quantize_type.keys()) +
list(_gptneox_quantize_type.keys()))
def quantize(input_path: str, output_path: str=None,
model_family: str = 'llama', dtype: str='q4_0'):