diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md
index 104032c6..4cb29db1 100644
--- a/python/llm/example/GPU/Lightweight-Serving/README.md
+++ b/python/llm/example/GPU/Lightweight-Serving/README.md
@@ -18,6 +18,10 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
 pip install fastapi uvicorn openai
 pip install gradio # for gradio web UI
 conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
+
+# for internlm-xcomposer2-vl-7b
+pip install transformers==4.31.0
+pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
 ```
 
 #### 1.2 Installation on Windows
@@ -172,10 +176,39 @@ curl http://localhost:8000/v1/chat/completions \
   }'
 ```
 
+##### Image input
+
+image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
+```bash
+wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
+curl http://localhost:8000/v1/chat/completions \
+  -H "Content-Type: application/json" \
+  -d '{
+    "model": "internlm-xcomposer2-vl-7b",
+    "messages": [
+      {
+        "role": "user",
+        "content": [
+          {
+            "type": "text",
+            "text": "What'\''s in this image?"
+          },
+          {
+            "type": "image_url",
+            "image_url": {
+              "url": "./test.jpg"
+            }
+          }
+        ]
+      }
+    ],
+    "max_tokens": 128
+  }'
+```
+
 #### /v1/completions
 
 ```bash
-
 curl http://localhost:8000/v1/completions \
   -H "Content-Type: application/json" \
   -d '{
diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
index 0cc12bf3..88c85618 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
@@ -47,12 +47,17 @@ logger = logging.get_logger(__name__)
 class InputsRequest(BaseModel):
     inputs: str
     parameters: Optional[Parameters] = None
+    image_list: Optional[list] = None
     stream: Optional[bool] = False
     req_type: str = 'completion'
 
 
 class ChatCompletionRequest(BaseModel):
-    messages: List[ChatMessage]
+    messages: Union[
+        str,
+        List[Dict[str, str]],
+        List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]],
+    ]
     model: str
     max_tokens: Optional[int] = None
     min_tokens: Optional[int] = None
@@ -266,7 +271,7 @@ async def generate_stream(inputs_request: InputsRequest):
 
 def get_prompt(messages) -> str:
     if "codegeex" in local_model.model_name.lower():
-        query = messages[-1].content
+        query = messages[-1]["content"]
         if len(messages) <= 1:
             history = []
         else:
@@ -277,18 +282,33 @@ def get_prompt(messages) -> str:
         return inputs
     else:
         prompt = ""
+        image_list = []
         for msg in messages:
-            role = msg.role
-            content = msg.content
-            if role == "system":
-                prompt += f"<>\n{content}\n<>\n\n"
-            elif role == "user":
-                prompt += f"[INST] {content} [/INST] "
-            elif role == "assistant":
-                prompt += f"{content} "
+            role = msg["role"]
+            content = msg["content"]
+            if type(content) == list:
+                image_list1 = [
+                    item["image_url"]["url"]
+                    for item in content
+                    if item["type"] == "image_url"
+                ]
+                image_list.extend(image_list1)
+                text_list = [
+                    item["text"]
+                    for item in content
+                    if item["type"] == "text"
+                ]
+                prompt = "".join(text_list)
             else:
-                invalidInputError(False, f"Unknown role: {role}")
-        return prompt.strip()
+                if role == "system":
+                    prompt += f"<>\n{content}\n<>\n\n"
+                elif role == "user":
+                    prompt += f"[INST] {content} [/INST] "
+                elif role == "assistant":
+                    prompt += f"{content} "
+                else:
+                    invalidInputError(False, f"Unknown role: {role}")
+        return prompt.strip(), image_list
 
 
 def set_parameters(req):
@@ -313,11 +333,12 @@ def set_parameters(req):
 
 @app.post("/v1/chat/completions")
 async def create_chat_completion(request: ChatCompletionRequest):
-    print(request)
     model_name = local_model.model_name
+    prompt, image_list = get_prompt(request.messages)
     inputs_request = InputsRequest(
-        inputs=get_prompt(request.messages),
+        inputs=prompt,
         parameters=set_parameters(request),
+        image_list=image_list if len(image_list) >= 1 else None,
         stream=request.stream,
         req_type="chat"
     )
diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
index 099cfc0c..0fe0f88c 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
@@ -60,15 +60,40 @@ class ModelWorker:
         tmp_result = await self.waiting_requests.get()
         request_id, prompt_request = tmp_result
         plain_texts = prompt_request.inputs
-        inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
-        input_ids = inputs.input_ids.to('xpu')
+        input_ids = None
+        inputs_embeds = None
+        if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
+            lines = [
+                "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
+                "- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language "
+                "model that is developed by Shanghai AI Laboratory (上海人工智能实验室). "
+                "It is designed to be helpful, honest, and harmless.",
+                "- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in "
+                "the language chosen by the user such as English and 中文.",
+                "- InternLM-XComposer (浦语·灵笔) is capable of comprehending and "
+                "articulating responses effectively based on the provided image."
+            ]
+            meta_instruction = "\n".join(lines)
+            if prompt_request.image_list is None:
+                inputs = self.model.build_inputs(tokenizer, plain_texts, [], meta_instruction)
+                im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
+                input_ids = inputs["input_ids"].to('xpu')
+            else:
+                image = self.model.encode_img(prompt_request.image_list[0])
+                plain_texts = "" + plain_texts
+                inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
+                                                                 image, [], meta_instruction)
+                inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
+        else:
+            inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
+            input_ids = inputs.input_ids.to('xpu')
         parameters = prompt_request.parameters
-        return input_ids, parameters, request_id
+        return input_ids, parameters, request_id, inputs_embeds
 
     @torch.no_grad()
     async def process_step(self, tokenizer, result_dict):
         if not self.waiting_requests.empty():
-            input_ids, parameters, request_id = await self.add_request(tokenizer)
+            input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
             self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
 
             def model_generate():
@@ -78,8 +103,18 @@ class ModelWorker:
                                     tokenizer.convert_tokens_to_ids("<|user|>"),
                                     tokenizer.convert_tokens_to_ids("<|observation|>")]
                     generate_kwargs["eos_token_id"] = eos_token_id
-                self.model.generate(input_ids,
-                                    streamer=self.streamer[request_id], **generate_kwargs)
+                elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
+                    eos_token_id = [
+                        tokenizer.eos_token_id,
+                        tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
+                    ]
+                    generate_kwargs["eos_token_id"] = eos_token_id
+                if input_ids is not None:
+                    self.model.generate(input_ids,
+                                        streamer=self.streamer[request_id], **generate_kwargs)
+                elif inputs_embeds is not None:
+                    self.model.generate(inputs_embeds=inputs_embeds,
+                                        streamer=self.streamer[request_id], **generate_kwargs)
                 torch.xpu.empty_cache()
                 torch.xpu.synchronize()
 
diff --git a/python/llm/src/ipex_llm/utils/benchmark_util.py b/python/llm/src/ipex_llm/utils/benchmark_util.py
index cbbd6c6a..d64631f1 100644
--- a/python/llm/src/ipex_llm/utils/benchmark_util.py
+++ b/python/llm/src/ipex_llm/utils/benchmark_util.py
@@ -574,7 +574,7 @@ class BenchmarkWrapper:
         if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
             if not self.config.is_encoder_decoder:
                 has_inputs_embeds_forwarding = "inputs_embeds" in set(
-                    inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
+                    inspect.signature(self.model.prepare_inputs_for_generation).parameters.keys()
                 )
                 if not has_inputs_embeds_forwarding:
                     raise ValueError(