Update Eagle example to Eagle2+ipex-llm integration (#11717)
* update to e2 example * update * update
This commit is contained in:
parent
26390f9213
commit
667f0db466
2 changed files with 261 additions and 122 deletions
|
|
@ -17,8 +17,9 @@ Step 3, you also need to download and install [Intel® oneAPI Base Toolkit](http
|
||||||
- Intel Data Center GPU Max Series
|
- Intel Data Center GPU Max Series
|
||||||
- Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series
|
- Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series
|
||||||
|
|
||||||
## Example - EAGLE Speculative Sampling with IPEX-LLM on MT-bench
|
## Example - EAGLE-2 Speculative Sampling with IPEX-LLM on MT-bench
|
||||||
In this example, we run inference for a Llama2 model to showcase the speed of EAGLE with IPEX-LLM on MT-bench data on Intel GPUs.
|
In this example, we run inference for a Llama2 model to showcase the speed of EAGLE with IPEX-LLM on MT-bench data on Intel GPUs.
|
||||||
|
We use EAGLE-2 which have better performance than EAGLE-1
|
||||||
|
|
||||||
### 1. Install
|
### 1. Install
|
||||||
#### 1.1 Installation on Linux
|
#### 1.1 Installation on Linux
|
||||||
|
|
@ -28,8 +29,10 @@ conda create -n llm python=3.11
|
||||||
conda activate llm
|
conda activate llm
|
||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
pip install eagle-llm
|
git clone https://github.com/SafeAILab/EAGLE.git
|
||||||
|
cd EAGLE
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 1.2 Installation on Windows
|
#### 1.2 Installation on Windows
|
||||||
|
|
@ -42,10 +45,10 @@ pip install dpcpp-cpp-rt==2024.0.2 mkl-dpcpp==2024.0.0 onednn==2024.0.0
|
||||||
|
|
||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
|
git clone https://github.com/SafeAILab/EAGLE.git
|
||||||
|
cd EAGLE
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
pip install transformers==4.36.2
|
pip install -e .
|
||||||
pip install gradio==3.50.2
|
|
||||||
pip install eagle-llm
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Configures OneAPI environment variables for Linux
|
### 2. Configures OneAPI environment variables for Linux
|
||||||
|
|
@ -89,7 +92,7 @@ export ENABLE_SDP_FUSION=1
|
||||||
### 4. Running Example
|
### 4. Running Example
|
||||||
You can test the speed of EAGLE speculative sampling with ipex-llm on MT-bench using the following command.
|
You can test the speed of EAGLE speculative sampling with ipex-llm on MT-bench using the following command.
|
||||||
```bash
|
```bash
|
||||||
python -m evaluation.gen_ea_answer_llama2chat\
|
python -m evaluation.gen_ea_answer_llama2chat_e2_ipex_optimize\
|
||||||
--ea-model-path [path of EAGLE weight]\
|
--ea-model-path [path of EAGLE weight]\
|
||||||
--base-model-path [path of the original model]\
|
--base-model-path [path of the original model]\
|
||||||
--enable-ipex-llm\
|
--enable-ipex-llm\
|
||||||
|
|
|
||||||
|
|
@ -34,106 +34,224 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from accelerate.utils import set_seed
|
|
||||||
set_seed(0)
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import shortuuid
|
import shortuuid
|
||||||
from fastchat.llm_judge.common import load_questions
|
from fastchat.llm_judge.common import load_questions
|
||||||
from fastchat.model import get_conversation_template
|
from fastchat.model import get_conversation_template
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from eagle.model.ea_model import EaModel
|
from eagle.model.ea_model import EaModel
|
||||||
from eagle.model.utils import *
|
from eagle.model.utils import prepare_logits_processor, evaluate_posterior
|
||||||
from eagle.model.kv_cache import initialize_past_key_values
|
from eagle.model.kv_cache import initialize_past_key_values
|
||||||
from eagle.model.choices import *
|
from eagle.model.choices import *
|
||||||
|
from eagle.modeling_eagle import forward_with_tree_mask
|
||||||
from ipex_llm import optimize_model
|
from ipex_llm import optimize_model
|
||||||
|
from ipex_llm.transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
def ea_forward(input_ids, model, tokenizer, tree_choices, logits_processor=None, max_steps=512):
|
class Timer:
|
||||||
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
|
def __init__(self,name):
|
||||||
# Avoid modifying the input_ids in-place
|
self.name = name
|
||||||
input_ids = input_ids.clone()
|
|
||||||
model.ea_layer.reset_kv()
|
|
||||||
|
|
||||||
if hasattr(model, "tree_choices") and model.tree_choices == tree_choices:
|
def __enter__(self):
|
||||||
tree_buffers = model.tree_buffers
|
torch.xpu.synchronize()
|
||||||
|
self.start = time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
elapsed = time.perf_counter() - self.start
|
||||||
|
print(f'{self.name} took {elapsed} seconds')
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_tree(input_ids, model, logits_processor):
|
||||||
|
# outputs, orig, hidden_states = model(
|
||||||
|
# input_ids, past_key_values=past_key_values, output_orig=True
|
||||||
|
# )
|
||||||
|
hidden_states, past_key_values = forward_with_tree_mask(model.base_model.model, input_ids=input_ids)
|
||||||
|
orig = model.base_model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if logits_processor is not None:
|
||||||
|
logits = orig[:, -1]
|
||||||
|
logits = logits_processor(None, logits)
|
||||||
|
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
||||||
|
token = torch.multinomial(probabilities, 1)
|
||||||
else:
|
else:
|
||||||
tree_buffers = generate_tree_buffers(
|
token = torch.argmax(orig[:, -1])
|
||||||
tree_choices, device=model.base_model.model.layers[-1].self_attn.q_proj.weight.device
|
token = token[None, None]
|
||||||
)
|
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
|
||||||
tree_buffers["retrieve_indices_head"] = tree_buffers["retrieve_indices"].to(
|
# Clone the output hidden states
|
||||||
model.base_model.lm_head.weight.device)
|
|
||||||
model.tree_buffers = tree_buffers
|
|
||||||
model.tree_choices = tree_choices
|
|
||||||
|
|
||||||
# Initialize the past key and value states
|
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor)
|
||||||
if hasattr(model, "past_key_values"):
|
return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token, past_key_values
|
||||||
past_key_values = model.past_key_values
|
|
||||||
past_key_values_data = model.past_key_values_data
|
|
||||||
current_length_data = model.current_length_data
|
|
||||||
# Reset the past key and value states
|
|
||||||
current_length_data.zero_()
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
past_key_values,
|
|
||||||
past_key_values_data,
|
|
||||||
current_length_data,
|
|
||||||
) = initialize_past_key_values(model.base_model)
|
|
||||||
model.past_key_values = past_key_values
|
|
||||||
model.past_key_values_data = past_key_values_data
|
|
||||||
model.current_length_data = current_length_data
|
|
||||||
|
|
||||||
input_len = input_ids.shape[1]
|
def tree_decoding(
|
||||||
reset_tree_mode(model)
|
model,
|
||||||
tree_logits, logits, hidden_state, sample_token = initialize_tree(
|
tree_candidates,
|
||||||
input_ids, model, tree_buffers["tree_attn_mask"], past_key_values, logits_processor
|
past_key_values,
|
||||||
|
tree_position_ids,
|
||||||
|
input_ids,
|
||||||
|
retrieve_indices,
|
||||||
|
attention_mask=None,
|
||||||
|
tree_mask=None,
|
||||||
|
):
|
||||||
|
position_ids = tree_position_ids + input_ids.shape[1]
|
||||||
|
|
||||||
|
hidden_states, past_key_values = forward_with_tree_mask(model.base_model.model, input_ids=tree_candidates, past_key_values=past_key_values,
|
||||||
|
position_ids=position_ids, tree_mask=tree_mask, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
tree_logits = model.base_model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
logits = tree_logits[0, retrieve_indices]
|
||||||
|
|
||||||
|
return logits, hidden_states, past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
def update_inference_inputs(
|
||||||
|
input_ids,
|
||||||
|
candidates,
|
||||||
|
best_candidate,
|
||||||
|
accept_length,
|
||||||
|
retrieve_indices,
|
||||||
|
logits_processor,
|
||||||
|
new_token,
|
||||||
|
past_key_values,
|
||||||
|
# current_length_data,
|
||||||
|
model,
|
||||||
|
hidden_state_new,
|
||||||
|
sample_p
|
||||||
|
):
|
||||||
|
prev_input_len = input_ids.shape[1]
|
||||||
|
# Map the best candidate indices to the original indices in the sequence
|
||||||
|
select_indices = (
|
||||||
|
retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
|
||||||
|
)
|
||||||
|
# Append the tokens from the best candidate to the input sequence
|
||||||
|
input_ids = torch.cat(
|
||||||
|
[input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1
|
||||||
)
|
)
|
||||||
new_token = 0
|
|
||||||
|
|
||||||
for idx in range(max_steps):
|
new_kv = ()
|
||||||
candidates, cart_candidates_prob, tree_candidates = generate_candidates(
|
|
||||||
tree_logits,
|
for past_key_values_data in past_key_values:
|
||||||
tree_buffers["tree_indices"],
|
layer_kv = ()
|
||||||
tree_buffers["retrieve_indices"],
|
for korv in past_key_values_data:
|
||||||
sample_token,
|
tgt = korv[:, :, select_indices, :]
|
||||||
logits_processor
|
dst = korv[:, :, prev_input_len: prev_input_len + tgt.shape[-2], :]
|
||||||
)
|
dst.copy_(tgt, non_blocking=True)
|
||||||
logits, hidden_state_new, outputs = tree_decoding(
|
layer_kv += (korv[:, :, : prev_input_len + tgt.shape[-2], :],)
|
||||||
|
new_kv += (layer_kv,)
|
||||||
|
|
||||||
|
retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices]
|
||||||
|
accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1]
|
||||||
|
|
||||||
|
prob = sample_p
|
||||||
|
if logits_processor is not None:
|
||||||
|
token = torch.multinomial(prob, 1)
|
||||||
|
token = token[None]
|
||||||
|
else:
|
||||||
|
token = torch.argmax(prob)
|
||||||
|
token = token[None, None]
|
||||||
|
|
||||||
|
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new,
|
||||||
|
input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1),
|
||||||
|
head=model.base_model.lm_head,logits_processor=logits_processor)
|
||||||
|
|
||||||
|
|
||||||
|
new_token += accept_length + 1
|
||||||
|
|
||||||
|
return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token, new_kv
|
||||||
|
|
||||||
|
|
||||||
|
def eagenerate(
|
||||||
model,
|
model,
|
||||||
tree_candidates,
|
|
||||||
past_key_values,
|
|
||||||
tree_buffers["tree_position_ids"],
|
|
||||||
input_ids,
|
input_ids,
|
||||||
tree_buffers["retrieve_indices_head"],
|
temperature=0.0,
|
||||||
|
top_p=0.0,
|
||||||
|
top_k=0.0,
|
||||||
|
max_new_tokens=512,
|
||||||
|
max_length=2048,
|
||||||
|
log=False,
|
||||||
|
is_llama3=False,
|
||||||
|
|
||||||
|
):
|
||||||
|
if is_llama3:
|
||||||
|
stop_token_id = model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||||||
|
max_length=max_length-model.ea_layer.total_tokens-10
|
||||||
|
|
||||||
|
if temperature > 1e-5:
|
||||||
|
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
|
||||||
|
else:
|
||||||
|
logits_processor = None
|
||||||
|
#assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
|
||||||
|
# Avoid modifying the input_ids in-place
|
||||||
|
|
||||||
|
padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
|
||||||
|
input_ids = input_ids.clone()
|
||||||
|
model.ea_layer.reset_kv()
|
||||||
|
|
||||||
|
input_len = input_ids.shape[1]
|
||||||
|
# with Timer("initialize_tree"):
|
||||||
|
draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token, past_key_values = initialize_tree(
|
||||||
|
input_ids, model, logits_processor
|
||||||
)
|
)
|
||||||
best_candidate, accept_length, sample_p = evaluate_posterior(
|
|
||||||
logits, candidates, logits_processor, cart_candidates_prob, tree_logits[2], tree_buffers["p_indices"],
|
new_token = 0
|
||||||
tree_candidates, tree_buffers["b_indices"]
|
|
||||||
)
|
for idx in range(max_length):
|
||||||
input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
|
# with Timer("all"):
|
||||||
input_ids,
|
draft_tokens=draft_tokens.to(input_ids.device)
|
||||||
candidates,
|
#with Timer("tree_decoding"):
|
||||||
best_candidate,
|
logits, hidden_state_new, past_key_values = tree_decoding(
|
||||||
accept_length,
|
model,
|
||||||
tree_buffers["retrieve_indices"],
|
draft_tokens,
|
||||||
logits_processor,
|
past_key_values,
|
||||||
logits,
|
tree_position_ids,
|
||||||
tree_logits,
|
input_ids,
|
||||||
new_token,
|
retrieve_indices,
|
||||||
past_key_values_data,
|
tree_mask=tree_mask,
|
||||||
current_length_data,
|
)
|
||||||
model,
|
|
||||||
hidden_state,
|
draft_tokens=torch.cat((draft_tokens,padding),dim=1)
|
||||||
hidden_state_new,
|
candidates=draft_tokens[0,retrieve_indices]
|
||||||
sample_p
|
# with Timer("evaluate_posterior"):
|
||||||
)
|
best_candidate, accept_length, sample_p = evaluate_posterior(
|
||||||
if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
|
logits, candidates, logits_processor
|
||||||
break
|
)
|
||||||
if new_token > 1024:
|
|
||||||
break
|
# print("new_token: ", (accept_length+1).item())
|
||||||
if input_ids.shape[1] > 1960:
|
# with Timer("update_inference_inputs"):
|
||||||
break
|
input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token, past_key_values = update_inference_inputs(
|
||||||
return input_ids, new_token, idx
|
input_ids,
|
||||||
|
candidates,
|
||||||
|
best_candidate,
|
||||||
|
accept_length,
|
||||||
|
retrieve_indices,
|
||||||
|
logits_processor,
|
||||||
|
new_token,
|
||||||
|
past_key_values,
|
||||||
|
# current_length_data,
|
||||||
|
model,
|
||||||
|
hidden_state_new,
|
||||||
|
sample_p
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_llama3:
|
||||||
|
if stop_token_id in input_ids[0, input_len:].tolist():
|
||||||
|
break
|
||||||
|
|
||||||
|
if model.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
|
||||||
|
break
|
||||||
|
if new_token > max_new_tokens:
|
||||||
|
break
|
||||||
|
if input_ids.shape[1] > max_length:
|
||||||
|
break
|
||||||
|
if not log:
|
||||||
|
return input_ids
|
||||||
|
else:
|
||||||
|
return input_ids, new_token, idx
|
||||||
|
|
||||||
|
|
||||||
def run_eval(
|
def run_eval(
|
||||||
|
|
@ -147,8 +265,8 @@ def run_eval(
|
||||||
max_new_token,
|
max_new_token,
|
||||||
num_choices,
|
num_choices,
|
||||||
temperature,
|
temperature,
|
||||||
tree_choices,
|
|
||||||
enable_ipex_llm,
|
enable_ipex_llm,
|
||||||
|
args,
|
||||||
):
|
):
|
||||||
questions = load_questions(question_file, question_begin, question_end)
|
questions = load_questions(question_file, question_begin, question_end)
|
||||||
shuffled_ids = [q["question_id"] for q in questions]
|
shuffled_ids = [q["question_id"] for q in questions]
|
||||||
|
|
@ -168,8 +286,8 @@ def run_eval(
|
||||||
max_new_token,
|
max_new_token,
|
||||||
num_choices,
|
num_choices,
|
||||||
temperature,
|
temperature,
|
||||||
tree_choices,
|
|
||||||
enable_ipex_llm,
|
enable_ipex_llm,
|
||||||
|
args
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -184,34 +302,32 @@ def get_model_answers(
|
||||||
max_new_token,
|
max_new_token,
|
||||||
num_choices,
|
num_choices,
|
||||||
temperature,
|
temperature,
|
||||||
tree_choices,
|
|
||||||
enable_ipex_llm,
|
enable_ipex_llm,
|
||||||
|
args
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
model = EaModel.from_pretrained(
|
model = EaModel.from_pretrained(
|
||||||
base_model_path=base_model_path,
|
base_model_path=base_model_path,
|
||||||
ea_model_path=ea_model_path,
|
ea_model_path=ea_model_path,
|
||||||
#torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
torch_dtype=torch.float32,
|
# torch_dtype=torch.float32,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
# load_in_8bit=True,
|
# load_in_8bit=True,
|
||||||
device_map="auto"
|
total_token=args.total_token,
|
||||||
)
|
depth=args.depth,
|
||||||
except ValueError:
|
top_k=args.top_k,
|
||||||
print("Using sequential device_map.")
|
|
||||||
model = EaModel.from_pretrained(
|
|
||||||
base_model_path=base_model_path,
|
|
||||||
ea_model_path=ea_model_path,
|
|
||||||
#torch_dtype=torch.float16,
|
|
||||||
torch_dtype=torch.float32,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
# load_in_8bit=True,
|
|
||||||
device_map="sequential"
|
|
||||||
)
|
)
|
||||||
if enable_ipex_llm:
|
if enable_ipex_llm:
|
||||||
# single line of change to enable ipex-llm
|
# Use optimized int4 base model to replace base model in EaModel
|
||||||
model = optimize_model(model, low_bit='sym_int4', optimize_llm=False)
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, load_in_4bit=True, use_cache=True,
|
||||||
model.to("xpu")
|
torch_dtype=torch.float16)
|
||||||
|
model.base_model = base_model
|
||||||
|
# Also optimize draft model in EaModel
|
||||||
|
model.ea_layer = optimize_model(model.ea_layer)
|
||||||
|
model = model.to("xpu")
|
||||||
|
model.ea_layer.tree_mask_init = model.ea_layer.tree_mask_init.to("xpu")
|
||||||
|
model.ea_layer.position_ids = model.ea_layer.position_ids.to("xpu")
|
||||||
tokenizer = model.get_tokenizer()
|
tokenizer = model.get_tokenizer()
|
||||||
|
|
||||||
if temperature > 1e-5:
|
if temperature > 1e-5:
|
||||||
|
|
@ -246,15 +362,17 @@ def get_model_answers(
|
||||||
inputs = tokenizer([prompt], return_tensors="pt").to("xpu")
|
inputs = tokenizer([prompt], return_tensors="pt").to("xpu")
|
||||||
input_ids = inputs.input_ids
|
input_ids = inputs.input_ids
|
||||||
|
|
||||||
|
torch.xpu.synchronize()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
output_ids, new_token, idx = ea_forward(
|
|
||||||
torch.as_tensor(input_ids),
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
tree_choices,
|
|
||||||
logits_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
output_ids, new_token, idx = eagenerate(
|
||||||
|
model,
|
||||||
|
torch.as_tensor(input_ids),
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=max_new_token,
|
||||||
|
log=True
|
||||||
|
)
|
||||||
|
torch.xpu.synchronize()
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
output_ids = output_ids[0][len(input_ids[0]):]
|
output_ids = output_ids[0][len(input_ids[0]):]
|
||||||
# be consistent with the template's stop_token_ids
|
# be consistent with the template's stop_token_ids
|
||||||
|
|
@ -311,14 +429,16 @@ def get_model_answers(
|
||||||
input_ids = inputs.input_ids
|
input_ids = inputs.input_ids
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
torch.xpu.synchronize()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
output_ids, new_token, idx = ea_forward(
|
output_ids, new_token, idx = eagenerate(
|
||||||
torch.as_tensor(input_ids),
|
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
torch.as_tensor(input_ids),
|
||||||
tree_choices,
|
temperature=temperature,
|
||||||
logits_processor,
|
max_new_tokens=max_new_token,
|
||||||
|
log=True,
|
||||||
)
|
)
|
||||||
|
torch.xpu.synchronize()
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
output_ids = output_ids[0][len(input_ids[0]):]
|
output_ids = output_ids[0][len(input_ids[0]):]
|
||||||
|
|
||||||
|
|
@ -417,6 +537,22 @@ if __name__ == "__main__":
|
||||||
default=1024,
|
default=1024,
|
||||||
help="The maximum number of new generated tokens.",
|
help="The maximum number of new generated tokens.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--total-token",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="The total number of nodes in the draft tree",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--depth",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-choices",
|
"--num-choices",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -426,7 +562,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temperature",
|
"--temperature",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -443,7 +579,7 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.model_id = args.model_id + "-temperature-" + str(args.temperature)
|
args.model_id = args.model_id + "-temperature-" + str(args.temperature)
|
||||||
args.tree_choices = eval(args.tree_choices)
|
# args.tree_choices = eval(args.tree_choices)
|
||||||
|
|
||||||
question_file = f"data/{args.bench_name}/question.jsonl"
|
question_file = f"data/{args.bench_name}/question.jsonl"
|
||||||
if args.answer_file:
|
if args.answer_file:
|
||||||
|
|
@ -464,8 +600,8 @@ if __name__ == "__main__":
|
||||||
args.max_new_token,
|
args.max_new_token,
|
||||||
args.num_choices,
|
args.num_choices,
|
||||||
args.temperature,
|
args.temperature,
|
||||||
args.tree_choices,
|
|
||||||
args.enable_ipex_llm,
|
args.enable_ipex_llm,
|
||||||
|
args
|
||||||
)
|
)
|
||||||
|
|
||||||
reorg_answer_file(answer_file)
|
reorg_answer_file(answer_file)
|
||||||
Loading…
Reference in a new issue