Support lightweight-serving glm-4v-9b (#11994)

* enable glm-4v-9b serving

* update readme

* update for no image input
This commit is contained in:
Wang, Jian4 2024-09-05 09:25:08 +08:00 committed by GitHub
parent 75b19f8522
commit b3b2cd64b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 61 additions and 9 deletions

View file

@ -40,6 +40,9 @@ pip install fastapi uvicorn openai
pip install gradio # for gradio web UI pip install gradio # for gradio web UI
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
# for glm-4v-9b
pip install transformers==4.42.4 trl
# for internlm-xcomposer2-vl-7b # for internlm-xcomposer2-vl-7b
pip install transformers==4.31.0 pip install transformers==4.31.0
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
@ -190,9 +193,8 @@ curl http://localhost:8000/v1/chat/completions \
##### Image input ##### Image input
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run. image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) and [glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b) now. And they should both install specific transformers version to run.
```bash ```bash
wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
curl http://localhost:8000/v1/chat/completions \ curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -208,7 +210,7 @@ curl http://localhost:8000/v1/chat/completions \
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "./test.jpg" "url": "http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"
} }
} }
] ]

View file

@ -317,7 +317,10 @@ def get_prompt(messages) -> str:
if role == "system": if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n" prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user": elif role == "user":
prompt += f"[INST] {content} [/INST] " if "glm" in local_model.model_name.lower():
prompt += f"<|user|>\n{content}\n<|assistant|>"
else:
prompt += f"[INST] {content} [/INST] "
elif role == "assistant": elif role == "assistant":
prompt += f"{content} " prompt += f"{content} "
else: else:

View file

@ -16,8 +16,11 @@
import torch import torch
from transformers.utils import logging from transformers.utils import logging
import os
import time import time
import asyncio import asyncio
from PIL import Image
import requests
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -30,8 +33,12 @@ class ModelWorker:
self.model = self.load_model(checkpoint, low_bit, "audio") self.model = self.load_model(checkpoint, low_bit, "audio")
else: else:
model = self.load_model(checkpoint, low_bit) model = self.load_model(checkpoint, low_bit)
from ipex_llm.utils import BenchmarkWrapper if "glm-4v" not in checkpoint.lower():
self.model = BenchmarkWrapper(model, do_print=True) from ipex_llm.utils import BenchmarkWrapper
self.model = BenchmarkWrapper(model, do_print=True)
else:
# glm-4v-9b does not support benchmark_util now
self.model = model
end = time.perf_counter() end = time.perf_counter()
logger.info(f"Time to load weights: {end - start:.2f}s") logger.info(f"Time to load weights: {end - start:.2f}s")
self.waiting_requests = asyncio.Queue() self.waiting_requests = asyncio.Queue()
@ -49,12 +56,18 @@ class ModelWorker:
use_cache=True) use_cache=True)
else: else:
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
modules = None
if "glm-4" in model_path.lower():
modules = ["encoder.layers.35.mlp", "encoder.layers.36.mlp",
"encoder.layers.37.mlp", "encoder.layers.38.mlp",
"encoder.layers.39.mlp"]
try: try:
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=low_bit, load_in_low_bit=low_bit,
torch_dtype=self.dtype, torch_dtype=self.dtype,
optimize_model=True, optimize_model=True,
trust_remote_code=True, trust_remote_code=True,
modules_to_not_convert=modules,
use_cache=True,) use_cache=True,)
except: except:
model = AutoModel.from_pretrained(model_path, model = AutoModel.from_pretrained(model_path,
@ -62,10 +75,25 @@ class ModelWorker:
torch_dtype=self.dtype, torch_dtype=self.dtype,
optimize_model=True, optimize_model=True,
trust_remote_code=True, trust_remote_code=True,
modules_to_not_convert=modules,
use_cache=True,) use_cache=True,)
model = model.eval().to("xpu") model = model.eval().to("xpu")
return model return model
def get_local_image_path(self, image_path):
local_dir = './local_images/'
local_path = local_dir + os.path.basename(image_path)
if os.path.exists(image_path) or os.path.exists(local_path):
pass
else:
response = requests.get(image_path)
if response.status_code == 200:
if not os.path.exists(local_dir):
os.makedirs(local_dir)
with open(local_path, 'wb') as file:
file.write(response.content)
return local_path
async def add_asr_request(self, processor): async def add_asr_request(self, processor):
if self.waiting_requests.empty(): if self.waiting_requests.empty():
return return
@ -94,6 +122,7 @@ class ModelWorker:
plain_texts = prompt_request.inputs plain_texts = prompt_request.inputs
input_ids = None input_ids = None
inputs_embeds = None inputs_embeds = None
inputs = None
if "internlm-xcomposer2-vl-7b" in self.model_name.lower(): if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
lines = [ lines = [
"You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).", "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
@ -111,16 +140,30 @@ class ModelWorker:
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool() im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
input_ids = inputs["input_ids"].to('xpu') input_ids = inputs["input_ids"].to('xpu')
else: else:
image = self.model.encode_img(prompt_request.image_list[0]) # only process the first image now
local_path = self.get_local_image_path(prompt_request.image_list[0])
image = self.model.encode_img(local_path)
plain_texts = "<ImageHere>" + plain_texts plain_texts = "<ImageHere>" + plain_texts
inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts, inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
image, [], meta_instruction) image, [], meta_instruction)
inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype) inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
elif "glm-4v" in self.model_name.lower() and prompt_request.image_list is not None:
# only process the first image now
local_path = self.get_local_image_path(prompt_request.image_list[0])
image = Image.open(local_path)
inputs = tokenizer.apply_chat_template([{"role": "user", "image": image,
"content": plain_texts}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True)
inputs = inputs.to('xpu')
else: else:
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True) inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to('xpu') input_ids = inputs.input_ids.to('xpu')
parameters = prompt_request.parameters parameters = prompt_request.parameters
return input_ids, parameters, request_id, inputs_embeds return input_ids, parameters, request_id, inputs_embeds, inputs
@torch.no_grad() @torch.no_grad()
async def process_step(self, tokenizer, result_dict, processor=None): async def process_step(self, tokenizer, result_dict, processor=None):
@ -134,7 +177,8 @@ class ModelWorker:
streamer=self.streamer[request_id], streamer=self.streamer[request_id],
forced_decoder_ids=decoder_ids) forced_decoder_ids=decoder_ids)
else: else:
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer) input_ids, parameters, request_id, inputs_embeds, inputs = \
await self.add_request(tokenizer)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
def model_generate(): def model_generate():
@ -156,6 +200,9 @@ class ModelWorker:
elif inputs_embeds is not None: elif inputs_embeds is not None:
self.model.generate(inputs_embeds=inputs_embeds, self.model.generate(inputs_embeds=inputs_embeds,
streamer=self.streamer[request_id], **generate_kwargs) streamer=self.streamer[request_id], **generate_kwargs)
else:
self.model.generate(**inputs,
streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache() torch.xpu.empty_cache()
torch.xpu.synchronize() torch.xpu.synchronize()
from threading import Thread from threading import Thread