diff --git a/.github/workflows/llm-harness-evaluation.yml b/.github/workflows/llm-harness-evaluation.yml index b5b08d82..b9fb6123 100644 --- a/.github/workflows/llm-harness-evaluation.yml +++ b/.github/workflows/llm-harness-evaluation.yml @@ -10,7 +10,6 @@ on: schedule: - cron: "00 13 * * 5" # GMT time, 13:00 GMT == 21:00 China pull_request: - types: ready_for_review branches: [main] paths: - ".github/workflows/llm-harness-evaluation.yml" @@ -139,10 +138,7 @@ jobs: working-directory: ${{ github.workspace }}/python/llm/dev/benchmark/harness/ shell: bash run: | - git clone https://github.com/EleutherAI/lm-evaluation-harness.git - cd lm-evaluation-harness - git checkout e81d3cc - pip install -e . + pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@e81d3cc - name: Download models and datasets shell: bash @@ -226,4 +222,4 @@ jobs: shell: bash run: | ls results - python ${{ github.workspace }}/python/llm/dev/benchmark/harness/make_table_results.py results \ No newline at end of file + python ${{ github.workspace }}/python/llm/dev/benchmark/harness/make_table_results.py results diff --git a/python/llm/dev/benchmark/harness/bigdl_llm.py b/python/llm/dev/benchmark/harness/bigdl_llm.py index c9a27b79..c6bd843d 100644 --- a/python/llm/dev/benchmark/harness/bigdl_llm.py +++ b/python/llm/dev/benchmark/harness/bigdl_llm.py @@ -13,116 +13,44 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os -import multiprocessing -from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM +from bigdl.llm.transformers import AutoModelForCausalLM -import torch -from typing import Optional, Union -from lm_eval.base import BaseLM +import inspect +from lm_eval.models.huggingface import AutoCausalLM +from lm_eval import utils +from functools import partial -from transformers import AutoTokenizer, LlamaTokenizer -def _get_dtype( - dtype: Union[str, torch.dtype] -) -> torch.dtype: - """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" - if isinstance(dtype, str) and dtype != "auto": - # Convert `str` args torch dtype: `float16` -> `torch.float16` - _torch_dtype = getattr(torch, dtype) - else: - _torch_dtype = dtype - return _torch_dtype +# wrap and force the Reorderer to be in a decrease order +# This is a workaround to avoid frequent memory allocation which may cause OOM +def force_decrease_order(Reorderer): + def DecreaseReorderer(arr, fn): + def _collate(x): + len, tokens = fn(x) + len = - abs(len) + return len, tokens + return Reorderer(arr, _collate) + return DecreaseReorderer +utils.Reorderer = force_decrease_order(utils.Reorderer) -class BigDLLM(BaseLM): - def __init__( - self, - device="xpu", - pretrained="gpt2", - revision="main", - low_cpu_mem_usage=None, - subfolder=None, - tokenizer=None, - batch_size=1, - load_in_8bit: Optional[bool] = False, - trust_remote_code: Optional[bool] = True, - load_in_low_bit=None, - dtype: Optional[Union[str, torch.dtype]] = "auto", - **kwargs - ): - super().__init__() - assert isinstance(pretrained, str) - assert isinstance(batch_size, (int,str)) - if 'xpu' in device: - import intel_extension_for_pytorch as ipex - model = AutoModelForCausalLM.from_pretrained(pretrained, - load_in_low_bit=load_in_low_bit, - optimize_model=kwargs.get('optimize_model', True), - trust_remote_code=trust_remote_code, - use_cache=True, - torch_dtype=_get_dtype(dtype)) - print(model) # print model to check precision - self._device = device - self.model = model.to(device) - - self.tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True) - - # setup for automatic batch size detection - if batch_size == 'auto': - self.batch_size_per_gpu = batch_size - else: - self.batch_size_per_gpu = int(batch_size) +class BigDLLM(AutoCausalLM): + AUTO_MODEL_CLASS = AutoModelForCausalLM + AutoCausalLM_ARGS = inspect.getfullargspec(AutoCausalLM.__init__).args + def __init__(self, *args, **kwargs): + if 'device' in kwargs and 'xpu' in kwargs['device']: + import intel_extension_for_pytorch + self.bigdl_llm_kwargs = {} + keys = list(kwargs.keys()) + for k in keys: + if k not in self.AutoCausalLM_ARGS: + self.bigdl_llm_kwargs[k] = kwargs[k] + kwargs.pop(k) + AutoModelForCausalLM.from_pretrained = partial(AutoModelForCausalLM.from_pretrained, **self.bigdl_llm_kwargs) + kwargs['trust_remote_code'] = kwargs.get('trust_remote_code', True) + super().__init__(*args, **kwargs) @property - def eot_token_id(self): - # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* - return self.tokenizer.eos_token_id - - @property - def max_length(self): - return 2048 # TODO: how to get this from config - - @property - def max_gen_toks(self): - return 256 - - @property - def batch_size(self): - # TODO: fix multi-gpu - return self.batch_size_per_gpu # * gpus - - @property - def device(self): - # TODO: fix multi-gpu - return torch.device(self._device) - - def tok_encode(self, string: str): - input_ids = self.tokenizer.encode(string) - return input_ids - - def tok_decode(self, tokens): - return self.tokenizer.decode(tokens, skip_special_tokens=True) - - def _model_call(self, inps): - """ - inps: a torch tensor of shape [batch, sequence] - the size of sequence may vary from call to call - - returns: a torch tensor of shape [batch, sequence, vocab] with the - logits returned from the model - """ - with torch.inference_mode(): - inps = inps.to(self.device) - res = self.model(inps)[0] - return res - - def _model_generate(self, context, max_length, eos_token_id): - generation_kwargs = {"do_sample": False, "max_length": max_length} - if eos_token_id is not None: - generation_kwargs["eos_token_id"] = eos_token_id - generation_kwargs[ - "pad_token_id" - ] = eos_token_id # setting eos_token_id as pad token - return self.model.generate(context, **generation_kwargs) \ No newline at end of file + def add_special_tokens(self) -> bool: + return False