diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index bced1ec9..2e6e99b5 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -462,11 +462,15 @@ class BatchTask(BaseModel): partial_prefilling: int -def make_attention_mask(prompt_lengths): +def make_attention_mask(prompt_lengths, device): max_length = max(prompt_lengths) - attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64) - for i, length in enumerate(prompt_lengths): - attention_mask[i, max_length - length:] = 1 + batch_size = len(prompt_lengths) + + range_tensor = torch.arange(max_length, device=device).expand(batch_size, max_length) + prompt_lengths_tensor = torch.tensor(prompt_lengths, device=device).unsqueeze(1) + attention_mask = range_tensor >= max_length - prompt_lengths_tensor + attention_mask = attention_mask.to(torch.int64) + return attention_mask @@ -501,6 +505,8 @@ class PPModelWorker: self.print_len = {} self.is_finish = {} self.model_name = checkpoint + + self.device = f"xpu:{self.rank}" # self.layer_start = 0 # self.layer_end = 0 @@ -543,7 +549,7 @@ class PPModelWorker: return cur_batch def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): - if model_type in ["baichuan", "chatglm"]: + if model_type in ["baichuan", "chatglm", "mixtral"]: result = [] for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2): if sub_tuple1 is None: @@ -613,7 +619,7 @@ class PPModelWorker: # logger.info(f"{self.rank} {cur_batch} {input.shape}") cur_id = cur_batch.batch_id _past_key_values = self.past_key_values_dict.get(cur_id, None) - attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device) + attention_mask = make_attention_mask(cur_batch.prompt_lengths, input.device) if self.rank == 0: input_ids = input @@ -637,7 +643,7 @@ class PPModelWorker: tmp_past_key_values = _past_key_values _past_key_values = None - # torch.xpu.empty_cache() + torch.xpu.empty_cache() output = self.model(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=_past_key_values, @@ -661,7 +667,7 @@ class PPModelWorker: if self.pp_config.is_tail: _pre_output = self.partial_output_dict.get(cur_id, None) - tmp_output = output.logits.to(self.dtype) + tmp_output = output.logits tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1) if _pre_output is None: _pre_output = tmp_output @@ -674,7 +680,10 @@ class PPModelWorker: self.past_key_values_dict[cur_id] = _past_key_values torch.xpu.synchronize() if not self.pp_config.is_tail: - return output[0].to(self.dtype), cur_batch + _output = output[0] + if _output.dtype != self.dtype: + _output = _output.to(self.dtype) + return _output, cur_batch else: if cur_batch.partial_prefilling > 0 and \ cur_batch.prefilled_index == cur_batch.batch_size: