LLM: Support gpt-j in speculative decoding (#10067)

* gptj

* support gptj in speculative decoding

* fix

* update readme

* small fix
This commit is contained in:
Yina Chen 2024-02-02 14:54:55 +08:00 committed by GitHub
parent 2927c77d7f
commit 77be19bb97
4 changed files with 161 additions and 27 deletions

View file

@ -0,0 +1,47 @@
# GPT-J
In this directory, you will find examples on how you could run GPT-J FP16 infernece with self-speculative decoding using BigDL-LLM on [Intel GPUs](../README.md). For illustration purposes,we utilize the [EleutherAI/gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b) as reference GPT-J models.
## 0. Requirements
To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
## Example: Predict Tokens using `generate()` API
In the example [speculative.py](./speculative.py), we show a basic use case for a GPT-J model to predict the next N tokens using `generate()` API, with BigDL-LLM speculative decoding optimizations on Intel GPUs.
### 1. Install
We suggest using conda to manage environment:
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
```
### 2. Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. Run
For optimal performance on Intel Data Center GPU Max Series, it is recommended to set several environment variables.
```bash
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
export ENABLE_SDP_FUSION=1
```
```
python ./speculative.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
```
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the GPT-J model (e.g. `EleutherAI/gpt-j-6b`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'EleutherAI/gpt-j-6b'`.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). A default prompt is provided.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `128`.
#### Sample Output
#### [EleutherAI/gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b)
```log
It is done, and submitted. You can play 'Survival of the Tastiest' on Android, and on the web. Playing on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. There is a lot I'd like to talk about. I will go through every topic, insted of making the typical what went right/wrong list. Concept Working over the theme was probably one of the hardest tasks which I had to face. Originally, I had an idea of what kind of game I wanted to develop, gameplay wise - something with a lot of enemies/actors, simple graphics, maybe set in space, controlled from a top-down view. I was confident that I could fit any theme around it. In the end, the problem with a theme like 'Evolution' in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game? In a game, you need to control something to reach an objective. That control goes against what evolution is supposed to be like. If you allow the user to pick how to evolve something, it's not evolution anymore - it's the equivalent of intelligent design, the fable invented by creationists to combat the idea of evolution. Being agnostic and a Pastafarian, that's not something that rubbed me the right way. Hence, my biggest dillema when deciding what to create was not with what I wanted to create, but with what I did not. I didn't want to create an 'intelligent design' simulator and wrongly call it evolution. This is a problem, of course, every other contestant also had to face. And judging by the entries submitted, not many managed to work around it. I'd say the only real solution was through the use of artificial selection, somehow. So far, I have not seen any entry using this at its core gameplay. Alas, this is just a fun competition and after a while I decided not to be as strict with the game idea, and allowed myself to pick whatever I thought would work out. My initial idea was to create something where humanity tried to evolve to a next level but had some kind of foe trying to stop them from doing so. I kind of had this image of human souls flying in space towards a monolith or a space baby (all based in 2001: A Space Odyssey of course) but I couldn't think of compelling (read: serious) mechanics for that. Borgs were my next inspiration, as their whole hypothesis fit pretty well into the evolution theme. But how to make it work? Are you the borg, or fighting the Borg? The third and final idea came to me through my girlfriend, who somehow gave me the idea of making something about the evolution of Pasta. The more I thought about it the more it sounded like it would work, so I decided to go with it. Conversations with my inspiring co-worker Roushey (who also created the 'Mechanical Underdogs' signature logo for my intros) further matured the concept, as it involved into the idea of having individual pieces of pasta flying around and trying to evolve until they became all-powerful. A secondary idea here was that the game would work to explain how the Flying Spaghetti Monster came to exist - by evolving from a normal dinner table. So the idea evolved more or less into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is having your pasta conquer all the plates on the table). But they are really autonomous, so after being spawned, you have no control over your pasta (think DotA or LoL creeps). Your pasta doesn't like other people's pasta, so if they meet, they shoot sauce at each other until one dies. You get credits for other pastas your own pasta kill. Once a pasta is in the vicinity of a plate, it will try to conquer it. If it succeeds, it will spawn a new pasta. If it fails, it will die. The more you order, the more pastas you can spawn. The more pastas you spawn, the more you can order. The more you order, the more you can evolve. The more you evolve, the more you can order. The more you order, the more you can evolve. The more you evolve, the more you can order. The more you order, the more you can evolve. The more you evolve, the more you can order. The more you order, the more you can evolve
Tokens generated 128
E2E Generation time xx.xxxxs
First token latency xx.xxxxs
```

View file

@ -0,0 +1,82 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import argparse
import time
import numpy as np
torch.nn.Linear.reset_parameters = lambda x: None
seed=42
torch.manual_seed(seed)
np.random.seed(seed)
long_input = "It is done, and submitted. You can play 'Survival of the Tastiest' on Android, and on the web. Playing on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. There is a lot I'd like to talk about. I will go through every topic, insted of making the typical what went right/wrong list. Concept Working over the theme was probably one of the hardest tasks which I had to face. Originally, I had an idea of what kind of game I wanted to develop, gameplay wise - something with a lot of enemies/actors, simple graphics, maybe set in space, controlled from a top-down view. I was confident that I could fit any theme around it. In the end, the problem with a theme like 'Evolution' in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game? In a game, you need to control something to reach an objective. That control goes against what evolution is supposed to be like. If you allow the user to pick how to evolve something, it's not evolution anymore - it's the equivalent of intelligent design, the fable invented by creationists to combat the idea of evolution. Being agnostic and a Pastafarian, that's not something that rubbed me the right way. Hence, my biggest dillema when deciding what to create was not with what I wanted to create, but with what I did not. I didn't want to create an 'intelligent design' simulator and wrongly call it evolution. This is a problem, of course, every other contestant also had to face. And judging by the entries submitted, not many managed to work around it. I'd say the only real solution was through the use of artificial selection, somehow. So far, I have not seen any entry using this at its core gameplay. Alas, this is just a fun competition and after a while I decided not to be as strict with the game idea, and allowed myself to pick whatever I thought would work out. My initial idea was to create something where humanity tried to evolve to a next level but had some kind of foe trying to stop them from doing so. I kind of had this image of human souls flying in space towards a monolith or a space baby (all based in 2001: A Space Odyssey of course) but I couldn't think of compelling (read: serious) mechanics for that. Borgs were my next inspiration, as their whole hypothesis fit pretty well into the evolution theme. But how to make it work? Are you the borg, or fighting the Borg? The third and final idea came to me through my girlfriend, who somehow gave me the idea of making something about the evolution of Pasta. The more I thought about it the more it sounded like it would work, so I decided to go with it. Conversations with my inspiring co-worker Roushey (who also created the 'Mechanical Underdogs' signature logo for my intros) further matured the concept, as it involved into the idea of having individual pieces of pasta flying around and trying to evolve until they became all-powerful. A secondary idea here was that the game would work to explain how the Flying Spaghetti Monster came to exist - by evolving from a normal dinner table. So the idea evolved more or less into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is having your pasta conquer all the plates on the table). But they are really autonomous, so after being spawned, you have no control over your pasta (think DotA or LoL creeps). Your pasta doesn't like other people's pasta, so if they meet, they shoot sauce at each other until one dies. You get credits for other pastas your own pasta kill. Once a pasta is in the vicinity of a plate,"
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for GPT-J model')
parser.add_argument('--repo-id-or-model-path', type=str, default="EleutherAI/gpt-j-6b",
help='The huggingface repo id for the GPT-J (e.g. `EleutherAI/gpt-j-6b`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default=long_input,
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=128,
help='Max tokens to predict')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
# Load model in optimized fp16 here.
# Set `speculative=True`` to enable speculative decoding,
# it only works when load_in_low_bit="fp16" on Intel GPU or load_in_low_bit="bf16" on latest Intel Xeon CPU
model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
torch_dtype=torch.float16,
load_in_low_bit="fp16",
speculative=True,
trust_remote_code=True,
use_cache=True)
model = model.to('xpu')
tokenizer = AutoTokenizer.from_pretrained(model_path)
with torch.inference_mode():
prompt = args.prompt
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
# warmup
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
do_sample=False,
th_stop_draft=0.6)
output_str = tokenizer.decode(output[0])
# speculative decoding
st = time.perf_counter()
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
do_sample=False,
th_stop_draft=0.6)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
torch.xpu.synchronize()
end = time.perf_counter()
print(output_str)
print(f"Tokens generated {model.n_token_generated}")
print(f"E2E Generation time {(end - st):.4f}s")
print(f"First token latency {model.first_token_time:.4f}s")

View file

@ -148,7 +148,7 @@ def gptj_attention_forward(
self.head_dim,
past_length,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
dtype=cache_v.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
@ -162,7 +162,7 @@ def gptj_attention_forward(
self.head_dim,
kv_seq_len,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=key.dtype,
dtype=value.dtype,
device=device)
key_cache[:] = key
value_cache[:] = value

View file

@ -456,21 +456,22 @@ def speculative_generate(self,
appended_len = step_draft + step
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
draft_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
forward_args = {
"input_ids": draft_current_input_ids,
"past_key_values": draft_past_key_values,
"attention_mask": draft_attention_mask,
"return_dict": True,
"use_cache": True,
}
if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
draft_output = draft_model(input_ids=draft_current_input_ids,
past_key_values=draft_past_key_values,
attention_mask=draft_attention_mask,
return_dict=True,
use_cache=True,
position_ids=position_ids)
else:
draft_output = draft_model(input_ids=draft_current_input_ids,
past_key_values=draft_past_key_values,
attention_mask=draft_attention_mask,
return_dict=True,
use_cache=True)
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = draft_past_key_values[0][0].size(1)
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
forward_args["position_ids"] = position_ids
draft_output = draft_model(**forward_args)
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
draft_generate_ids[:, 1:step_draft+1]), dim=-1)
logits = draft_output['logits']
@ -548,23 +549,27 @@ def speculative_generate(self,
logits = output[0]
past_key_values = output[1]
else:
forward_args = {
"input_ids": drafted_input_ids,
"past_key_values": past_key_values,
"attention_mask": cur_attention_mask,
"return_dict": True,
"use_cache": True,
}
if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
output = self(input_ids=drafted_input_ids,
past_key_values=past_key_values,
attention_mask=cur_attention_mask,
return_dict=True,
use_cache=True,
position_ids=position_ids)
else:
output = self(input_ids=drafted_input_ids,
past_key_values=past_key_values,
attention_mask=cur_attention_mask,
return_dict=True,
use_cache=True)
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = past_key_values[0][0].size(1)
input_len = drafted_input_ids.shape[1]
position_ids = torch.arange(past_length, input_len + past_length,
dtype=torch.long, device=drafted_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, input_len)
forward_args["position_ids"] = position_ids
output = self(**forward_args)
if isinstance(output, dict):
logits = output['logits']
past_key_values = output['past_key_values']
@ -639,7 +644,7 @@ def speculative_generate(self,
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
for _, key_cache, value_cache, beam_idx in past_key_values]
else:
if self.config.model_type == "qwen":
if self.config.model_type in ["qwen", "gptj"]:
past_key_values = [
(k[:, :-(max_of_max_matched - max_matched), :],
v[:, :-(max_of_max_matched - max_matched), :])