Support lightweight-serving with internlm-xcomposer2-vl-7b multimodal input (#11703)
* init image_list * enable internlm-xcomposer2 image input * update style * add readme * update model * update readme
This commit is contained in:
parent
aa98ef96fe
commit
493cbd9a36
4 changed files with 111 additions and 22 deletions
|
|
@ -18,6 +18,10 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
|
|||
pip install fastapi uvicorn openai
|
||||
pip install gradio # for gradio web UI
|
||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||
|
||||
# for internlm-xcomposer2-vl-7b
|
||||
pip install transformers==4.31.0
|
||||
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
|
||||
```
|
||||
|
||||
#### 1.2 Installation on Windows
|
||||
|
|
@ -172,10 +176,39 @@ curl http://localhost:8000/v1/chat/completions \
|
|||
}'
|
||||
```
|
||||
|
||||
##### Image input
|
||||
|
||||
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
|
||||
```bash
|
||||
wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "internlm-xcomposer2-vl-7b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What'\''s in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "./test.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 128
|
||||
}'
|
||||
```
|
||||
|
||||
#### /v1/completions
|
||||
|
||||
```bash
|
||||
|
||||
curl http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
|
|
|
|||
|
|
@ -47,12 +47,17 @@ logger = logging.get_logger(__name__)
|
|||
class InputsRequest(BaseModel):
|
||||
inputs: str
|
||||
parameters: Optional[Parameters] = None
|
||||
image_list: Optional[list] = None
|
||||
stream: Optional[bool] = False
|
||||
req_type: str = 'completion'
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
messages: Union[
|
||||
str,
|
||||
List[Dict[str, str]],
|
||||
List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]],
|
||||
]
|
||||
model: str
|
||||
max_tokens: Optional[int] = None
|
||||
min_tokens: Optional[int] = None
|
||||
|
|
@ -266,7 +271,7 @@ async def generate_stream(inputs_request: InputsRequest):
|
|||
|
||||
def get_prompt(messages) -> str:
|
||||
if "codegeex" in local_model.model_name.lower():
|
||||
query = messages[-1].content
|
||||
query = messages[-1]["content"]
|
||||
if len(messages) <= 1:
|
||||
history = []
|
||||
else:
|
||||
|
|
@ -277,18 +282,33 @@ def get_prompt(messages) -> str:
|
|||
return inputs
|
||||
else:
|
||||
prompt = ""
|
||||
image_list = []
|
||||
for msg in messages:
|
||||
role = msg.role
|
||||
content = msg.content
|
||||
if role == "system":
|
||||
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
|
||||
elif role == "user":
|
||||
prompt += f"[INST] {content} [/INST] "
|
||||
elif role == "assistant":
|
||||
prompt += f"{content} "
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
if type(content) == list:
|
||||
image_list1 = [
|
||||
item["image_url"]["url"]
|
||||
for item in content
|
||||
if item["type"] == "image_url"
|
||||
]
|
||||
image_list.extend(image_list1)
|
||||
text_list = [
|
||||
item["text"]
|
||||
for item in content
|
||||
if item["type"] == "text"
|
||||
]
|
||||
prompt = "".join(text_list)
|
||||
else:
|
||||
invalidInputError(False, f"Unknown role: {role}")
|
||||
return prompt.strip()
|
||||
if role == "system":
|
||||
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
|
||||
elif role == "user":
|
||||
prompt += f"[INST] {content} [/INST] "
|
||||
elif role == "assistant":
|
||||
prompt += f"{content} "
|
||||
else:
|
||||
invalidInputError(False, f"Unknown role: {role}")
|
||||
return prompt.strip(), image_list
|
||||
|
||||
|
||||
def set_parameters(req):
|
||||
|
|
@ -313,11 +333,12 @@ def set_parameters(req):
|
|||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
print(request)
|
||||
model_name = local_model.model_name
|
||||
prompt, image_list = get_prompt(request.messages)
|
||||
inputs_request = InputsRequest(
|
||||
inputs=get_prompt(request.messages),
|
||||
inputs=prompt,
|
||||
parameters=set_parameters(request),
|
||||
image_list=image_list if len(image_list) >= 1 else None,
|
||||
stream=request.stream,
|
||||
req_type="chat"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -60,15 +60,40 @@ class ModelWorker:
|
|||
tmp_result = await self.waiting_requests.get()
|
||||
request_id, prompt_request = tmp_result
|
||||
plain_texts = prompt_request.inputs
|
||||
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
|
||||
input_ids = inputs.input_ids.to('xpu')
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
|
||||
lines = [
|
||||
"You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
|
||||
"- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language "
|
||||
"model that is developed by Shanghai AI Laboratory (上海人工智能实验室). "
|
||||
"It is designed to be helpful, honest, and harmless.",
|
||||
"- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in "
|
||||
"the language chosen by the user such as English and 中文.",
|
||||
"- InternLM-XComposer (浦语·灵笔) is capable of comprehending and "
|
||||
"articulating responses effectively based on the provided image."
|
||||
]
|
||||
meta_instruction = "\n".join(lines)
|
||||
if prompt_request.image_list is None:
|
||||
inputs = self.model.build_inputs(tokenizer, plain_texts, [], meta_instruction)
|
||||
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
|
||||
input_ids = inputs["input_ids"].to('xpu')
|
||||
else:
|
||||
image = self.model.encode_img(prompt_request.image_list[0])
|
||||
plain_texts = "<ImageHere>" + plain_texts
|
||||
inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
|
||||
image, [], meta_instruction)
|
||||
inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
|
||||
else:
|
||||
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
|
||||
input_ids = inputs.input_ids.to('xpu')
|
||||
parameters = prompt_request.parameters
|
||||
return input_ids, parameters, request_id
|
||||
return input_ids, parameters, request_id, inputs_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
async def process_step(self, tokenizer, result_dict):
|
||||
if not self.waiting_requests.empty():
|
||||
input_ids, parameters, request_id = await self.add_request(tokenizer)
|
||||
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
|
||||
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
||||
|
||||
def model_generate():
|
||||
|
|
@ -78,8 +103,18 @@ class ModelWorker:
|
|||
tokenizer.convert_tokens_to_ids("<|user|>"),
|
||||
tokenizer.convert_tokens_to_ids("<|observation|>")]
|
||||
generate_kwargs["eos_token_id"] = eos_token_id
|
||||
self.model.generate(input_ids,
|
||||
streamer=self.streamer[request_id], **generate_kwargs)
|
||||
elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
|
||||
eos_token_id = [
|
||||
tokenizer.eos_token_id,
|
||||
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
|
||||
]
|
||||
generate_kwargs["eos_token_id"] = eos_token_id
|
||||
if input_ids is not None:
|
||||
self.model.generate(input_ids,
|
||||
streamer=self.streamer[request_id], **generate_kwargs)
|
||||
elif inputs_embeds is not None:
|
||||
self.model.generate(inputs_embeds=inputs_embeds,
|
||||
streamer=self.streamer[request_id], **generate_kwargs)
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
|
|
|||
|
|
@ -574,7 +574,7 @@ class BenchmarkWrapper:
|
|||
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
|
||||
if not self.config.is_encoder_decoder:
|
||||
has_inputs_embeds_forwarding = "inputs_embeds" in set(
|
||||
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
|
||||
inspect.signature(self.model.prepare_inputs_for_generation).parameters.keys()
|
||||
)
|
||||
if not has_inputs_embeds_forwarding:
|
||||
raise ValueError(
|
||||
|
|
|
|||
Loading…
Reference in a new issue