diff --git a/README.md b/README.md index 0b3ccc20..75798b92 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa | Model | CPU Example | GPU Example | |------------|----------------------------------------------------------------|-----------------------------------------------------------------| | LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/vicuna)| -| LLaMA 2 | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/llama2) | +| LLaMA 2 | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link1](python/llm/example/GPU/HF-Transformers-AutoModels/Model/llama2), [link2-low GPU memory example](python/llm/example/GPU/PyTorch-Models/Model/llama2#example-2---low-memory-version-predict-tokens-using-generate-api) | | ChatGLM | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm) | | | ChatGLM2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm2) | | ChatGLM3 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm3) | diff --git a/python/llm/README.md b/python/llm/README.md index 1d70f9f8..c2500b2f 100644 --- a/python/llm/README.md +++ b/python/llm/README.md @@ -39,7 +39,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa | Model | CPU Example | GPU Example | |------------|----------------------------------------------------------------|-----------------------------------------------------------------| | LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](example/GPU/HF-Transformers-AutoModels/Model/vicuna)| -| LLaMA 2 | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link](example/GPU/HF-Transformers-AutoModels/Model/llama2) | +| LLaMA 2 | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link1](example/GPU/HF-Transformers-AutoModels/Model/llama2), [link2-low GPU memory example](example/GPU/PyTorch-Models/Model/llama2#example-2---low-memory-version-predict-tokens-using-generate-api) | | ChatGLM | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm) | | | ChatGLM2 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm2) | | ChatGLM3 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm3) | diff --git a/python/llm/example/GPU/PyTorch-Models/Model/llama2/README.md b/python/llm/example/GPU/PyTorch-Models/Model/llama2/README.md index f36d373d..325f6488 100644 --- a/python/llm/example/GPU/PyTorch-Models/Model/llama2/README.md +++ b/python/llm/example/GPU/PyTorch-Models/Model/llama2/README.md @@ -1,10 +1,10 @@ # Llama2 -In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Llama2 models. For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) and [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) as reference Llama2 models. +In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Llama2 models. For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) as reference Llama2 models. ## Requirements To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. -## Example: Predict Tokens using `generate()` API +## Example 1 - Basic Version: Predict Tokens using `generate()` API In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs. ### 1. Install We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#). @@ -43,7 +43,7 @@ In the example, several arguments can be passed to satisfy your requirements: - `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. - `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. -#### 2.3 Sample Output +#### Sample Output #### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) ```log Inference time: xxxx s @@ -67,3 +67,24 @@ What is AI? AI, or artificial intelligence, refers to the ability of machines to perform tasks that would typically require human intelligence, such as learning, problem-solving, ``` + +## Example 2 - Low Memory Version: Predict Tokens using `generate()` API + +If you're not able to load the full 4-bit model (e.g. [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)) in one GPU as shown in [Example 1](#example-1---basic-version-predict-tokens-using-generate-api), you may try this example instead. + +In [low_memory_generate.py](./low_memory_generate.py), we show a way to load very large models with very low GPU memory footprint. However this could be much slower than the standard way. The implementation is adapted from [here](https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag). + +### 1. Environment setup +Please refer to [Example 1](#example-1---basic-version-predict-tokens-using-generate-api) for more information. + +### 2. Run + +```bash +python ./low_memory_generate.py --split-weight --splitted-weights-path ${SPLITTED_WEIGHTS_PATH} +``` + +In the example, besides arguments in [Example 1](#3-run), several other arguments can be passed to satisfy your requirements: + +- `--splitted-weights-path`: argument defining folder saving per-layer weights. +- `--split-weight`: argument defining whether to split weights by layer. If this argument is enabled, per-layer weights will be generated and saved to `--splitted-weights-path`. This argument only needs to be enabled once for the same model. +- `--max-cache-num`: argument defining the maximum number of weights saved in the cache. You can adjust this argument based on your GPU memory. It is default to be 200. For [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf), GPU peak memory is around 3G when it is set to 0 and 15G when it is set to 200. diff --git a/python/llm/example/GPU/PyTorch-Models/Model/llama2/low_memory_generate.py b/python/llm/example/GPU/PyTorch-Models/Model/llama2/low_memory_generate.py new file mode 100644 index 00000000..32f17329 --- /dev/null +++ b/python/llm/example/GPU/PyTorch-Models/Model/llama2/low_memory_generate.py @@ -0,0 +1,354 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# The code is adapted from: https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag. +# + +import argparse +import gc +import os +import time +import warnings +from typing import Dict, List, Optional, Union + +import intel_extension_for_pytorch as ipex +import torch +import torch.nn as nn +from accelerate import init_empty_weights +from safetensors.torch import load_file, save_file +from tqdm.auto import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import CausalLMOutputWithPast + +from bigdl.llm import optimize_model +from bigdl.llm.transformers.low_bit_linear import FP4Params, LowBitLinear + +MAX_LENGTH = 4096 +# you could tune the prompt based on your own model, +# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style +LLAMA2_PROMPT_FORMAT = """### HUMAN: +{prompt} + +### RESPONSE: +""" + + +# Modified based on https://github.com/huggingface/accelerate/blob/d25efa71ce76a5f5911a1fc6c039979d7248596f/src/accelerate/utils/modeling.py#L238 +def set_module_tensor_to_device_with_cache( + module: nn.Module, + tensor_name: str, + device: Union[int, str, torch.device], + value: Optional[torch.Tensor], + dtype: Optional[Union[str, torch.dtype]] = None, + cache_dict: Optional[Dict[str, FP4Params]] = None, + max_cache_num: Optional[int] = 100, +): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + param_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`): + The value of the tensor (useful when going from the meta device to any other device). + dtype (`torch.dtype`, *optional*): + If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to + the dtype of the existing parameter in the model. + cache_dict (`Dict`, *optional*): + The cache dict to save layer weights. This can improve the loading speed. + max_cache_num (`int`, *optional*): + The maximum number of weights saved in the cache_dict. You can adjust this number based on your GPU memory. + Default is 100. + """ + original_tensor_name = tensor_name + assert value is not None + + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + # Use cache to load weights + if original_tensor_name in cache_dict: + module._parameters[tensor_name] = cache_dict[original_tensor_name] + return + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + if dtype is None: + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model + value = value.to(old_value.dtype) + elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + value = value.to(dtype) + + param = module._parameters[tensor_name] if tensor_name in module._parameters else None + param_cls = type(param) + + with torch.no_grad(): + if isinstance(module, LowBitLinear): + # load cpu int4 weights + new_value = FP4Params(data=value, + requires_grad=False, + quantized=True, + _shape=(module.out_features, module.in_features), + convert_shape_only=False, + qtype=2).to(device) + if len(cache_dict) < max_cache_num: + cache_dict.update({original_tensor_name: new_value}) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + if is_buffer: + module._buffers[tensor_name] = new_value + elif isinstance(module, LowBitLinear): + module._parameters[tensor_name] = new_value + elif value is not None or torch.device(device) != module._parameters[tensor_name].device: + param_cls = type(module._parameters[tensor_name]) + new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device) + module._parameters[tensor_name] = new_value + else: + raise NotImplementedError + + +class LowMemoryLlama(GenerationMixin): + def __init__(self, model_path: str, splitted_weights_path: str, max_cache_num: Optional[int] = 100): + """ + Low memory version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage. + During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer. + + Parameters + ---------- + model_path (`str`): + The huggingface repo id or path to the huggingface checkpoint folder that including config.json and tokenizer.model. + splitted_weights_path (`str`): + The folder including int4 weights per layer. You can use `LowMemoryLlama.split_and_convert_to_cpu_int4_weights` to generate those weights. + max_cache_num (`int`, *optional*): + The maximum number of weights saved in the cache_dict. You can adjust this number based on your GPU memory. + Default is 100. + """ + super().__init__() + + # Save parameters + self.model_path = model_path + self.splitted_weights_path = splitted_weights_path + self.device = torch.device('xpu:0') + self.layer_weight_cache = {} # initialize weight cache dict + self.max_cache_num = max_cache_num + + # Create model + self._create_model() + # Check if `self.splitted_weights_path` exists + self._check_split_weight_path() + + # Initialize attention mask and position ids to further improve the inference speed + self._generate_att_mask_and_pos_id() + + @classmethod + def split_and_convert_to_cpu_int4_weights(cls, model_path, safetensor_per_layer_path): + model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors=True) + model = optimize_model(model) + + state_dict = model.state_dict() + layer_names = ["model.embed_tokens."] + [f"model.layers.{i}." for i in range(len(model.model.layers))] + ["model.norm.", "lm_head."] + for layer_name in tqdm(layer_names): + local_state_dict = {k: v.contiguous() for k, v in state_dict.items() if k.startswith(layer_name)} + save_name = os.path.join(safetensor_per_layer_path, f"{layer_name}safetensors") + save_file(local_state_dict, save_name) + print(f'Save splitted safetensor weights to {safetensor_per_layer_path}') + + def _create_model(self): + # Load config + self.config = AutoConfig.from_pretrained(self.model_path) + if self.config.pretraining_tp > 1: + warnings.warn("Set config.pretraining_tp = 1 to use int4 inference") + self.config.pretraining_tp = 1 + if not self.config.use_cache: + warnings.warn("Set config.use_cache to further improve performance") + self.config.use_cache = True + + # Load tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) + # Model initialization + self._init_model() + # Tie with self.model + self.layer_names = ["model.embed_tokens"] + [f"model.layers.{i}" for i in range(len(self.model.model.layers))] + ["model.norm", "lm_head"] + self.generation_config = self.model.generation_config + self.main_input_name = self.model.main_input_name + + def _check_split_weight_path(self): + for layer_name in self.layer_names: + split_weight_path = os.path.join(self.splitted_weights_path, f"{layer_name}.safetensors") + if not os.path.exists(split_weight_path): + raise FileNotFoundError(f"Weight file {split_weight_path} is missing." + f"You can generate it using `LowMemoryLlama.split_and_convert_to_cpu_int4_weights`.") + + def _generate_att_mask_and_pos_id(self): + self.attention_mask = torch.full((MAX_LENGTH, MAX_LENGTH), torch.finfo(torch.float16).min, device=self.device) + mask_cond = torch.arange(self.attention_mask.size(-1), device=self.device) + self.attention_mask.masked_fill_(mask_cond < (mask_cond + 1).view(self.attention_mask.size(-1), 1), 0) + self.attention_mask = self.attention_mask.to(torch.float16)[None, None, :, :] + self.position_ids = torch.arange(MAX_LENGTH, dtype=torch.long, device=self.device)[None, :] + + def _init_model(self): + # Load meta model (no memory used) + with init_empty_weights(): + self.model = AutoModelForCausalLM.from_config(self.config) + self.model = optimize_model(self.model) + self.model.tie_weights() + + self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) \ + + [self.model.model.norm, self.model.lm_head] + + # Move buffers to device (not that much GPU memory used) + for buffer_name, buffer in self.model.named_buffers(): + set_module_tensor_to_device_with_cache(self.model, buffer_name, self.device, value=buffer, dtype=buffer.dtype, + cache_dict=self.layer_weight_cache, max_cache_num=self.max_cache_num) + + def move_layer_to_device(self, state_dict): + for param_name, param in state_dict.items(): + set_module_tensor_to_device_with_cache(self.model, param_name, self.device, value=param, dtype=param.dtype, + cache_dict=self.layer_weight_cache, max_cache_num=self.max_cache_num) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + *args, + **kwargs, + ): + + # Reinit model and clean memory + del self.model + gc.collect() + torch.xpu.empty_cache() + self._init_model() + + # Send batch to device + inputs = input_ids.to(self.device) + + # Set up kv cache + kv_cache = {} + if past_key_values is None: + past_key_values = {} + + with torch.inference_mode(): + # Generate attention mask and position ids + current_shape = inputs.shape[1] + if past_key_values.get(self.layer_names[1], None): + pre_shape = past_key_values[self.layer_names[1]][0].size(2) + else: + pre_shape = 0 + pos = self.position_ids[:, pre_shape : current_shape + pre_shape] + attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:] + + for (layer_name, layer) in tqdm(zip(self.layer_names, self.layers), total=len(self.layers)): + + # Load current layer to device + state_dict = load_file(os.path.join(self.splitted_weights_path, f"{layer_name}.safetensors"), device="cpu") + self.move_layer_to_device(state_dict) + + # Run layer + if layer_name in ("model.embed_tokens", "model.norm", "lm_head"): + inputs = layer(inputs) + else: + inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None), + position_ids=pos, attention_mask=attn) + kv_cache[layer_name] = new_kv_cache + + # Delete weight before moving to('meta') + for module in layer.modules(): + if hasattr(module, "weight"): + del module.weight + + # Remove previous layer from memory (including buffers) + layer.to("meta") + + result = CausalLMOutputWithPast( + logits=inputs.detach(), + past_key_values=kv_cache, + ) + return result + + def can_generate(self): + return True + + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`,' + ' `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + parser.add_argument('--splitted-weights-path', type=str, required=True, + help='The folder saving per-layer weights. You can use' + ' LowMemoryLlama.split_and_convert_to_cpu_int4_weights() to generate those weights.') + parser.add_argument('--split-weight', action='store_true', + help='Whether to split weights by layer. If this argument is enabled, per-layer weights will' + ' be generated and saved to `--splitted-weights-path`. This argument only needs to be' + ' enabled once for the same model.') + parser.add_argument('--max-cache-num', type=int, default=200, + help='The maximum number of weights saved in the cache_dict. You can adjust this' + ' number based on your GPU memory. Default is 200.') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + splitted_weights_path = args.splitted_weights_path + + if args.split_weight: + os.makedirs(splitted_weights_path, exist_ok=True) + LowMemoryLlama.split_and_convert_to_cpu_int4_weights(model_path, splitted_weights_path) + + model = LowMemoryLlama(model_path, splitted_weights_path, args.max_cache_num) + + with torch.inference_mode(): + prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt) + input_ids = model.tokenizer.encode(prompt, return_tensors="pt").to('xpu') + st = time.time() + output = model.generate(input_ids, + max_new_tokens=args.n_predict) + torch.xpu.synchronize() + end = time.time() + output = output.cpu() + output_str = model.tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/src/bigdl/llm/optimize.py b/python/llm/src/bigdl/llm/optimize.py index 39b750f7..71fce857 100644 --- a/python/llm/src/bigdl/llm/optimize.py +++ b/python/llm/src/bigdl/llm/optimize.py @@ -226,8 +226,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ invalidInputError(isinstance(model, torch.nn.Module), "model should be an instance of " f"`torch.nn.Module`, but got {type(model)} at last.") - invalidInputError(model.device.type == 'cpu', - "Expect model on device `cpu`, " + invalidInputError(model.device.type in ('cpu', 'meta'), + "Expect model on device `cpu` or `meta`, " f"but got device type {model.device.type}") if kwargs.pop("replace_embedding", False): warnings.warn("replace_embedding is deprecated and will be removed in a future version,"