Support lightweight-serving glm-4v-9b (#11994)
* enable glm-4v-9b serving * update readme * update for no image input
This commit is contained in:
parent
75b19f8522
commit
b3b2cd64b4
3 changed files with 61 additions and 9 deletions
|
|
@ -40,6 +40,9 @@ pip install fastapi uvicorn openai
|
||||||
pip install gradio # for gradio web UI
|
pip install gradio # for gradio web UI
|
||||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||||
|
|
||||||
|
# for glm-4v-9b
|
||||||
|
pip install transformers==4.42.4 trl
|
||||||
|
|
||||||
# for internlm-xcomposer2-vl-7b
|
# for internlm-xcomposer2-vl-7b
|
||||||
pip install transformers==4.31.0
|
pip install transformers==4.31.0
|
||||||
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
|
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
|
||||||
|
|
@ -190,9 +193,8 @@ curl http://localhost:8000/v1/chat/completions \
|
||||||
|
|
||||||
##### Image input
|
##### Image input
|
||||||
|
|
||||||
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
|
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) and [glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b) now. And they should both install specific transformers version to run.
|
||||||
```bash
|
```bash
|
||||||
wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
|
|
||||||
curl http://localhost:8000/v1/chat/completions \
|
curl http://localhost:8000/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
|
|
@ -208,7 +210,7 @@ curl http://localhost:8000/v1/chat/completions \
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": "./test.jpg"
|
"url": "http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -317,6 +317,9 @@ def get_prompt(messages) -> str:
|
||||||
if role == "system":
|
if role == "system":
|
||||||
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
|
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
|
||||||
elif role == "user":
|
elif role == "user":
|
||||||
|
if "glm" in local_model.model_name.lower():
|
||||||
|
prompt += f"<|user|>\n{content}\n<|assistant|>"
|
||||||
|
else:
|
||||||
prompt += f"[INST] {content} [/INST] "
|
prompt += f"[INST] {content} [/INST] "
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
prompt += f"{content} "
|
prompt += f"{content} "
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,11 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -30,8 +33,12 @@ class ModelWorker:
|
||||||
self.model = self.load_model(checkpoint, low_bit, "audio")
|
self.model = self.load_model(checkpoint, low_bit, "audio")
|
||||||
else:
|
else:
|
||||||
model = self.load_model(checkpoint, low_bit)
|
model = self.load_model(checkpoint, low_bit)
|
||||||
|
if "glm-4v" not in checkpoint.lower():
|
||||||
from ipex_llm.utils import BenchmarkWrapper
|
from ipex_llm.utils import BenchmarkWrapper
|
||||||
self.model = BenchmarkWrapper(model, do_print=True)
|
self.model = BenchmarkWrapper(model, do_print=True)
|
||||||
|
else:
|
||||||
|
# glm-4v-9b does not support benchmark_util now
|
||||||
|
self.model = model
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.info(f"Time to load weights: {end - start:.2f}s")
|
logger.info(f"Time to load weights: {end - start:.2f}s")
|
||||||
self.waiting_requests = asyncio.Queue()
|
self.waiting_requests = asyncio.Queue()
|
||||||
|
|
@ -49,12 +56,18 @@ class ModelWorker:
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
else:
|
else:
|
||||||
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
|
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
|
||||||
|
modules = None
|
||||||
|
if "glm-4" in model_path.lower():
|
||||||
|
modules = ["encoder.layers.35.mlp", "encoder.layers.36.mlp",
|
||||||
|
"encoder.layers.37.mlp", "encoder.layers.38.mlp",
|
||||||
|
"encoder.layers.39.mlp"]
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
load_in_low_bit=low_bit,
|
load_in_low_bit=low_bit,
|
||||||
torch_dtype=self.dtype,
|
torch_dtype=self.dtype,
|
||||||
optimize_model=True,
|
optimize_model=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
modules_to_not_convert=modules,
|
||||||
use_cache=True,)
|
use_cache=True,)
|
||||||
except:
|
except:
|
||||||
model = AutoModel.from_pretrained(model_path,
|
model = AutoModel.from_pretrained(model_path,
|
||||||
|
|
@ -62,10 +75,25 @@ class ModelWorker:
|
||||||
torch_dtype=self.dtype,
|
torch_dtype=self.dtype,
|
||||||
optimize_model=True,
|
optimize_model=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
modules_to_not_convert=modules,
|
||||||
use_cache=True,)
|
use_cache=True,)
|
||||||
model = model.eval().to("xpu")
|
model = model.eval().to("xpu")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def get_local_image_path(self, image_path):
|
||||||
|
local_dir = './local_images/'
|
||||||
|
local_path = local_dir + os.path.basename(image_path)
|
||||||
|
if os.path.exists(image_path) or os.path.exists(local_path):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
response = requests.get(image_path)
|
||||||
|
if response.status_code == 200:
|
||||||
|
if not os.path.exists(local_dir):
|
||||||
|
os.makedirs(local_dir)
|
||||||
|
with open(local_path, 'wb') as file:
|
||||||
|
file.write(response.content)
|
||||||
|
return local_path
|
||||||
|
|
||||||
async def add_asr_request(self, processor):
|
async def add_asr_request(self, processor):
|
||||||
if self.waiting_requests.empty():
|
if self.waiting_requests.empty():
|
||||||
return
|
return
|
||||||
|
|
@ -94,6 +122,7 @@ class ModelWorker:
|
||||||
plain_texts = prompt_request.inputs
|
plain_texts = prompt_request.inputs
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
inputs = None
|
||||||
if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
|
if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
|
||||||
lines = [
|
lines = [
|
||||||
"You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
|
"You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
|
||||||
|
|
@ -111,16 +140,30 @@ class ModelWorker:
|
||||||
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
|
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
|
||||||
input_ids = inputs["input_ids"].to('xpu')
|
input_ids = inputs["input_ids"].to('xpu')
|
||||||
else:
|
else:
|
||||||
image = self.model.encode_img(prompt_request.image_list[0])
|
# only process the first image now
|
||||||
|
local_path = self.get_local_image_path(prompt_request.image_list[0])
|
||||||
|
image = self.model.encode_img(local_path)
|
||||||
plain_texts = "<ImageHere>" + plain_texts
|
plain_texts = "<ImageHere>" + plain_texts
|
||||||
inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
|
inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
|
||||||
image, [], meta_instruction)
|
image, [], meta_instruction)
|
||||||
inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
|
inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
|
||||||
|
elif "glm-4v" in self.model_name.lower() and prompt_request.image_list is not None:
|
||||||
|
# only process the first image now
|
||||||
|
local_path = self.get_local_image_path(prompt_request.image_list[0])
|
||||||
|
image = Image.open(local_path)
|
||||||
|
|
||||||
|
inputs = tokenizer.apply_chat_template([{"role": "user", "image": image,
|
||||||
|
"content": plain_texts}],
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True)
|
||||||
|
inputs = inputs.to('xpu')
|
||||||
else:
|
else:
|
||||||
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
|
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
|
||||||
input_ids = inputs.input_ids.to('xpu')
|
input_ids = inputs.input_ids.to('xpu')
|
||||||
parameters = prompt_request.parameters
|
parameters = prompt_request.parameters
|
||||||
return input_ids, parameters, request_id, inputs_embeds
|
return input_ids, parameters, request_id, inputs_embeds, inputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
async def process_step(self, tokenizer, result_dict, processor=None):
|
async def process_step(self, tokenizer, result_dict, processor=None):
|
||||||
|
|
@ -134,7 +177,8 @@ class ModelWorker:
|
||||||
streamer=self.streamer[request_id],
|
streamer=self.streamer[request_id],
|
||||||
forced_decoder_ids=decoder_ids)
|
forced_decoder_ids=decoder_ids)
|
||||||
else:
|
else:
|
||||||
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
|
input_ids, parameters, request_id, inputs_embeds, inputs = \
|
||||||
|
await self.add_request(tokenizer)
|
||||||
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
||||||
|
|
||||||
def model_generate():
|
def model_generate():
|
||||||
|
|
@ -156,6 +200,9 @@ class ModelWorker:
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
self.model.generate(inputs_embeds=inputs_embeds,
|
self.model.generate(inputs_embeds=inputs_embeds,
|
||||||
streamer=self.streamer[request_id], **generate_kwargs)
|
streamer=self.streamer[request_id], **generate_kwargs)
|
||||||
|
else:
|
||||||
|
self.model.generate(**inputs,
|
||||||
|
streamer=self.streamer[request_id], **generate_kwargs)
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue