Modify Dockerfile

This commit is contained in:
Wang 2023-09-26 11:19:13 +08:00
commit a50c11d326
11 changed files with 321 additions and 14 deletions

View file

@ -1,16 +1,16 @@
## Run BF16-Optimized Lora Finetuning on Kubernetes with OneCCL
[Alpaca Lora](https://github.com/tloen/alpaca-lora/tree/main) uses [low-rank adaption](https://arxiv.org/pdf/2106.09685.pdf) to speed up the finetuning process of base model [Llama 7b](https://huggingface.co/decapoda-research/llama-7b-hf), and tries to reproduce the standard Alpaca, a general finetuned LLM. This is on top of Hugging Face transformers with Pytorch backend, which natively requires a number of expensive GPU resources and takes significant time.
[Alpaca Lora](https://github.com/tloen/alpaca-lora/tree/main) uses [low-rank adaption](https://arxiv.org/pdf/2106.09685.pdf) to speed up the finetuning process of base model [Llama2-7b](https://huggingface.co/meta-llama/Llama-2-7b), and tries to reproduce the standard Alpaca, a general finetuned LLM. This is on top of Hugging Face transformers with Pytorch backend, which natively requires a number of expensive GPU resources and takes significant time.
By constract, BigDL here provides a CPU optimization to accelerate the lora finetuning of Llama 7b, in the power of mixed-precision and distributed training. Detailedly, [Intel OneCCL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html), an available Hugging Face backend, is able to speed up the Pytorch computation with BF16 datatype on CPUs, as well as parallel processing on Kubernetes enabled by [Intel MPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/mpi-library.html).
By constract, BigDL here provides a CPU optimization to accelerate the lora finetuning of Llama2-7b, in the power of mixed-precision and distributed training. Detailedly, [Intel OneCCL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html), an available Hugging Face backend, is able to speed up the Pytorch computation with BF16 datatype on CPUs, as well as parallel processing on Kubernetes enabled by [Intel MPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/mpi-library.html).
The architecture is illustrated in the following:
![image](https://github.com/Uxito-Ada/BigDL/assets/60865256/139cf9be-10e6-48df-bc84-8872457e83dd)
![image](https://github.com/Jasonzzt/BigDL/assets/60865256/b66416bc-ad07-49af-8cb0-8967dffb5f58)
As above, BigDL implements its MPI training build on [Kubeflow MPI operator](https://github.com/kubeflow/mpi-operator/tree/master), which encapsulates the deployment as MPIJob CRD, and assists users to handle the construction of a MPI worker cluster on Kubernetes, such as public key distribution, SSH connection, and log collection.
Now, let's go to deploy a Lora finetuning to create a LLM from Llama 7b.
Now, let's go to deploy a Lora finetuning to create a LLM from Llama2-7b.
**Note: Please make sure you have already have an available Kubernetes infrastructure and NFS shared storage, and install [Helm CLI](https://helm.sh/docs/helm/helm_install/) for Kubernetes job submission.**
@ -22,7 +22,7 @@ Follow [here](https://github.com/kubeflow/mpi-operator/tree/master#installation)
Follow [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/lora/docker#prepare-bigdl-image-for-lora-finetuning) to prepare BigDL Lora Finetuning image in your cluster.
As finetuning is from a base model, first download [Llama 7b hf model from the public download site of Hugging Face](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main). Then, download [cleaned alpaca data](https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json), which contains all kinds of general knowledge and has already been cleaned. Next, move the downloaded files to a shared directory on your NFS server.
As finetuning is from a base model, first download [Llama2-7b model from the public download site of Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b). Then, download [cleaned alpaca data](https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json), which contains all kinds of general knowledge and has already been cleaned. Next, move the downloaded files to a shared directory on your NFS server.
### 3. Deploy through Helm Chart

View file

@ -2,6 +2,7 @@ FROM ubuntu:20.04
ARG http_proxy
ARG https_proxy
ARG PIP_NO_CACHE_DIR=false
# Install PYTHON 3.9
RUN env DEBIAN_FRONTEND=noninteractive apt-get update && \
@ -12,12 +13,16 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get update && \
ln -s /usr/bin/python3.9 /usr/bin/python3 && \
ln -s /usr/bin/python3 /usr/bin/python && \
apt-get install -y python3-pip python3.9-dev python3-wheel python3.9-distutils && \
pip3 install --no-cache --upgrade requests argparse urllib3 && \
pip3 install --pre --upgrade bigdl-llm[all] && \
pip3 install --pre --upgrade bigdl-nano && \
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \
# Install FastChat from source requires PEP 660 support
python3 get-pip.py && \
rm get-pip.py && \
pip install --upgrade requests argparse urllib3 && \
pip3 install --no-cache-dir --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu && \
pip install --pre --upgrade bigdl-llm[all] && \
pip install --pre --upgrade bigdl-nano && \
# Download chat.py script
cd /root && \
wget https://raw.githubusercontent.com/intel-analytics/BigDL/main/python/llm/portable-executable/chat.py && \
wget -P /root https://raw.githubusercontent.com/intel-analytics/BigDL/main/python/llm/portable-executable/chat.py && \
export PYTHONUNBUFFERED=1
ENTRYPOINT ["/bin/bash"]

View file

@ -5,6 +5,9 @@ ENV https_proxy $HTTP_PROXY
ENV TZ=Asia/Shanghai
# Disable pip's cache behavior
ARG PIP_NO_CACHE_DIR=false
RUN apt-get update && \
apt-get install -y curl wget git gnupg gpg-agent && \
wget -qO - https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg && \
@ -20,8 +23,12 @@ RUN apt-get update && \
ln -s /usr/bin/python3.9 /usr/bin/python3 && \
ln -s /usr/bin/python3 /usr/bin/python && \
apt-get install -y python3-pip python3.9-dev python3-wheel python3.9-distutils && \
pip3 install --no-cache --upgrade requests argparse urllib3 && \
pip3 install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu && \
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \
# Install FastChat from source requires PEP 660 support
python3 get-pip.py && \
rm get-pip.py && \
pip install --upgrade requests argparse urllib3 && \
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu && \
# Install opencl-related repos
apt-get update && \
apt-get install -y intel-opencl-icd intel-level-zero-gpu level-zero level-zero-dev

View file

@ -12,7 +12,7 @@ docker build \
### Use the image for doing xpu inference
To map the `xpu` into the cotainer, you need to specify `--device=/dev/dri` when booting the container.
To map the `xpu` into the container, you need to specify `--device=/dev/dri` when booting the container.
An example could be:
```bash

View file

@ -0,0 +1,19 @@
FROM intelanalytics/bigdl-llm-cpu:2.4.0-SNAPSHOT
ARG http_proxy
ARG https_proxy
# Disable pip's cache behavior
ARG PIP_NO_CACHE_DIR=false
# Install Serving Dependencies
RUN mkdir /llm && \
cd /llm && \
git clone https://github.com/analytics-zoo/FastChat.git && \
cd FastChat && \
git checkout dev-2023-09-22 && \
pip3 install -e ".[model_worker,webui]" && \
cd /llm
WORKDIR /llm/

View file

@ -0,0 +1,35 @@
## Build/Use BigDL-LLM-serving cpu image
### Build Image
```bash
docker build \
--build-arg http_proxy=.. \
--build-arg https_proxy=.. \
--build-arg no_proxy=.. \
--rm --no-cache -t intelanalytics/bigdl-llm-serving-cpu:2.4.0-SNAPSHOT .
```
### Use the image for doing cpu serving
You could use the following bash script to start the container. Please be noted that the CPU config is specified for Xeon CPUs, change it accordingly if you are not using a Xeon CPU.
```bash
#/bin/bash
export DOCKER_IMAGE=intelanalytics/bigdl-llm-serving-cpu:2.4.0-SNAPSHOT
sudo docker run -itd \
--net=host \
--cpuset-cpus="0-47" \
--cpuset-mems="0" \
--memory="32G" \
--name=CONTAINER_NAME \
--shm-size="16g" \
$DOCKER_IMAGE
```
After the container is booted, you could get into the container through `docker exec`.
To run model-serving using `BigDL-LLM` as backend, you can refer to this [document](https://github.com/intel-analytics/BigDL/tree/main/python/llm/src/bigdl/llm/serving).

View file

@ -0,0 +1,19 @@
FROM intelanalytics/bigdl-llm-xpu:2.4.0-SNAPSHOT
ARG http_proxy
ARG https_proxy
# Disable pip's cache behavior
ARG PIP_NO_CACHE_DIR=false
# Install Serving Dependencies
RUN mkdir /llm && \
cd /llm && \
git clone https://github.com/analytics-zoo/FastChat.git && \
cd FastChat && \
git checkout dev-2023-09-22 && \
pip3 install -e ".[model_worker,webui]" && \
cd /llm
WORKDIR /llm/

View file

@ -0,0 +1,46 @@
## Build/Use BigDL-LLM-serving xpu image
### Build Image
```bash
docker build \
--build-arg http_proxy=.. \
--build-arg https_proxy=.. \
--build-arg no_proxy=.. \
--rm --no-cache -t intelanalytics/bigdl-llm-serving-xpu:2.4.0-SNAPSHOT .
```
### Use the image for doing xpu serving
To map the `xpu` into the container, you need to specify `--device=/dev/dri` when booting the container.
An example could be:
```bash
#/bin/bash
export DOCKER_IMAGE=intelanalytics/bigdl-llm-serving-xpu:2.4.0-SNAPSHOT
sudo docker run -itd \
--net=host \
--device=/dev/dri \
--memory="32G" \
--name=CONTAINER_NAME \
--shm-size="16g" \
$DOCKER_IMAGE
```
After the container is booted, you could get into the container through `docker exec`.
To verify the device is successfully mapped into the container, run `sycl-ls` to check the result. In a machine with Arc A770, the sampled output is:
```bash
root@arda-arc12:/# sycl-ls
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device 1.2 [2023.16.7.0.21_160000]
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i9-13900K 3.0 [2023.16.7.0.21_160000]
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241]
```
After the container is booted, you could get into the container through `docker exec`.
To run model-serving using `BigDL-LLM` as backend, you can refer to this [document](https://github.com/intel-analytics/BigDL/tree/main/python/llm/src/bigdl/llm/serving).

View file

@ -272,4 +272,12 @@ def optimize(model):
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
gptneox_attention_forward
)
elif model.config.model_type == "internlm":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.internlm import internlm_attention_forward
convert_forward(model,
module.InternLMAttention,
internlm_attention_forward
)
return model

View file

@ -0,0 +1,168 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/internlm/internlm-chat-7b/blob/659ed911eec1e26810f9854f19c5ec27854e9cf3/modeling_internlm.py
# which is licensed under Apache License 2.0:
#
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch InternLM model."""
import math
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def internlm_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Tuple[torch.Tensor]]=None,
output_attentions: bool=False,
use_cache: bool=False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
query_states = self.q_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)
key_states = self.k_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)
value_states = self.v_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
position_ids,
"internlm"
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device
)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(
bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device
)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
f"but is {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, "
f"but is {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

View file

@ -71,7 +71,7 @@ def rotate_every_two(x):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
if model_family in ["llama", "baichuan"]:
if model_family in ["llama", "baichuan", "internlm"]:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]