LLM: reduce GPU 1st token latency and update example (#8763)

* reduce 1st token latency

* update example

* fix

* fix style

* update readme of gpu benchmark
This commit is contained in:
Ruonan Wang 2023-08-16 18:01:23 +08:00 committed by GitHub
parent 06609d9260
commit e9aa2bd890
4 changed files with 15 additions and 6 deletions

View file

@ -54,6 +54,13 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
prompt = "今天睡不着怎么办"
with torch.inference_mode():
# wamup two times as use ipex
for i in range(2):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
# collect performance data now
for i in range(5):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)

View file

@ -46,7 +46,7 @@ if __name__ == '__main__':
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True)
model = model.half().to('xpu')
model = model.to('xpu')
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path,

View file

@ -49,7 +49,7 @@ if __name__ == '__main__':
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True)
model = model.half().to('xpu')
model = model.to('xpu')
# Load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)

View file

@ -169,7 +169,6 @@ class ParamsQuant(torch.nn.Parameter):
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
return new_param
@ -244,6 +243,9 @@ class LinearQuant(nn.Linear):
if x_2d.is_contiguous() is False:
x_2d = x_2d.contiguous()
# current workaround to reduce first token latency of fp32 input
if x_2d.shape[0] > 1 and x_2d.dtype == torch.float32:
x_2d = x_2d.half()
# input format of linear_q4.forward is 1: input, 2: weight
result = linear_q4_0.forward(x_2d, x0)
new_shape = x_shape[:-1] + (self.out_len,)