Add llama2 gpu low memory example (#9514)
* Add low memory example * Minor fixes * Update readme.md
This commit is contained in:
parent
06febb5fa7
commit
d154b38bf9
5 changed files with 382 additions and 7 deletions
|
|
@ -137,7 +137,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
|
||||||
| Model | CPU Example | GPU Example |
|
| Model | CPU Example | GPU Example |
|
||||||
|------------|----------------------------------------------------------------|-----------------------------------------------------------------|
|
|------------|----------------------------------------------------------------|-----------------------------------------------------------------|
|
||||||
| LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/vicuna)|
|
| LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/vicuna)|
|
||||||
| LLaMA 2 | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/llama2) |
|
| LLaMA 2 | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link1](python/llm/example/GPU/HF-Transformers-AutoModels/Model/llama2), [link2-low GPU memory example](python/llm/example/GPU/PyTorch-Models/Model/llama2#example-2---low-memory-version-predict-tokens-using-generate-api) |
|
||||||
| ChatGLM | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm) | |
|
| ChatGLM | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm) | |
|
||||||
| ChatGLM2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm2) |
|
| ChatGLM2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm2) |
|
||||||
| ChatGLM3 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm3) |
|
| ChatGLM3 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm3) |
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
|
||||||
| Model | CPU Example | GPU Example |
|
| Model | CPU Example | GPU Example |
|
||||||
|------------|----------------------------------------------------------------|-----------------------------------------------------------------|
|
|------------|----------------------------------------------------------------|-----------------------------------------------------------------|
|
||||||
| LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](example/GPU/HF-Transformers-AutoModels/Model/vicuna)|
|
| LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](example/GPU/HF-Transformers-AutoModels/Model/vicuna)|
|
||||||
| LLaMA 2 | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link](example/GPU/HF-Transformers-AutoModels/Model/llama2) |
|
| LLaMA 2 | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link1](example/GPU/HF-Transformers-AutoModels/Model/llama2), [link2-low GPU memory example](example/GPU/PyTorch-Models/Model/llama2#example-2---low-memory-version-predict-tokens-using-generate-api) |
|
||||||
| ChatGLM | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm) | |
|
| ChatGLM | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm) | |
|
||||||
| ChatGLM2 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm2) |
|
| ChatGLM2 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm2) |
|
||||||
| ChatGLM3 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm3) |
|
| ChatGLM3 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm3) |
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
# Llama2
|
# Llama2
|
||||||
In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Llama2 models. For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) and [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) as reference Llama2 models.
|
In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Llama2 models. For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) as reference Llama2 models.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
||||||
|
|
||||||
## Example: Predict Tokens using `generate()` API
|
## Example 1 - Basic Version: Predict Tokens using `generate()` API
|
||||||
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs.
|
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs.
|
||||||
### 1. Install
|
### 1. Install
|
||||||
We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
|
We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
|
||||||
|
|
@ -43,7 +43,7 @@ In the example, several arguments can be passed to satisfy your requirements:
|
||||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
||||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||||
|
|
||||||
#### 2.3 Sample Output
|
#### Sample Output
|
||||||
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||||
```log
|
```log
|
||||||
Inference time: xxxx s
|
Inference time: xxxx s
|
||||||
|
|
@ -67,3 +67,24 @@ What is AI?
|
||||||
|
|
||||||
AI, or artificial intelligence, refers to the ability of machines to perform tasks that would typically require human intelligence, such as learning, problem-solving,
|
AI, or artificial intelligence, refers to the ability of machines to perform tasks that would typically require human intelligence, such as learning, problem-solving,
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Example 2 - Low Memory Version: Predict Tokens using `generate()` API
|
||||||
|
|
||||||
|
If you're not able to load the full 4-bit model (e.g. [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)) in one GPU as shown in [Example 1](#example-1---basic-version-predict-tokens-using-generate-api), you may try this example instead.
|
||||||
|
|
||||||
|
In [low_memory_generate.py](./low_memory_generate.py), we show a way to load very large models with very low GPU memory footprint. However this could be much slower than the standard way. The implementation is adapted from [here](https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag).
|
||||||
|
|
||||||
|
### 1. Environment setup
|
||||||
|
Please refer to [Example 1](#example-1---basic-version-predict-tokens-using-generate-api) for more information.
|
||||||
|
|
||||||
|
### 2. Run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python ./low_memory_generate.py --split-weight --splitted-weights-path ${SPLITTED_WEIGHTS_PATH}
|
||||||
|
```
|
||||||
|
|
||||||
|
In the example, besides arguments in [Example 1](#3-run), several other arguments can be passed to satisfy your requirements:
|
||||||
|
|
||||||
|
- `--splitted-weights-path`: argument defining folder saving per-layer weights.
|
||||||
|
- `--split-weight`: argument defining whether to split weights by layer. If this argument is enabled, per-layer weights will be generated and saved to `--splitted-weights-path`. This argument only needs to be enabled once for the same model.
|
||||||
|
- `--max-cache-num`: argument defining the maximum number of weights saved in the cache. You can adjust this argument based on your GPU memory. It is default to be 200. For [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf), GPU peak memory is around 3G when it is set to 0 and 15G when it is set to 200.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,354 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# The code is adapted from: https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from transformers.generation import GenerationMixin
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
from bigdl.llm import optimize_model
|
||||||
|
from bigdl.llm.transformers.low_bit_linear import FP4Params, LowBitLinear
|
||||||
|
|
||||||
|
MAX_LENGTH = 4096
|
||||||
|
# you could tune the prompt based on your own model,
|
||||||
|
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
|
||||||
|
LLAMA2_PROMPT_FORMAT = """### HUMAN:
|
||||||
|
{prompt}
|
||||||
|
|
||||||
|
### RESPONSE:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Modified based on https://github.com/huggingface/accelerate/blob/d25efa71ce76a5f5911a1fc6c039979d7248596f/src/accelerate/utils/modeling.py#L238
|
||||||
|
def set_module_tensor_to_device_with_cache(
|
||||||
|
module: nn.Module,
|
||||||
|
tensor_name: str,
|
||||||
|
device: Union[int, str, torch.device],
|
||||||
|
value: Optional[torch.Tensor],
|
||||||
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
cache_dict: Optional[Dict[str, FP4Params]] = None,
|
||||||
|
max_cache_num: Optional[int] = 100,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
||||||
|
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (`torch.nn.Module`):
|
||||||
|
The module in which the tensor we want to move lives.
|
||||||
|
param_name (`str`):
|
||||||
|
The full name of the parameter/buffer.
|
||||||
|
device (`int`, `str` or `torch.device`):
|
||||||
|
The device on which to set the tensor.
|
||||||
|
value (`torch.Tensor`):
|
||||||
|
The value of the tensor (useful when going from the meta device to any other device).
|
||||||
|
dtype (`torch.dtype`, *optional*):
|
||||||
|
If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
|
||||||
|
the dtype of the existing parameter in the model.
|
||||||
|
cache_dict (`Dict`, *optional*):
|
||||||
|
The cache dict to save layer weights. This can improve the loading speed.
|
||||||
|
max_cache_num (`int`, *optional*):
|
||||||
|
The maximum number of weights saved in the cache_dict. You can adjust this number based on your GPU memory.
|
||||||
|
Default is 100.
|
||||||
|
"""
|
||||||
|
original_tensor_name = tensor_name
|
||||||
|
assert value is not None
|
||||||
|
|
||||||
|
# Recurse if needed
|
||||||
|
if "." in tensor_name:
|
||||||
|
splits = tensor_name.split(".")
|
||||||
|
for split in splits[:-1]:
|
||||||
|
new_module = getattr(module, split)
|
||||||
|
if new_module is None:
|
||||||
|
raise ValueError(f"{module} has no attribute {split}.")
|
||||||
|
module = new_module
|
||||||
|
tensor_name = splits[-1]
|
||||||
|
|
||||||
|
# Use cache to load weights
|
||||||
|
if original_tensor_name in cache_dict:
|
||||||
|
module._parameters[tensor_name] = cache_dict[original_tensor_name]
|
||||||
|
return
|
||||||
|
|
||||||
|
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
||||||
|
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||||
|
is_buffer = tensor_name in module._buffers
|
||||||
|
old_value = getattr(module, tensor_name)
|
||||||
|
|
||||||
|
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
|
||||||
|
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
|
||||||
|
|
||||||
|
if dtype is None:
|
||||||
|
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
|
||||||
|
value = value.to(old_value.dtype)
|
||||||
|
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||||
|
value = value.to(dtype)
|
||||||
|
|
||||||
|
param = module._parameters[tensor_name] if tensor_name in module._parameters else None
|
||||||
|
param_cls = type(param)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if isinstance(module, LowBitLinear):
|
||||||
|
# load cpu int4 weights
|
||||||
|
new_value = FP4Params(data=value,
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=True,
|
||||||
|
_shape=(module.out_features, module.in_features),
|
||||||
|
convert_shape_only=False,
|
||||||
|
qtype=2).to(device)
|
||||||
|
if len(cache_dict) < max_cache_num:
|
||||||
|
cache_dict.update({original_tensor_name: new_value})
|
||||||
|
elif isinstance(value, torch.Tensor):
|
||||||
|
new_value = value.to(device)
|
||||||
|
else:
|
||||||
|
new_value = torch.tensor(value, device=device)
|
||||||
|
if is_buffer:
|
||||||
|
module._buffers[tensor_name] = new_value
|
||||||
|
elif isinstance(module, LowBitLinear):
|
||||||
|
module._parameters[tensor_name] = new_value
|
||||||
|
elif value is not None or torch.device(device) != module._parameters[tensor_name].device:
|
||||||
|
param_cls = type(module._parameters[tensor_name])
|
||||||
|
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
|
||||||
|
module._parameters[tensor_name] = new_value
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LowMemoryLlama(GenerationMixin):
|
||||||
|
def __init__(self, model_path: str, splitted_weights_path: str, max_cache_num: Optional[int] = 100):
|
||||||
|
"""
|
||||||
|
Low memory version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage.
|
||||||
|
During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_path (`str`):
|
||||||
|
The huggingface repo id or path to the huggingface checkpoint folder that including config.json and tokenizer.model.
|
||||||
|
splitted_weights_path (`str`):
|
||||||
|
The folder including int4 weights per layer. You can use `LowMemoryLlama.split_and_convert_to_cpu_int4_weights` to generate those weights.
|
||||||
|
max_cache_num (`int`, *optional*):
|
||||||
|
The maximum number of weights saved in the cache_dict. You can adjust this number based on your GPU memory.
|
||||||
|
Default is 100.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Save parameters
|
||||||
|
self.model_path = model_path
|
||||||
|
self.splitted_weights_path = splitted_weights_path
|
||||||
|
self.device = torch.device('xpu:0')
|
||||||
|
self.layer_weight_cache = {} # initialize weight cache dict
|
||||||
|
self.max_cache_num = max_cache_num
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
self._create_model()
|
||||||
|
# Check if `self.splitted_weights_path` exists
|
||||||
|
self._check_split_weight_path()
|
||||||
|
|
||||||
|
# Initialize attention mask and position ids to further improve the inference speed
|
||||||
|
self._generate_att_mask_and_pos_id()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def split_and_convert_to_cpu_int4_weights(cls, model_path, safetensor_per_layer_path):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors=True)
|
||||||
|
model = optimize_model(model)
|
||||||
|
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
layer_names = ["model.embed_tokens."] + [f"model.layers.{i}." for i in range(len(model.model.layers))] + ["model.norm.", "lm_head."]
|
||||||
|
for layer_name in tqdm(layer_names):
|
||||||
|
local_state_dict = {k: v.contiguous() for k, v in state_dict.items() if k.startswith(layer_name)}
|
||||||
|
save_name = os.path.join(safetensor_per_layer_path, f"{layer_name}safetensors")
|
||||||
|
save_file(local_state_dict, save_name)
|
||||||
|
print(f'Save splitted safetensor weights to {safetensor_per_layer_path}')
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
# Load config
|
||||||
|
self.config = AutoConfig.from_pretrained(self.model_path)
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
warnings.warn("Set config.pretraining_tp = 1 to use int4 inference")
|
||||||
|
self.config.pretraining_tp = 1
|
||||||
|
if not self.config.use_cache:
|
||||||
|
warnings.warn("Set config.use_cache to further improve performance")
|
||||||
|
self.config.use_cache = True
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
||||||
|
# Model initialization
|
||||||
|
self._init_model()
|
||||||
|
# Tie with self.model
|
||||||
|
self.layer_names = ["model.embed_tokens"] + [f"model.layers.{i}" for i in range(len(self.model.model.layers))] + ["model.norm", "lm_head"]
|
||||||
|
self.generation_config = self.model.generation_config
|
||||||
|
self.main_input_name = self.model.main_input_name
|
||||||
|
|
||||||
|
def _check_split_weight_path(self):
|
||||||
|
for layer_name in self.layer_names:
|
||||||
|
split_weight_path = os.path.join(self.splitted_weights_path, f"{layer_name}.safetensors")
|
||||||
|
if not os.path.exists(split_weight_path):
|
||||||
|
raise FileNotFoundError(f"Weight file {split_weight_path} is missing."
|
||||||
|
f"You can generate it using `LowMemoryLlama.split_and_convert_to_cpu_int4_weights`.")
|
||||||
|
|
||||||
|
def _generate_att_mask_and_pos_id(self):
|
||||||
|
self.attention_mask = torch.full((MAX_LENGTH, MAX_LENGTH), torch.finfo(torch.float16).min, device=self.device)
|
||||||
|
mask_cond = torch.arange(self.attention_mask.size(-1), device=self.device)
|
||||||
|
self.attention_mask.masked_fill_(mask_cond < (mask_cond + 1).view(self.attention_mask.size(-1), 1), 0)
|
||||||
|
self.attention_mask = self.attention_mask.to(torch.float16)[None, None, :, :]
|
||||||
|
self.position_ids = torch.arange(MAX_LENGTH, dtype=torch.long, device=self.device)[None, :]
|
||||||
|
|
||||||
|
def _init_model(self):
|
||||||
|
# Load meta model (no memory used)
|
||||||
|
with init_empty_weights():
|
||||||
|
self.model = AutoModelForCausalLM.from_config(self.config)
|
||||||
|
self.model = optimize_model(self.model)
|
||||||
|
self.model.tie_weights()
|
||||||
|
|
||||||
|
self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) \
|
||||||
|
+ [self.model.model.norm, self.model.lm_head]
|
||||||
|
|
||||||
|
# Move buffers to device (not that much GPU memory used)
|
||||||
|
for buffer_name, buffer in self.model.named_buffers():
|
||||||
|
set_module_tensor_to_device_with_cache(self.model, buffer_name, self.device, value=buffer, dtype=buffer.dtype,
|
||||||
|
cache_dict=self.layer_weight_cache, max_cache_num=self.max_cache_num)
|
||||||
|
|
||||||
|
def move_layer_to_device(self, state_dict):
|
||||||
|
for param_name, param in state_dict.items():
|
||||||
|
set_module_tensor_to_device_with_cache(self.model, param_name, self.device, value=param, dtype=param.dtype,
|
||||||
|
cache_dict=self.layer_weight_cache, max_cache_num=self.max_cache_num)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
# Reinit model and clean memory
|
||||||
|
del self.model
|
||||||
|
gc.collect()
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
self._init_model()
|
||||||
|
|
||||||
|
# Send batch to device
|
||||||
|
inputs = input_ids.to(self.device)
|
||||||
|
|
||||||
|
# Set up kv cache
|
||||||
|
kv_cache = {}
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = {}
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# Generate attention mask and position ids
|
||||||
|
current_shape = inputs.shape[1]
|
||||||
|
if past_key_values.get(self.layer_names[1], None):
|
||||||
|
pre_shape = past_key_values[self.layer_names[1]][0].size(2)
|
||||||
|
else:
|
||||||
|
pre_shape = 0
|
||||||
|
pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
|
||||||
|
attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
|
||||||
|
|
||||||
|
for (layer_name, layer) in tqdm(zip(self.layer_names, self.layers), total=len(self.layers)):
|
||||||
|
|
||||||
|
# Load current layer to device
|
||||||
|
state_dict = load_file(os.path.join(self.splitted_weights_path, f"{layer_name}.safetensors"), device="cpu")
|
||||||
|
self.move_layer_to_device(state_dict)
|
||||||
|
|
||||||
|
# Run layer
|
||||||
|
if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
|
||||||
|
inputs = layer(inputs)
|
||||||
|
else:
|
||||||
|
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None),
|
||||||
|
position_ids=pos, attention_mask=attn)
|
||||||
|
kv_cache[layer_name] = new_kv_cache
|
||||||
|
|
||||||
|
# Delete weight before moving to('meta')
|
||||||
|
for module in layer.modules():
|
||||||
|
if hasattr(module, "weight"):
|
||||||
|
del module.weight
|
||||||
|
|
||||||
|
# Remove previous layer from memory (including buffers)
|
||||||
|
layer.to("meta")
|
||||||
|
|
||||||
|
result = CausalLMOutputWithPast(
|
||||||
|
logits=inputs.detach(),
|
||||||
|
past_key_values=kv_cache,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def can_generate(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||||
|
return self.model.prepare_inputs_for_generation(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||||
|
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`,'
|
||||||
|
' `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
|
||||||
|
', or the path to the huggingface checkpoint folder')
|
||||||
|
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||||
|
help='Prompt to infer')
|
||||||
|
parser.add_argument('--n-predict', type=int, default=32,
|
||||||
|
help='Max tokens to predict')
|
||||||
|
parser.add_argument('--splitted-weights-path', type=str, required=True,
|
||||||
|
help='The folder saving per-layer weights. You can use'
|
||||||
|
' LowMemoryLlama.split_and_convert_to_cpu_int4_weights() to generate those weights.')
|
||||||
|
parser.add_argument('--split-weight', action='store_true',
|
||||||
|
help='Whether to split weights by layer. If this argument is enabled, per-layer weights will'
|
||||||
|
' be generated and saved to `--splitted-weights-path`. This argument only needs to be'
|
||||||
|
' enabled once for the same model.')
|
||||||
|
parser.add_argument('--max-cache-num', type=int, default=200,
|
||||||
|
help='The maximum number of weights saved in the cache_dict. You can adjust this'
|
||||||
|
' number based on your GPU memory. Default is 200.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_path = args.repo_id_or_model_path
|
||||||
|
splitted_weights_path = args.splitted_weights_path
|
||||||
|
|
||||||
|
if args.split_weight:
|
||||||
|
os.makedirs(splitted_weights_path, exist_ok=True)
|
||||||
|
LowMemoryLlama.split_and_convert_to_cpu_int4_weights(model_path, splitted_weights_path)
|
||||||
|
|
||||||
|
model = LowMemoryLlama(model_path, splitted_weights_path, args.max_cache_num)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||||
|
input_ids = model.tokenizer.encode(prompt, return_tensors="pt").to('xpu')
|
||||||
|
st = time.time()
|
||||||
|
output = model.generate(input_ids,
|
||||||
|
max_new_tokens=args.n_predict)
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
output = output.cpu()
|
||||||
|
output_str = model.tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
print(f'Inference time: {end-st} s')
|
||||||
|
print('-'*20, 'Output', '-'*20)
|
||||||
|
print(output_str)
|
||||||
|
|
@ -226,8 +226,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
invalidInputError(isinstance(model, torch.nn.Module),
|
invalidInputError(isinstance(model, torch.nn.Module),
|
||||||
"model should be an instance of "
|
"model should be an instance of "
|
||||||
f"`torch.nn.Module`, but got {type(model)} at last.")
|
f"`torch.nn.Module`, but got {type(model)} at last.")
|
||||||
invalidInputError(model.device.type == 'cpu',
|
invalidInputError(model.device.type in ('cpu', 'meta'),
|
||||||
"Expect model on device `cpu`, "
|
"Expect model on device `cpu` or `meta`, "
|
||||||
f"but got device type {model.device.type}")
|
f"but got device type {model.device.type}")
|
||||||
if kwargs.pop("replace_embedding", False):
|
if kwargs.pop("replace_embedding", False):
|
||||||
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
|
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue