Update Eagle example to Eagle2+ipex-llm integration (#11717)

* update to e2 example

* update

* update
This commit is contained in:
Jiao Wang 2024-10-17 14:16:14 +08:00 committed by GitHub
parent 26390f9213
commit 667f0db466
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 261 additions and 122 deletions

View file

@ -17,8 +17,9 @@ Step 3, you also need to download and install [Intel® oneAPI Base Toolkit](http
- Intel Data Center GPU Max Series
- Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series
## Example - EAGLE Speculative Sampling with IPEX-LLM on MT-bench
## Example - EAGLE-2 Speculative Sampling with IPEX-LLM on MT-bench
In this example, we run inference for a Llama2 model to showcase the speed of EAGLE with IPEX-LLM on MT-bench data on Intel GPUs.
We use EAGLE-2 which have better performance than EAGLE-1
### 1. Install
#### 1.1 Installation on Linux
@ -28,8 +29,10 @@ conda create -n llm python=3.11
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install eagle-llm
git clone https://github.com/SafeAILab/EAGLE.git
cd EAGLE
pip install -r requirements.txt
pip install -e .
```
#### 1.2 Installation on Windows
@ -42,10 +45,10 @@ pip install dpcpp-cpp-rt==2024.0.2 mkl-dpcpp==2024.0.0 onednn==2024.0.0
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
git clone https://github.com/SafeAILab/EAGLE.git
cd EAGLE
pip install -r requirements.txt
pip install transformers==4.36.2
pip install gradio==3.50.2
pip install eagle-llm
pip install -e .
```
### 2. Configures OneAPI environment variables for Linux
@ -89,7 +92,7 @@ export ENABLE_SDP_FUSION=1
### 4. Running Example
You can test the speed of EAGLE speculative sampling with ipex-llm on MT-bench using the following command.
```bash
python -m evaluation.gen_ea_answer_llama2chat\
python -m evaluation.gen_ea_answer_llama2chat_e2_ipex_optimize\
--ea-model-path [path of EAGLE weight]\
--base-model-path [path of the original model]\
--enable-ipex-llm\

View file

@ -34,105 +34,223 @@
import argparse
import json
import os
from accelerate.utils import set_seed
set_seed(0)
import time
import torch
import shortuuid
from fastchat.llm_judge.common import load_questions
from fastchat.model import get_conversation_template
from tqdm import tqdm
from eagle.model.ea_model import EaModel
from eagle.model.utils import *
from eagle.model.utils import prepare_logits_processor, evaluate_posterior
from eagle.model.kv_cache import initialize_past_key_values
from eagle.model.choices import *
from eagle.modeling_eagle import forward_with_tree_mask
from ipex_llm import optimize_model
from ipex_llm.transformers import AutoModelForCausalLM
def ea_forward(input_ids, model, tokenizer, tree_choices, logits_processor=None, max_steps=512):
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
input_ids = input_ids.clone()
model.ea_layer.reset_kv()
class Timer:
def __init__(self,name):
self.name = name
if hasattr(model, "tree_choices") and model.tree_choices == tree_choices:
tree_buffers = model.tree_buffers
def __enter__(self):
torch.xpu.synchronize()
self.start = time.perf_counter()
def __exit__(self, exc_type, exc_value, traceback):
torch.xpu.synchronize()
elapsed = time.perf_counter() - self.start
print(f'{self.name} took {elapsed} seconds')
def initialize_tree(input_ids, model, logits_processor):
# outputs, orig, hidden_states = model(
# input_ids, past_key_values=past_key_values, output_orig=True
# )
hidden_states, past_key_values = forward_with_tree_mask(model.base_model.model, input_ids=input_ids)
orig = model.base_model.lm_head(hidden_states)
if logits_processor is not None:
logits = orig[:, -1]
logits = logits_processor(None, logits)
probabilities = torch.nn.functional.softmax(logits, dim=1)
token = torch.multinomial(probabilities, 1)
else:
tree_buffers = generate_tree_buffers(
tree_choices, device=model.base_model.model.layers[-1].self_attn.q_proj.weight.device
)
tree_buffers["retrieve_indices_head"] = tree_buffers["retrieve_indices"].to(
model.base_model.lm_head.weight.device)
model.tree_buffers = tree_buffers
model.tree_choices = tree_choices
token = torch.argmax(orig[:, -1])
token = token[None, None]
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
# Clone the output hidden states
# Initialize the past key and value states
if hasattr(model, "past_key_values"):
past_key_values = model.past_key_values
past_key_values_data = model.past_key_values_data
current_length_data = model.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(model.base_model)
model.past_key_values = past_key_values
model.past_key_values_data = past_key_values_data
model.current_length_data = current_length_data
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor)
return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token, past_key_values
input_len = input_ids.shape[1]
reset_tree_mode(model)
tree_logits, logits, hidden_state, sample_token = initialize_tree(
input_ids, model, tree_buffers["tree_attn_mask"], past_key_values, logits_processor
)
new_token = 0
for idx in range(max_steps):
candidates, cart_candidates_prob, tree_candidates = generate_candidates(
tree_logits,
tree_buffers["tree_indices"],
tree_buffers["retrieve_indices"],
sample_token,
logits_processor
)
logits, hidden_state_new, outputs = tree_decoding(
def tree_decoding(
model,
tree_candidates,
past_key_values,
tree_buffers["tree_position_ids"],
tree_position_ids,
input_ids,
tree_buffers["retrieve_indices_head"],
)
best_candidate, accept_length, sample_p = evaluate_posterior(
logits, candidates, logits_processor, cart_candidates_prob, tree_logits[2], tree_buffers["p_indices"],
tree_candidates, tree_buffers["b_indices"]
)
input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
retrieve_indices,
attention_mask=None,
tree_mask=None,
):
position_ids = tree_position_ids + input_ids.shape[1]
hidden_states, past_key_values = forward_with_tree_mask(model.base_model.model, input_ids=tree_candidates, past_key_values=past_key_values,
position_ids=position_ids, tree_mask=tree_mask, attention_mask=attention_mask)
tree_logits = model.base_model.lm_head(hidden_states)
logits = tree_logits[0, retrieve_indices]
return logits, hidden_states, past_key_values
def update_inference_inputs(
input_ids,
candidates,
best_candidate,
accept_length,
tree_buffers["retrieve_indices"],
retrieve_indices,
logits_processor,
logits,
tree_logits,
new_token,
past_key_values_data,
current_length_data,
past_key_values,
# current_length_data,
model,
hidden_state_new,
sample_p
):
prev_input_len = input_ids.shape[1]
# Map the best candidate indices to the original indices in the sequence
select_indices = (
retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
)
# Append the tokens from the best candidate to the input sequence
input_ids = torch.cat(
[input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1
)
new_kv = ()
for past_key_values_data in past_key_values:
layer_kv = ()
for korv in past_key_values_data:
tgt = korv[:, :, select_indices, :]
dst = korv[:, :, prev_input_len: prev_input_len + tgt.shape[-2], :]
dst.copy_(tgt, non_blocking=True)
layer_kv += (korv[:, :, : prev_input_len + tgt.shape[-2], :],)
new_kv += (layer_kv,)
retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices]
accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1]
prob = sample_p
if logits_processor is not None:
token = torch.multinomial(prob, 1)
token = token[None]
else:
token = torch.argmax(prob)
token = token[None, None]
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new,
input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1),
head=model.base_model.lm_head,logits_processor=logits_processor)
new_token += accept_length + 1
return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token, new_kv
def eagenerate(
model,
input_ids,
temperature=0.0,
top_p=0.0,
top_k=0.0,
max_new_tokens=512,
max_length=2048,
log=False,
is_llama3=False,
):
if is_llama3:
stop_token_id = model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
max_length=max_length-model.ea_layer.total_tokens-10
if temperature > 1e-5:
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
else:
logits_processor = None
#assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
input_ids = input_ids.clone()
model.ea_layer.reset_kv()
input_len = input_ids.shape[1]
# with Timer("initialize_tree"):
draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token, past_key_values = initialize_tree(
input_ids, model, logits_processor
)
new_token = 0
for idx in range(max_length):
# with Timer("all"):
draft_tokens=draft_tokens.to(input_ids.device)
#with Timer("tree_decoding"):
logits, hidden_state_new, past_key_values = tree_decoding(
model,
draft_tokens,
past_key_values,
tree_position_ids,
input_ids,
retrieve_indices,
tree_mask=tree_mask,
)
draft_tokens=torch.cat((draft_tokens,padding),dim=1)
candidates=draft_tokens[0,retrieve_indices]
# with Timer("evaluate_posterior"):
best_candidate, accept_length, sample_p = evaluate_posterior(
logits, candidates, logits_processor
)
# print("new_token: ", (accept_length+1).item())
# with Timer("update_inference_inputs"):
input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token, past_key_values = update_inference_inputs(
input_ids,
candidates,
best_candidate,
accept_length,
retrieve_indices,
logits_processor,
new_token,
past_key_values,
# current_length_data,
model,
hidden_state,
hidden_state_new,
sample_p
)
if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
if is_llama3:
if stop_token_id in input_ids[0, input_len:].tolist():
break
if new_token > 1024:
if model.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
break
if input_ids.shape[1] > 1960:
if new_token > max_new_tokens:
break
if input_ids.shape[1] > max_length:
break
if not log:
return input_ids
else:
return input_ids, new_token, idx
@ -147,8 +265,8 @@ def run_eval(
max_new_token,
num_choices,
temperature,
tree_choices,
enable_ipex_llm,
args,
):
questions = load_questions(question_file, question_begin, question_end)
shuffled_ids = [q["question_id"] for q in questions]
@ -168,8 +286,8 @@ def run_eval(
max_new_token,
num_choices,
temperature,
tree_choices,
enable_ipex_llm,
args
)
)
@ -184,34 +302,32 @@ def get_model_answers(
max_new_token,
num_choices,
temperature,
tree_choices,
enable_ipex_llm,
args
):
try:
model = EaModel.from_pretrained(
base_model_path=base_model_path,
ea_model_path=ea_model_path,
#torch_dtype=torch.float16,
torch_dtype=torch.float32,
torch_dtype=torch.float16,
# torch_dtype=torch.float32,
low_cpu_mem_usage=True,
# load_in_8bit=True,
device_map="auto"
)
except ValueError:
print("Using sequential device_map.")
model = EaModel.from_pretrained(
base_model_path=base_model_path,
ea_model_path=ea_model_path,
#torch_dtype=torch.float16,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
# load_in_8bit=True,
device_map="sequential"
total_token=args.total_token,
depth=args.depth,
top_k=args.top_k,
)
if enable_ipex_llm:
# single line of change to enable ipex-llm
model = optimize_model(model, low_bit='sym_int4', optimize_llm=False)
model.to("xpu")
# Use optimized int4 base model to replace base model in EaModel
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, load_in_4bit=True, use_cache=True,
torch_dtype=torch.float16)
model.base_model = base_model
# Also optimize draft model in EaModel
model.ea_layer = optimize_model(model.ea_layer)
model = model.to("xpu")
model.ea_layer.tree_mask_init = model.ea_layer.tree_mask_init.to("xpu")
model.ea_layer.position_ids = model.ea_layer.position_ids.to("xpu")
tokenizer = model.get_tokenizer()
if temperature > 1e-5:
@ -246,15 +362,17 @@ def get_model_answers(
inputs = tokenizer([prompt], return_tensors="pt").to("xpu")
input_ids = inputs.input_ids
torch.xpu.synchronize()
start_time = time.time()
output_ids, new_token, idx = ea_forward(
torch.as_tensor(input_ids),
model,
tokenizer,
tree_choices,
logits_processor,
)
output_ids, new_token, idx = eagenerate(
model,
torch.as_tensor(input_ids),
temperature=temperature,
max_new_tokens=max_new_token,
log=True
)
torch.xpu.synchronize()
total_time = time.time() - start_time
output_ids = output_ids[0][len(input_ids[0]):]
# be consistent with the template's stop_token_ids
@ -311,14 +429,16 @@ def get_model_answers(
input_ids = inputs.input_ids
try:
torch.xpu.synchronize()
start_time = time.time()
output_ids, new_token, idx = ea_forward(
torch.as_tensor(input_ids),
output_ids, new_token, idx = eagenerate(
model,
tokenizer,
tree_choices,
logits_processor,
torch.as_tensor(input_ids),
temperature=temperature,
max_new_tokens=max_new_token,
log=True,
)
torch.xpu.synchronize()
total_time = time.time() - start_time
output_ids = output_ids[0][len(input_ids[0]):]
@ -417,6 +537,22 @@ if __name__ == "__main__":
default=1024,
help="The maximum number of new generated tokens.",
)
parser.add_argument(
"--total-token",
type=int,
default=8,
help="The total number of nodes in the draft tree",
)
parser.add_argument(
"--depth",
type=int,
default=4,
)
parser.add_argument(
"--top-k",
type=int,
default=5,
)
parser.add_argument(
"--num-choices",
type=int,
@ -426,7 +562,7 @@ if __name__ == "__main__":
parser.add_argument(
"--temperature",
type=float,
default=1.0,
default=0.0,
)
parser.add_argument(
@ -443,7 +579,7 @@ if __name__ == "__main__":
args = parser.parse_args()
args.model_id = args.model_id + "-temperature-" + str(args.temperature)
args.tree_choices = eval(args.tree_choices)
# args.tree_choices = eval(args.tree_choices)
question_file = f"data/{args.bench_name}/question.jsonl"
if args.answer_file:
@ -464,8 +600,8 @@ if __name__ == "__main__":
args.max_new_token,
args.num_choices,
args.temperature,
args.tree_choices,
args.enable_ipex_llm,
args
)
reorg_answer_file(answer_file)