From b3b2cd64b437f2c91659ef4a02290ef15b48b0e8 Mon Sep 17 00:00:00 2001
From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com>
Date: Thu, 5 Sep 2024 09:25:08 +0800
Subject: [PATCH] Support lightweight-serving glm-4v-9b  (#11994)
* enable glm-4v-9b serving
* update readme
* update for no image input
---
 .../example/GPU/Lightweight-Serving/README.md |  8 ++-
 .../ipex_llm/serving/fastapi/api_server.py    |  5 +-
 .../ipex_llm/serving/fastapi/model_worker.py  | 57 +++++++++++++++++--
 3 files changed, 61 insertions(+), 9 deletions(-)
diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md
index c21aa880..1a1f7f5c 100644
--- a/python/llm/example/GPU/Lightweight-Serving/README.md
+++ b/python/llm/example/GPU/Lightweight-Serving/README.md
@@ -40,6 +40,9 @@ pip install fastapi uvicorn openai
 pip install gradio # for gradio web UI
 conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 
+# for glm-4v-9b
+pip install transformers==4.42.4 trl
+
 # for internlm-xcomposer2-vl-7b
 pip install transformers==4.31.0
 pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
@@ -190,9 +193,8 @@ curl http://localhost:8000/v1/chat/completions \
 
 ##### Image input
 
-image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
+image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) and [glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b) now. And they should both install specific transformers version to run.
 ```bash
-wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
 curl http://localhost:8000/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
@@ -208,7 +210,7 @@ curl http://localhost:8000/v1/chat/completions \
           {
             "type": "image_url",
             "image_url": {
-              "url": "./test.jpg"
+              "url": "http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"
             }
           }
         ]
diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
index 86fc6bce..387d47e5 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
@@ -317,7 +317,10 @@ def get_prompt(messages) -> str:
                 if role == "system":
                     prompt += f"<>\n{content}\n<>\n\n"
                 elif role == "user":
-                    prompt += f"[INST] {content} [/INST] "
+                    if "glm" in local_model.model_name.lower():
+                        prompt += f"<|user|>\n{content}\n<|assistant|>"
+                    else:
+                        prompt += f"[INST] {content} [/INST] "
                 elif role == "assistant":
                     prompt += f"{content} "
                 else:
diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
index 9a7b2b0b..e339bcc1 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
@@ -16,8 +16,11 @@
 
 import torch
 from transformers.utils import logging
+import os
 import time
 import asyncio
+from PIL import Image
+import requests
 from transformers import TextIteratorStreamer
 logger = logging.get_logger(__name__)
 
@@ -30,8 +33,12 @@ class ModelWorker:
             self.model = self.load_model(checkpoint, low_bit, "audio")
         else:
             model = self.load_model(checkpoint, low_bit)
-            from ipex_llm.utils import BenchmarkWrapper
-            self.model = BenchmarkWrapper(model, do_print=True)
+            if "glm-4v" not in checkpoint.lower():
+                from ipex_llm.utils import BenchmarkWrapper
+                self.model = BenchmarkWrapper(model, do_print=True)
+            else:
+                # glm-4v-9b does not support benchmark_util now
+                self.model = model
         end = time.perf_counter()
         logger.info(f"Time to load weights: {end - start:.2f}s")
         self.waiting_requests = asyncio.Queue()
@@ -49,12 +56,18 @@ class ModelWorker:
                                                               use_cache=True)
         else:
             from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
+            modules = None
+            if "glm-4" in model_path.lower():
+                modules = ["encoder.layers.35.mlp", "encoder.layers.36.mlp",
+                           "encoder.layers.37.mlp", "encoder.layers.38.mlp",
+                           "encoder.layers.39.mlp"]
             try:
                 model = AutoModelForCausalLM.from_pretrained(model_path,
                                                              load_in_low_bit=low_bit,
                                                              torch_dtype=self.dtype,
                                                              optimize_model=True,
                                                              trust_remote_code=True,
+                                                             modules_to_not_convert=modules,
                                                              use_cache=True,)
             except:
                 model = AutoModel.from_pretrained(model_path,
@@ -62,10 +75,25 @@ class ModelWorker:
                                                   torch_dtype=self.dtype,
                                                   optimize_model=True,
                                                   trust_remote_code=True,
+                                                  modules_to_not_convert=modules,
                                                   use_cache=True,)
         model = model.eval().to("xpu")
         return model
 
+    def get_local_image_path(self, image_path):
+        local_dir = './local_images/'
+        local_path = local_dir + os.path.basename(image_path)
+        if os.path.exists(image_path) or os.path.exists(local_path):
+            pass
+        else:
+            response = requests.get(image_path)
+            if response.status_code == 200:
+                if not os.path.exists(local_dir):
+                    os.makedirs(local_dir)
+                with open(local_path, 'wb') as file:
+                    file.write(response.content)
+        return local_path
+
     async def add_asr_request(self, processor):
         if self.waiting_requests.empty():
             return
@@ -94,6 +122,7 @@ class ModelWorker:
         plain_texts = prompt_request.inputs
         input_ids = None
         inputs_embeds = None
+        inputs = None
         if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
             lines = [
                 "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
@@ -111,16 +140,30 @@ class ModelWorker:
                 im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
                 input_ids = inputs["input_ids"].to('xpu')
             else:
-                image = self.model.encode_img(prompt_request.image_list[0])
+                # only process the first image now
+                local_path = self.get_local_image_path(prompt_request.image_list[0])
+                image = self.model.encode_img(local_path)
                 plain_texts = "" + plain_texts
                 inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
                                                                  image, [], meta_instruction)
                 inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
+        elif "glm-4v" in self.model_name.lower() and prompt_request.image_list is not None:
+            # only process the first image now
+            local_path = self.get_local_image_path(prompt_request.image_list[0])
+            image = Image.open(local_path)
+
+            inputs = tokenizer.apply_chat_template([{"role": "user", "image": image,
+                                                   "content": plain_texts}],
+                                                   add_generation_prompt=True,
+                                                   tokenize=True,
+                                                   return_tensors="pt",
+                                                   return_dict=True)
+            inputs = inputs.to('xpu')
         else:
             inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
             input_ids = inputs.input_ids.to('xpu')
         parameters = prompt_request.parameters
-        return input_ids, parameters, request_id, inputs_embeds
+        return input_ids, parameters, request_id, inputs_embeds, inputs
 
     @torch.no_grad()
     async def process_step(self, tokenizer, result_dict, processor=None):
@@ -134,7 +177,8 @@ class ModelWorker:
                                         streamer=self.streamer[request_id],
                                         forced_decoder_ids=decoder_ids)
             else:
-                input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
+                input_ids, parameters, request_id, inputs_embeds, inputs = \
+                    await self.add_request(tokenizer)
                 self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
 
                 def model_generate():
@@ -156,6 +200,9 @@ class ModelWorker:
                     elif inputs_embeds is not None:
                         self.model.generate(inputs_embeds=inputs_embeds,
                                             streamer=self.streamer[request_id], **generate_kwargs)
+                    else:
+                        self.model.generate(**inputs,
+                                            streamer=self.streamer[request_id], **generate_kwargs)
             torch.xpu.empty_cache()
             torch.xpu.synchronize()
             from threading import Thread