Support lightweight-serving glm-4v-9b (#11994)
* enable glm-4v-9b serving * update readme * update for no image input
This commit is contained in:
		
							parent
							
								
									75b19f8522
								
							
						
					
					
						commit
						b3b2cd64b4
					
				
					 3 changed files with 61 additions and 9 deletions
				
			
		| 
						 | 
					@ -40,6 +40,9 @@ pip install fastapi uvicorn openai
 | 
				
			||||||
pip install gradio # for gradio web UI
 | 
					pip install gradio # for gradio web UI
 | 
				
			||||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
					conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# for glm-4v-9b
 | 
				
			||||||
 | 
					pip install transformers==4.42.4 trl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# for internlm-xcomposer2-vl-7b
 | 
					# for internlm-xcomposer2-vl-7b
 | 
				
			||||||
pip install transformers==4.31.0
 | 
					pip install transformers==4.31.0
 | 
				
			||||||
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
 | 
					pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
 | 
				
			||||||
| 
						 | 
					@ -190,9 +193,8 @@ curl http://localhost:8000/v1/chat/completions \
 | 
				
			||||||
 | 
					
 | 
				
			||||||
##### Image input
 | 
					##### Image input
 | 
				
			||||||
 | 
					
 | 
				
			||||||
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
 | 
					image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) and [glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b) now. And they should both install specific transformers version to run.
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
 | 
					 | 
				
			||||||
curl http://localhost:8000/v1/chat/completions \
 | 
					curl http://localhost:8000/v1/chat/completions \
 | 
				
			||||||
  -H "Content-Type: application/json" \
 | 
					  -H "Content-Type: application/json" \
 | 
				
			||||||
  -d '{
 | 
					  -d '{
 | 
				
			||||||
| 
						 | 
					@ -208,7 +210,7 @@ curl http://localhost:8000/v1/chat/completions \
 | 
				
			||||||
          {
 | 
					          {
 | 
				
			||||||
            "type": "image_url",
 | 
					            "type": "image_url",
 | 
				
			||||||
            "image_url": {
 | 
					            "image_url": {
 | 
				
			||||||
              "url": "./test.jpg"
 | 
					              "url": "http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -317,7 +317,10 @@ def get_prompt(messages) -> str:
 | 
				
			||||||
                if role == "system":
 | 
					                if role == "system":
 | 
				
			||||||
                    prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
 | 
					                    prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
 | 
				
			||||||
                elif role == "user":
 | 
					                elif role == "user":
 | 
				
			||||||
                    prompt += f"[INST] {content} [/INST] "
 | 
					                    if "glm" in local_model.model_name.lower():
 | 
				
			||||||
 | 
					                        prompt += f"<|user|>\n{content}\n<|assistant|>"
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        prompt += f"[INST] {content} [/INST] "
 | 
				
			||||||
                elif role == "assistant":
 | 
					                elif role == "assistant":
 | 
				
			||||||
                    prompt += f"{content} "
 | 
					                    prompt += f"{content} "
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,8 +16,11 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from transformers.utils import logging
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
from transformers import TextIteratorStreamer
 | 
					from transformers import TextIteratorStreamer
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,8 +33,12 @@ class ModelWorker:
 | 
				
			||||||
            self.model = self.load_model(checkpoint, low_bit, "audio")
 | 
					            self.model = self.load_model(checkpoint, low_bit, "audio")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            model = self.load_model(checkpoint, low_bit)
 | 
					            model = self.load_model(checkpoint, low_bit)
 | 
				
			||||||
            from ipex_llm.utils import BenchmarkWrapper
 | 
					            if "glm-4v" not in checkpoint.lower():
 | 
				
			||||||
            self.model = BenchmarkWrapper(model, do_print=True)
 | 
					                from ipex_llm.utils import BenchmarkWrapper
 | 
				
			||||||
 | 
					                self.model = BenchmarkWrapper(model, do_print=True)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # glm-4v-9b does not support benchmark_util now
 | 
				
			||||||
 | 
					                self.model = model
 | 
				
			||||||
        end = time.perf_counter()
 | 
					        end = time.perf_counter()
 | 
				
			||||||
        logger.info(f"Time to load weights: {end - start:.2f}s")
 | 
					        logger.info(f"Time to load weights: {end - start:.2f}s")
 | 
				
			||||||
        self.waiting_requests = asyncio.Queue()
 | 
					        self.waiting_requests = asyncio.Queue()
 | 
				
			||||||
| 
						 | 
					@ -49,12 +56,18 @@ class ModelWorker:
 | 
				
			||||||
                                                              use_cache=True)
 | 
					                                                              use_cache=True)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
 | 
					            from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
 | 
				
			||||||
 | 
					            modules = None
 | 
				
			||||||
 | 
					            if "glm-4" in model_path.lower():
 | 
				
			||||||
 | 
					                modules = ["encoder.layers.35.mlp", "encoder.layers.36.mlp",
 | 
				
			||||||
 | 
					                           "encoder.layers.37.mlp", "encoder.layers.38.mlp",
 | 
				
			||||||
 | 
					                           "encoder.layers.39.mlp"]
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
					                model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
                                                             load_in_low_bit=low_bit,
 | 
					                                                             load_in_low_bit=low_bit,
 | 
				
			||||||
                                                             torch_dtype=self.dtype,
 | 
					                                                             torch_dtype=self.dtype,
 | 
				
			||||||
                                                             optimize_model=True,
 | 
					                                                             optimize_model=True,
 | 
				
			||||||
                                                             trust_remote_code=True,
 | 
					                                                             trust_remote_code=True,
 | 
				
			||||||
 | 
					                                                             modules_to_not_convert=modules,
 | 
				
			||||||
                                                             use_cache=True,)
 | 
					                                                             use_cache=True,)
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                model = AutoModel.from_pretrained(model_path,
 | 
					                model = AutoModel.from_pretrained(model_path,
 | 
				
			||||||
| 
						 | 
					@ -62,10 +75,25 @@ class ModelWorker:
 | 
				
			||||||
                                                  torch_dtype=self.dtype,
 | 
					                                                  torch_dtype=self.dtype,
 | 
				
			||||||
                                                  optimize_model=True,
 | 
					                                                  optimize_model=True,
 | 
				
			||||||
                                                  trust_remote_code=True,
 | 
					                                                  trust_remote_code=True,
 | 
				
			||||||
 | 
					                                                  modules_to_not_convert=modules,
 | 
				
			||||||
                                                  use_cache=True,)
 | 
					                                                  use_cache=True,)
 | 
				
			||||||
        model = model.eval().to("xpu")
 | 
					        model = model.eval().to("xpu")
 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_local_image_path(self, image_path):
 | 
				
			||||||
 | 
					        local_dir = './local_images/'
 | 
				
			||||||
 | 
					        local_path = local_dir + os.path.basename(image_path)
 | 
				
			||||||
 | 
					        if os.path.exists(image_path) or os.path.exists(local_path):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            response = requests.get(image_path)
 | 
				
			||||||
 | 
					            if response.status_code == 200:
 | 
				
			||||||
 | 
					                if not os.path.exists(local_dir):
 | 
				
			||||||
 | 
					                    os.makedirs(local_dir)
 | 
				
			||||||
 | 
					                with open(local_path, 'wb') as file:
 | 
				
			||||||
 | 
					                    file.write(response.content)
 | 
				
			||||||
 | 
					        return local_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def add_asr_request(self, processor):
 | 
					    async def add_asr_request(self, processor):
 | 
				
			||||||
        if self.waiting_requests.empty():
 | 
					        if self.waiting_requests.empty():
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
| 
						 | 
					@ -94,6 +122,7 @@ class ModelWorker:
 | 
				
			||||||
        plain_texts = prompt_request.inputs
 | 
					        plain_texts = prompt_request.inputs
 | 
				
			||||||
        input_ids = None
 | 
					        input_ids = None
 | 
				
			||||||
        inputs_embeds = None
 | 
					        inputs_embeds = None
 | 
				
			||||||
 | 
					        inputs = None
 | 
				
			||||||
        if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
 | 
					        if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
 | 
				
			||||||
            lines = [
 | 
					            lines = [
 | 
				
			||||||
                "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
 | 
					                "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
 | 
				
			||||||
| 
						 | 
					@ -111,16 +140,30 @@ class ModelWorker:
 | 
				
			||||||
                im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
 | 
					                im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
 | 
				
			||||||
                input_ids = inputs["input_ids"].to('xpu')
 | 
					                input_ids = inputs["input_ids"].to('xpu')
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                image = self.model.encode_img(prompt_request.image_list[0])
 | 
					                # only process the first image now
 | 
				
			||||||
 | 
					                local_path = self.get_local_image_path(prompt_request.image_list[0])
 | 
				
			||||||
 | 
					                image = self.model.encode_img(local_path)
 | 
				
			||||||
                plain_texts = "<ImageHere>" + plain_texts
 | 
					                plain_texts = "<ImageHere>" + plain_texts
 | 
				
			||||||
                inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
 | 
					                inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
 | 
				
			||||||
                                                                 image, [], meta_instruction)
 | 
					                                                                 image, [], meta_instruction)
 | 
				
			||||||
                inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
 | 
					                inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
 | 
				
			||||||
 | 
					        elif "glm-4v" in self.model_name.lower() and prompt_request.image_list is not None:
 | 
				
			||||||
 | 
					            # only process the first image now
 | 
				
			||||||
 | 
					            local_path = self.get_local_image_path(prompt_request.image_list[0])
 | 
				
			||||||
 | 
					            image = Image.open(local_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            inputs = tokenizer.apply_chat_template([{"role": "user", "image": image,
 | 
				
			||||||
 | 
					                                                   "content": plain_texts}],
 | 
				
			||||||
 | 
					                                                   add_generation_prompt=True,
 | 
				
			||||||
 | 
					                                                   tokenize=True,
 | 
				
			||||||
 | 
					                                                   return_tensors="pt",
 | 
				
			||||||
 | 
					                                                   return_dict=True)
 | 
				
			||||||
 | 
					            inputs = inputs.to('xpu')
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
 | 
					            inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
 | 
				
			||||||
            input_ids = inputs.input_ids.to('xpu')
 | 
					            input_ids = inputs.input_ids.to('xpu')
 | 
				
			||||||
        parameters = prompt_request.parameters
 | 
					        parameters = prompt_request.parameters
 | 
				
			||||||
        return input_ids, parameters, request_id, inputs_embeds
 | 
					        return input_ids, parameters, request_id, inputs_embeds, inputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.no_grad()
 | 
					    @torch.no_grad()
 | 
				
			||||||
    async def process_step(self, tokenizer, result_dict, processor=None):
 | 
					    async def process_step(self, tokenizer, result_dict, processor=None):
 | 
				
			||||||
| 
						 | 
					@ -134,7 +177,8 @@ class ModelWorker:
 | 
				
			||||||
                                        streamer=self.streamer[request_id],
 | 
					                                        streamer=self.streamer[request_id],
 | 
				
			||||||
                                        forced_decoder_ids=decoder_ids)
 | 
					                                        forced_decoder_ids=decoder_ids)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
 | 
					                input_ids, parameters, request_id, inputs_embeds, inputs = \
 | 
				
			||||||
 | 
					                    await self.add_request(tokenizer)
 | 
				
			||||||
                self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
 | 
					                self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                def model_generate():
 | 
					                def model_generate():
 | 
				
			||||||
| 
						 | 
					@ -156,6 +200,9 @@ class ModelWorker:
 | 
				
			||||||
                    elif inputs_embeds is not None:
 | 
					                    elif inputs_embeds is not None:
 | 
				
			||||||
                        self.model.generate(inputs_embeds=inputs_embeds,
 | 
					                        self.model.generate(inputs_embeds=inputs_embeds,
 | 
				
			||||||
                                            streamer=self.streamer[request_id], **generate_kwargs)
 | 
					                                            streamer=self.streamer[request_id], **generate_kwargs)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        self.model.generate(**inputs,
 | 
				
			||||||
 | 
					                                            streamer=self.streamer[request_id], **generate_kwargs)
 | 
				
			||||||
            torch.xpu.empty_cache()
 | 
					            torch.xpu.empty_cache()
 | 
				
			||||||
            torch.xpu.synchronize()
 | 
					            torch.xpu.synchronize()
 | 
				
			||||||
            from threading import Thread
 | 
					            from threading import Thread
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue