diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md
index c21aa880..1a1f7f5c 100644
--- a/python/llm/example/GPU/Lightweight-Serving/README.md
+++ b/python/llm/example/GPU/Lightweight-Serving/README.md
@@ -40,6 +40,9 @@ pip install fastapi uvicorn openai
pip install gradio # for gradio web UI
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
+# for glm-4v-9b
+pip install transformers==4.42.4 trl
+
# for internlm-xcomposer2-vl-7b
pip install transformers==4.31.0
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
@@ -190,9 +193,8 @@ curl http://localhost:8000/v1/chat/completions \
##### Image input
-image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
+image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) and [glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b) now. And they should both install specific transformers version to run.
```bash
-wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
@@ -208,7 +210,7 @@ curl http://localhost:8000/v1/chat/completions \
{
"type": "image_url",
"image_url": {
- "url": "./test.jpg"
+ "url": "http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"
}
}
]
diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
index 86fc6bce..387d47e5 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
@@ -317,7 +317,10 @@ def get_prompt(messages) -> str:
if role == "system":
prompt += f"<>\n{content}\n<>\n\n"
elif role == "user":
- prompt += f"[INST] {content} [/INST] "
+ if "glm" in local_model.model_name.lower():
+ prompt += f"<|user|>\n{content}\n<|assistant|>"
+ else:
+ prompt += f"[INST] {content} [/INST] "
elif role == "assistant":
prompt += f"{content} "
else:
diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
index 9a7b2b0b..e339bcc1 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
@@ -16,8 +16,11 @@
import torch
from transformers.utils import logging
+import os
import time
import asyncio
+from PIL import Image
+import requests
from transformers import TextIteratorStreamer
logger = logging.get_logger(__name__)
@@ -30,8 +33,12 @@ class ModelWorker:
self.model = self.load_model(checkpoint, low_bit, "audio")
else:
model = self.load_model(checkpoint, low_bit)
- from ipex_llm.utils import BenchmarkWrapper
- self.model = BenchmarkWrapper(model, do_print=True)
+ if "glm-4v" not in checkpoint.lower():
+ from ipex_llm.utils import BenchmarkWrapper
+ self.model = BenchmarkWrapper(model, do_print=True)
+ else:
+ # glm-4v-9b does not support benchmark_util now
+ self.model = model
end = time.perf_counter()
logger.info(f"Time to load weights: {end - start:.2f}s")
self.waiting_requests = asyncio.Queue()
@@ -49,12 +56,18 @@ class ModelWorker:
use_cache=True)
else:
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
+ modules = None
+ if "glm-4" in model_path.lower():
+ modules = ["encoder.layers.35.mlp", "encoder.layers.36.mlp",
+ "encoder.layers.37.mlp", "encoder.layers.38.mlp",
+ "encoder.layers.39.mlp"]
try:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
+ modules_to_not_convert=modules,
use_cache=True,)
except:
model = AutoModel.from_pretrained(model_path,
@@ -62,10 +75,25 @@ class ModelWorker:
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
+ modules_to_not_convert=modules,
use_cache=True,)
model = model.eval().to("xpu")
return model
+ def get_local_image_path(self, image_path):
+ local_dir = './local_images/'
+ local_path = local_dir + os.path.basename(image_path)
+ if os.path.exists(image_path) or os.path.exists(local_path):
+ pass
+ else:
+ response = requests.get(image_path)
+ if response.status_code == 200:
+ if not os.path.exists(local_dir):
+ os.makedirs(local_dir)
+ with open(local_path, 'wb') as file:
+ file.write(response.content)
+ return local_path
+
async def add_asr_request(self, processor):
if self.waiting_requests.empty():
return
@@ -94,6 +122,7 @@ class ModelWorker:
plain_texts = prompt_request.inputs
input_ids = None
inputs_embeds = None
+ inputs = None
if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
lines = [
"You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
@@ -111,16 +140,30 @@ class ModelWorker:
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
input_ids = inputs["input_ids"].to('xpu')
else:
- image = self.model.encode_img(prompt_request.image_list[0])
+ # only process the first image now
+ local_path = self.get_local_image_path(prompt_request.image_list[0])
+ image = self.model.encode_img(local_path)
plain_texts = "" + plain_texts
inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
image, [], meta_instruction)
inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
+ elif "glm-4v" in self.model_name.lower() and prompt_request.image_list is not None:
+ # only process the first image now
+ local_path = self.get_local_image_path(prompt_request.image_list[0])
+ image = Image.open(local_path)
+
+ inputs = tokenizer.apply_chat_template([{"role": "user", "image": image,
+ "content": plain_texts}],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_tensors="pt",
+ return_dict=True)
+ inputs = inputs.to('xpu')
else:
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to('xpu')
parameters = prompt_request.parameters
- return input_ids, parameters, request_id, inputs_embeds
+ return input_ids, parameters, request_id, inputs_embeds, inputs
@torch.no_grad()
async def process_step(self, tokenizer, result_dict, processor=None):
@@ -134,7 +177,8 @@ class ModelWorker:
streamer=self.streamer[request_id],
forced_decoder_ids=decoder_ids)
else:
- input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
+ input_ids, parameters, request_id, inputs_embeds, inputs = \
+ await self.add_request(tokenizer)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
def model_generate():
@@ -156,6 +200,9 @@ class ModelWorker:
elif inputs_embeds is not None:
self.model.generate(inputs_embeds=inputs_embeds,
streamer=self.streamer[request_id], **generate_kwargs)
+ else:
+ self.model.generate(**inputs,
+ streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache()
torch.xpu.synchronize()
from threading import Thread