Add internlm_xcomposer cpu examples (#9337)

* add internlm-xcomposer cpu examples

* use chat

* some fixes

* add license

* address shengsheng's comments

* use demo.jpg
This commit is contained in:
dingbaorong 2023-11-02 15:50:02 +08:00 committed by GitHub
parent 97a38958bd
commit 2e3bfbfe1f
8 changed files with 881 additions and 2 deletions

View file

@ -159,7 +159,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
| LLaVA | [link](python/llm/example/CPU/PyTorch-Models/Model/llava) | [link](python/llm/example/GPU/PyTorch-Models/Model/llava) |
| CodeLlama | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/codellama) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/codellama) |
| Skywork | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/skywork) | |
| InternLM-XComposer | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/internlm-xcomposer) | |
***For more details, please refer to the `bigdl-llm` [Document](https://test-bigdl-llm.readthedocs.io/en/main/doc/LLM/index.html), [Readme](python/llm), [Tutorial](https://github.com/intel-analytics/bigdl-llm-tutorial) and [API Doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/LLM/index.html).***

View file

@ -66,7 +66,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
| LLaVA | [link](example/CPU/PyTorch-Models/Model/llava) | [link](example/GPU/PyTorch-Models/Model/llava) |
| CodeLlama | [link](example/CPU/HF-Transformers-AutoModels/Model/codellama) | [link](example/GPU/HF-Transformers-AutoModels/Model/codellama) |
| Skywork | [link](example/CPU/HF-Transformers-AutoModels/Model/skywork) | |
| InternLM-XComposer | [link](example/CPU/HF-Transformers-AutoModels/Model/internlm-xcomposer) | |
### Working with `bigdl-llm`

View file

@ -0,0 +1,93 @@
# InternLM_XComposer
In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on InternLM_XComposer models. For illustration purposes, we utilize the [internlm/internlm-xcomposer-vl-7b](https://huggingface.co/internlm/internlm-xcomposer-vl-7b) as a reference InternLM_XComposer model.
## Requirements
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
## Example: Multi-turn chat centered around an image using `chat()` API
In the example [chat.py](./chat.py), we show a basic use case for an InternLM_XComposer model to start a multi-turn chat centered around an image using `chat()` API, with BigDL-LLM INT4 optimizations.
### 1. Install
We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
After installing conda, create a Python environment for BigDL-LLM:
```bash
conda create -n llm python=3.9 # recommend to use Python 3.9
conda activate llm
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops # additional package required for InternLM_XComposer to conduct generation
```
### 2. Download Model and Replace File
If you select the InternLM_XComposer model ([internlm/internlm-xcomposer-vl-7b](https://huggingface.co/internlm/internlm-xcomposer-vl-7b)), please note that their code (`modeling_InternLM_XComposer.py`) does not support inference on CPU. To address this issue, we have provided the updated file ([internlm-xcomposer-vl-7b/modeling_InternLM_XComposer.py](./internlm-xcomposer-vl-7b/modeling_InternLM_XComposer.py), which can be used to conduct inference on CPU.
#### 2.1 Download Model
You could use the following code to download [internlm/internlm-xcomposer-vl-7b](https://huggingface.co/internlm/internlm-xcomposer-vl-7b) with a specific snapshot id. Please note that the `modeling_InternLM_XComposer.py` file that we provide are based on these specific commits.
```
from huggingface_hub import snapshot_download
# for internlm/internlm-xcomposer-vl-7b
model_path = snapshot_download(repo_id='internlm/internlm-xcomposer-vl-7b',
revision="b06eb0c11653fe1568b6c5614b6b7be407ef8660",
cache_dir="dir/path/where/model/files/are/downloaded")
print(f'internlm/internlm-xcomposer-vl-7b checkpoint is downloaded to {model_path}')
```
#### 2.2 Replace `modeling_InternLM_XComposer.py`
For `internlm/internlm-xcomposer-vl-7b`, you should replace the `modeling_InternLM_XComposer.py` with [internlm-xcomposer-vl-7b/modeling_InternLM_XComposer.py](./internlm-xcomposer-vl-7b/modeling_InternLM_XComposer.py).
### 3. Run
After setting up the Python environment, you could run the example by following steps.
> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
>
> Please select the appropriate size of the LLaVA model based on the capabilities of your machine.
#### 3.1 Client
On client Windows machines, it is recommended to run directly with full utilization of all cores:
```powershell
python ./chat.py --image-path demo.jpg
```
More information about arguments can be found in [Arguments Info](#33-arguments-info) section. The expected output can be found in [Sample Output](#34-sample-output) section.
#### 3.2 Server
For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
E.g. on Linux,
```bash
# set BigDL-Nano env variables
source bigdl-nano-init
# e.g. for a server with 48 cores per socket
export OMP_NUM_THREADS=48
numactl -C 0-47 -m 0 python ./chat.py --image-path demo.jpg
```
More information about arguments can be found in [Arguments Info](#33-arguments-info) section. The expected output can be found in [Sample Output](#34-sample-output) section.
#### 3.3 Arguments Info
In the example, several arguments can be passed to satisfy your requirements:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the LLaVA model (e.g. `internlm/internlm-xcomposer-vl-7b`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'internlm/internlm-xcomposer-vl-7b'`.
- `--image-path IMAGE_PATH`: argument defining the input image that the chat will focus on. It is required and should be a local path (not url).
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `512`.
#### 3.4 Sample Chat
#### [internlm/internlm-xcomposer-vl-7b](https://huggingface.co/internlm/internlm-xcomposer-vl-7b)
```log
User: 这是什么?
Bot: bus
User: 它可以用来干什么
Bot: transport people
```
The sample input image is (which is fetched from [COCO dataset](https://cocodataset.org/#explore?id=178242)):
[demo.jpg](https://cocodataset.org/#explore?id=178242)
<a href="http://farm6.staticflickr.com/5331/8954873157_539393fece_z.jpg"><img width=400px src="http://farm6.staticflickr.com/5331/8954873157_539393fece_z.jpg" ></a>

View file

@ -0,0 +1,61 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.generation import GenerationConfig
import torch
import time
import os
import argparse
from bigdl.llm import optimize_model
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `chat()` API for InternLM-XComposer model')
parser.add_argument('--repo-id-or-model-path', type=str, default="internlm/internlm-xcomposer-vl-7b",
help='The huggingface repo id for the InternLM-XComposer model to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--image-path', type=str, required=True,
help='Image path for the input image that the chat will focus on')
parser.add_argument('--n-predict', type=int, default=512, help='Max tokens to predict')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
image = args.image_path
# Load model
# For successful BigDL-LLM optimization on InternLM-XComposer, skip the 'qkv' module during optimization
model = AutoModelForCausalLM.from_pretrained(model_path, device='cpu', load_in_4bit=True,
trust_remote_code=True, modules_to_not_convert=['qkv'])
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.tokenizer = tokenizer
history = None
while True:
try:
user_input = input("User: ")
except EOFError:
user_input = ""
if not user_input:
print("exit...")
break
response, history = model.chat(text=user_input, image=image , history = history)
print(f'Bot: {response}', end="")
image = None

View file

@ -0,0 +1,284 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://huggingface.co/internlm/internlm-xcomposer-vl-7b/blob/b06eb0c11653fe1568b6c5614b6b7be407ef8660/modeling_InternLM_XComposer.py
#
# Apache 2.0 license
# We change the dtype from float16 to float32 to enable inference on CPU.
import copy
import os
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)
import contextlib
import torch.utils.checkpoint
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from .modeling_perceive_sampler import BertConfig, BertLMHeadModel
from .modeling_vit import *
from .modeling_InternLM import *
from .modeling_utils import *
from transformers.utils import logging
logger = logging.get_logger(__name__)
class InternLMXComposerForCausalLM(PreTrainedModel):
config_class = InternLMXComposerConfig
_auto_class = "AutoModelForCausalLM"
gen_config = dict(
num_beams=5,
do_sample=False,
min_length=1,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=1.0,
max_new_tokens=200,
)
def __init__(self, config):
super().__init__(config)
print('Init VIT ... ', end='')
# self.visual_encoder = create_eva_vit_g()
self.visual_encoder = create_eva_vit_g(precision="fp32")
self.ln_vision = LayerNorm(self.visual_encoder.num_features)
print('Done')
print('Init Perceive Sampler ... ', end='')
with all_logging_disabled():
self.Qformer, self.query_tokens = self.init_qformer(
config.num_query_token, self.visual_encoder.num_features)
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.Qformer.cls = None
print('Done')
print('Init InternLM ... ', end='')
self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096]))
self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096]))
self.flag_image_start.requires_grad = False
self.flag_image_end.requires_grad = False
internlm_lora = config.internlm_lora
self.internlm_lora = internlm_lora
setattr(InternLMForCausalLM, 'lora_cfg', internlm_lora)
if int(torch.__version__[0]) == 1:
# self.internlm_model = InternLMForCausalLM._from_config(config).to(
# torch.float16)
self.internlm_model = InternLMForCausalLM._from_config(config).to(
torch.float32)
else:
assert int(torch.__version__[0]) == 2
# speed up init llm
with torch.device('meta'):
self.internlm_model = InternLMForCausalLM._from_config(config)
# self.internlm_model.to_empty(device=config.device).to(torch.float16)
self.internlm_model.to_empty(device=config.device).to(torch.float32)
for n, m in self.internlm_model.named_modules():
if 'lora' in n:
m.float()
self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size,
self.internlm_model.config.hidden_size)
print('Done')
self.vis_processor = transforms.Compose([
transforms.Resize((224, 224),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
self.tokenizer = None
@property
def eoh(self):
return self.tokenizer.decode(torch.Tensor([103027]),
skip_special_tokens=True)
@property
def eoa(self):
return self.tokenizer.decode(torch.Tensor([103028]),
skip_special_tokens=True)
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_qformer(cls,
num_query_token,
vision_width,
cross_attention_freq=2,
pretrain=True):
encoder_config = BertConfig()
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size))
query_tokens.data.normal_(mean=0.0,
std=encoder_config.initializer_range)
return Qformer, query_tokens
def encode_img(self, image):
if image is None:
return None
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image = self.vis_processor(image).unsqueeze(0).to(self.device)
else:
assert isinstance(image, torch.Tensor)
device = image.device
with self.maybe_autocast():
image_embeds = self.ln_vision(
self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1],
dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
-1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_internlm = self.internlm_proj(query_output.last_hidden_state)
inputs_internlm = torch.cat([
self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1),
inputs_internlm,
self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1)
],
dim=1)
return inputs_internlm
def encode_text(self, text, add_special_tokens=False):
text_token_ids = self.tokenizer(
text,
return_tensors='pt',
add_special_tokens=add_special_tokens,
).input_ids.to(self.device)
text_embeds = self.internlm_model.model.embed_tokens(text_token_ids)
return text_embeds
def decode_text(self, out_embeds):
out_text = self.tokenizer.batch_decode(out_embeds,
skip_special_tokens=True)[0]
out_text = out_text.split(self.eoa)[0]
return out_text
def wrap_text(self, user_text, bot_text='', add_special=True):
if add_special:
eoh = self.eoh
else:
eoh = ''
text = f' <|User|>:{user_text} \n{eoh} <|Bot|>:{bot_text}'
return text
def get_gen_args(self, **kwargs):
new_kargs = copy.deepcopy(self.gen_config)
new_kargs.update(kwargs)
return new_kargs
def generate(self, text, image=None, **kwargs):
text_embeds = self.encode_text(text)
img_embeds = self.encode_img(image)
prompt_embeds = self.wrap_prompt(text_embeds, img_embeds)
out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds,
**self.get_gen_args(**kwargs))
out_text = self.decode_text(out_embeds)
return out_text
def chat(self, text, image=None, history=None, **kwargs):
text_embeds = self.encode_text(text)
img_embeds = self.encode_img(image)
prompt_embeds = self.wrap_prompt(text_embeds,
img_embeds,
history=history)
out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds,
**self.get_gen_args(**kwargs))
out_text = self.decode_text(out_embeds)
# trunc at eoh and eoa
clean_out_text_token_ids = self.tokenizer(
out_text, return_tensors='pt').input_ids.to(self.device)
clean_out_text_embeds = self.internlm_model.model.embed_tokens(
clean_out_text_token_ids)
clean_prompt_embeds = self.wrap_prompt(text_embeds,
img_embeds,
add_special=False)
cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds],
dim=1)
if history is None:
history = []
history.append(cur_history)
return out_text, history
def wrap_prompt(self,
text_embeds,
img_embeds=None,
history=None,
add_special=True):
if add_special:
prompt_segs = [' <|User|>:', f'\n{self.eoh} <|Bot|>:']
else:
prompt_segs = [' <|User|>:', ' <|Bot|>:'] # used in wrap history
prompt_seg_embeds = []
for i, seg in enumerate(prompt_segs):
if history is not None:
add_special_tokens = False
else:
add_special_tokens = i == 0
seg_embeds = self.encode_text(
seg, add_special_tokens=add_special_tokens)
prompt_seg_embeds.append(seg_embeds)
if img_embeds is None:
img_embeds = text_embeds.new_empty(text_embeds.size(0), 0,
text_embeds.size(-1))
prompt_seg_embeds = [
prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1]
]
prompt_embeds = torch.cat(prompt_seg_embeds, dim=1)
if history is not None:
prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
return prompt_embeds

View file

@ -0,0 +1,93 @@
# InternLM_XComposer
In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate InternLM_XComposer models. For illustration purposes, we utilize the [internlm/internlm-xcomposer-vl-7b](https://huggingface.co/internlm/internlm-xcomposer-vl-7b) as a reference InternLM_XComposer model.
## Requirements
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
## Example: Multi-turn chat centered around an image using `chat()` API
In the example [chat.py](./chat.py), we show a basic use case for an InternLM_XComposer model to start a multi-turn chat centered around an image using `chat()` API, with BigDL-LLM 'optimize_model' API.
### 1. Install
We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
After installing conda, create a Python environment for BigDL-LLM:
```bash
conda create -n llm python=3.9 # recommend to use Python 3.9
conda activate llm
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops # additional package required for InternLM_XComposer to conduct generation
```
### 2. Download Model and Replace File
If you select the InternLM_XComposer model ([internlm/internlm-xcomposer-vl-7b](https://huggingface.co/internlm/internlm-xcomposer-vl-7b)), please note that their code (`modeling_InternLM_XComposer.py`) does not support inference on CPU. To address this issue, we have provided the updated file ([internlm-xcomposer-vl-7b/modeling_InternLM_XComposer.py](./internlm-xcomposer-vl-7b/modeling_InternLM_XComposer.py), which can be used to conduct inference on CPU.
#### 2.1 Download Model
You could use the following code to download [internlm/internlm-xcomposer-vl-7b](https://huggingface.co/internlm/internlm-xcomposer-vl-7b) with a specific snapshot id. Please note that the `modeling_InternLM_XComposer.py` file that we provide are based on these specific commits.
```
from huggingface_hub import snapshot_download
# for internlm/internlm-xcomposer-vl-7b
model_path = snapshot_download(repo_id='internlm/internlm-xcomposer-vl-7b',
revision="b06eb0c11653fe1568b6c5614b6b7be407ef8660",
cache_dir="dir/path/where/model/files/are/downloaded")
print(f'internlm/internlm-xcomposer-vl-7b checkpoint is downloaded to {model_path}')
```
#### 2.2 Replace `modeling_InternLM_XComposer.py`
For `internlm/internlm-xcomposer-vl-7b`, you should replace the `modeling_InternLM_XComposer.py` with [internlm-xcomposer-vl-7b/modeling_InternLM_XComposer.py](./internlm-xcomposer-vl-7b/modeling_InternLM_XComposer.py).
### 3. Run
After setting up the Python environment, you could run the example by following steps.
> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
>
> Please select the appropriate size of the LLaVA model based on the capabilities of your machine.
#### 3.1 Client
On client Windows machines, it is recommended to run directly with full utilization of all cores:
```powershell
python ./chat.py --image-path demo.jpg
```
More information about arguments can be found in [Arguments Info](#33-arguments-info) section. The expected output can be found in [Sample Output](#34-sample-output) section.
#### 3.2 Server
For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
E.g. on Linux,
```bash
# set BigDL-Nano env variables
source bigdl-nano-init
# e.g. for a server with 48 cores per socket
export OMP_NUM_THREADS=48
numactl -C 0-47 -m 0 python ./chat.py --image-path demo.jpg
```
More information about arguments can be found in [Arguments Info](#33-arguments-info) section. The expected output can be found in [Sample Output](#34-sample-output) section.
#### 3.3 Arguments Info
In the example, several arguments can be passed to satisfy your requirements:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the LLaVA model (e.g. `internlm/internlm-xcomposer-vl-7b`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'internlm/internlm-xcomposer-vl-7b'`.
- `--image-path IMAGE_PATH`: argument defining the input image that the chat will focus on. It is required and should be a local path (not url).
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `512`.
#### 3.4 Sample Chat
#### [internlm/internlm-xcomposer-vl-7b](https://huggingface.co/internlm/internlm-xcomposer-vl-7b)
```log
User: 这是什么?
Bot: bus
User: 它可以用来干什么
Bot: transport people
```
The sample input image is (which is fetched from [COCO dataset](https://cocodataset.org/#explore?id=178242)):
[demo.jpg](https://cocodataset.org/#explore?id=178242)
<a href="http://farm6.staticflickr.com/5331/8954873157_539393fece_z.jpg"><img width=400px src="http://farm6.staticflickr.com/5331/8954873157_539393fece_z.jpg" ></a>

View file

@ -0,0 +1,64 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
import torch
import time
import os
import argparse
from bigdl.llm import optimize_model
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `chat()` API for InternLM-XComposer model')
parser.add_argument('--repo-id-or-model-path', type=str, default="internlm/internlm-xcomposer-vl-7b",
help='The huggingface repo id for the InternLM-XComposer model to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--image-path', type=str, required=True,
help='Image path for the input image that the chat will focus on')
parser.add_argument('--n-predict', type=int, default=512, help='Max tokens to predict')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
image = args.image_path
# Load model
model = AutoModelForCausalLM.from_pretrained(model_path, device='cpu', trust_remote_code=True)
# With only one line to enable BigDL-LLM optimization on model
# For successful BigDL-LLM optimization on InternLM-XComposer, skip the 'qkv' module during optimization
model = optimize_model(model,
low_bit='sym_int4',
modules_to_not_convert=['qkv'])
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.tokenizer = tokenizer
history = None
while True:
try:
user_input = input("User: ")
except EOFError:
user_input = ""
if not user_input:
print("exit...")
break
response, history = model.chat(text=user_input, image=image , history = history)
print(f'Bot: {response}', end="")
image = None

View file

@ -0,0 +1,284 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://huggingface.co/internlm/internlm-xcomposer-vl-7b/blob/b06eb0c11653fe1568b6c5614b6b7be407ef8660/modeling_InternLM_XComposer.py
#
# Apache 2.0 license
# We change the dtype from float16 to float32 to enable inference on CPU.
import copy
import os
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)
import contextlib
import torch.utils.checkpoint
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from .modeling_perceive_sampler import BertConfig, BertLMHeadModel
from .modeling_vit import *
from .modeling_InternLM import *
from .modeling_utils import *
from transformers.utils import logging
logger = logging.get_logger(__name__)
class InternLMXComposerForCausalLM(PreTrainedModel):
config_class = InternLMXComposerConfig
_auto_class = "AutoModelForCausalLM"
gen_config = dict(
num_beams=5,
do_sample=False,
min_length=1,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=1.0,
max_new_tokens=200,
)
def __init__(self, config):
super().__init__(config)
print('Init VIT ... ', end='')
# self.visual_encoder = create_eva_vit_g()
self.visual_encoder = create_eva_vit_g(precision="fp32")
self.ln_vision = LayerNorm(self.visual_encoder.num_features)
print('Done')
print('Init Perceive Sampler ... ', end='')
with all_logging_disabled():
self.Qformer, self.query_tokens = self.init_qformer(
config.num_query_token, self.visual_encoder.num_features)
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.Qformer.cls = None
print('Done')
print('Init InternLM ... ', end='')
self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096]))
self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096]))
self.flag_image_start.requires_grad = False
self.flag_image_end.requires_grad = False
internlm_lora = config.internlm_lora
self.internlm_lora = internlm_lora
setattr(InternLMForCausalLM, 'lora_cfg', internlm_lora)
if int(torch.__version__[0]) == 1:
# self.internlm_model = InternLMForCausalLM._from_config(config).to(
# torch.float16)
self.internlm_model = InternLMForCausalLM._from_config(config).to(
torch.float32)
else:
assert int(torch.__version__[0]) == 2
# speed up init llm
with torch.device('meta'):
self.internlm_model = InternLMForCausalLM._from_config(config)
# self.internlm_model.to_empty(device=config.device).to(torch.float16)
self.internlm_model.to_empty(device=config.device).to(torch.float32)
for n, m in self.internlm_model.named_modules():
if 'lora' in n:
m.float()
self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size,
self.internlm_model.config.hidden_size)
print('Done')
self.vis_processor = transforms.Compose([
transforms.Resize((224, 224),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
self.tokenizer = None
@property
def eoh(self):
return self.tokenizer.decode(torch.Tensor([103027]),
skip_special_tokens=True)
@property
def eoa(self):
return self.tokenizer.decode(torch.Tensor([103028]),
skip_special_tokens=True)
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_qformer(cls,
num_query_token,
vision_width,
cross_attention_freq=2,
pretrain=True):
encoder_config = BertConfig()
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size))
query_tokens.data.normal_(mean=0.0,
std=encoder_config.initializer_range)
return Qformer, query_tokens
def encode_img(self, image):
if image is None:
return None
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image = self.vis_processor(image).unsqueeze(0).to(self.device)
else:
assert isinstance(image, torch.Tensor)
device = image.device
with self.maybe_autocast():
image_embeds = self.ln_vision(
self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1],
dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
-1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_internlm = self.internlm_proj(query_output.last_hidden_state)
inputs_internlm = torch.cat([
self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1),
inputs_internlm,
self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1)
],
dim=1)
return inputs_internlm
def encode_text(self, text, add_special_tokens=False):
text_token_ids = self.tokenizer(
text,
return_tensors='pt',
add_special_tokens=add_special_tokens,
).input_ids.to(self.device)
text_embeds = self.internlm_model.model.embed_tokens(text_token_ids)
return text_embeds
def decode_text(self, out_embeds):
out_text = self.tokenizer.batch_decode(out_embeds,
skip_special_tokens=True)[0]
out_text = out_text.split(self.eoa)[0]
return out_text
def wrap_text(self, user_text, bot_text='', add_special=True):
if add_special:
eoh = self.eoh
else:
eoh = ''
text = f' <|User|>:{user_text} \n{eoh} <|Bot|>:{bot_text}'
return text
def get_gen_args(self, **kwargs):
new_kargs = copy.deepcopy(self.gen_config)
new_kargs.update(kwargs)
return new_kargs
def generate(self, text, image=None, **kwargs):
text_embeds = self.encode_text(text)
img_embeds = self.encode_img(image)
prompt_embeds = self.wrap_prompt(text_embeds, img_embeds)
out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds,
**self.get_gen_args(**kwargs))
out_text = self.decode_text(out_embeds)
return out_text
def chat(self, text, image=None, history=None, **kwargs):
text_embeds = self.encode_text(text)
img_embeds = self.encode_img(image)
prompt_embeds = self.wrap_prompt(text_embeds,
img_embeds,
history=history)
out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds,
**self.get_gen_args(**kwargs))
out_text = self.decode_text(out_embeds)
# trunc at eoh and eoa
clean_out_text_token_ids = self.tokenizer(
out_text, return_tensors='pt').input_ids.to(self.device)
clean_out_text_embeds = self.internlm_model.model.embed_tokens(
clean_out_text_token_ids)
clean_prompt_embeds = self.wrap_prompt(text_embeds,
img_embeds,
add_special=False)
cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds],
dim=1)
if history is None:
history = []
history.append(cur_history)
return out_text, history
def wrap_prompt(self,
text_embeds,
img_embeds=None,
history=None,
add_special=True):
if add_special:
prompt_segs = [' <|User|>:', f'\n{self.eoh} <|Bot|>:']
else:
prompt_segs = [' <|User|>:', ' <|Bot|>:'] # used in wrap history
prompt_seg_embeds = []
for i, seg in enumerate(prompt_segs):
if history is not None:
add_special_tokens = False
else:
add_special_tokens = i == 0
seg_embeds = self.encode_text(
seg, add_special_tokens=add_special_tokens)
prompt_seg_embeds.append(seg_embeds)
if img_embeds is None:
img_embeds = text_embeds.new_empty(text_embeds.size(0), 0,
text_embeds.size(-1))
prompt_seg_embeds = [
prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1]
]
prompt_embeds = torch.cat(prompt_seg_embeds, dim=1)
if history is not None:
prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
return prompt_embeds