ipex-llm/python/llm/src/ipex_llm/transformers/pipeline_parallel.py

268 lines
10 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
#
import torch
from torch import nn
import torch.distributed as dist
import os
import time
import numpy as np
from typing import Callable, List, Optional
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.utils.common import invalidInputError
import logging
logger = logging.getLogger(__name__)
# patch GenerationMixin.generate
from transformers import GenerationMixin
original_generate = GenerationMixin.generate
class DummyLayer(nn.Module):
def __init__(self, *args):
super().__init__()
# to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
# python/llm/src/ipex_llm/transformers/models/llama.py#L2076
self.weight = torch.randn(1,)
def forward(self, x):
return x
class Dummy_MLPLayer(nn.Module):
def __init__(self, *args):
super().__init__()
# to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
# python/llm/src/ipex_llm/transformers/models/llama.py#L119
self.up_proj = DummyLayer()
self.down_proj = DummyLayer()
def forward(self, x):
return x
class Dummy_DecoderLayer(nn.Module):
def __init__(self, *args):
super().__init__()
# to avoid AttributeError
self.input_layernorm = DummyLayer()
self.mlp = Dummy_MLPLayer()
def forward(self, hidden_states, past_key_value=None, use_cache=False, **kwargs):
outputs = (hidden_states,)
if use_cache:
outputs += (past_key_value,)
return outputs
def init_pipeline_parallel():
import oneccl_bindings_for_pytorch
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
dist.init_process_group('ccl')
def pipeline_parallel(model, pipeline_parallel_stages):
slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \
pipeline_parallel_stages
local_rank = dist.get_rank()
global layer_start
global layer_end
layer_start = slice_size * local_rank
layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start)
for i in range(model.config.num_hidden_layers):
if i < layer_start or i >= layer_end:
model._modules['model'].layers[i] = Dummy_DecoderLayer()
else:
# align layer_idx and len(past_key_values), otherwise abnormal output
model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start
if local_rank != 0:
model._modules['model'].embed_tokens = DummyLayer()
if local_rank != pipeline_parallel_stages - 1:
model._modules['model'].norm = DummyLayer()
model._modules['lm_head'] = DummyLayer()
model.pipeline_parallel_stages = pipeline_parallel_stages
model = model.to(f'xpu:{local_rank}')
return model
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
# priority: `generation_config` argument > `model.generation_config`
if generation_config is None:
if (
self.generation_config._from_model_config
and self.generation_config._original_object_hash == hash(self.generation_config)
and self.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
self.generation_config = new_generation_config
generation_config = self.generation_config
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning("Setting `pad_token_id` to `eos_token_id`: "
f"{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
if generation_config is not None and generation_config.max_new_tokens is not None:
max_new_tokens = generation_config.max_new_tokens
else:
max_new_tokens = kwargs.get("max_new_tokens", None)
return self.pipeline_parallel_generate(inputs=inputs,
max_new_tokens=max_new_tokens,
generation_config=generation_config,)
return original_generate(self,
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs)
GenerationMixin.generate = generate
@torch.no_grad()
def pipeline_parallel_generate(self,
inputs: Optional[torch.Tensor] = None,
max_new_tokens: int = 32,
generation_config: Optional[GenerationConfig] = None,
**kwargs):
local_rank = dist.get_rank()
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
next_rank = (local_rank + 1) % self.pipeline_parallel_stages
global layer_start
global layer_end
self.first_token_time = 0
self.next_token_time = []
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(inputs.device) \
if eos_token_id is not None else None
_input_ids = None
_past_key_values = None
bs = inputs.shape[0]
output_ids = inputs.clone()
step = 0
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(inputs.shape[0], dtype=torch.long, device=inputs.device)
this_peer_finished = False
while True:
if step >= max_new_tokens:
break
if _input_ids is None:
_input_ids = inputs
tic = time.time()
if local_rank == 0:
outputs = self(input_ids=_input_ids, inputs_embeds=None,
past_key_values=_past_key_values, use_cache=True)
else:
inputs_embeds = torch.empty(_input_ids.shape + (self.config.hidden_size,),
device=f'xpu:{local_rank}', dtype=self.dtype)
dist.recv(inputs_embeds, src=pre_rank)
outputs = self(input_ids=None, inputs_embeds=inputs_embeds,
past_key_values=_past_key_values, use_cache=True)
if local_rank == self.pipeline_parallel_stages - 1:
logits = outputs.logits
next_ids = torch.argmax(logits[:, -1:, :], dim=-1)
dist.broadcast(next_ids, src=local_rank)
else:
dist.send(outputs[0].to(self.dtype), dst=next_rank)
next_ids = torch.empty((bs, 1), device=f'xpu:{local_rank}', dtype=torch.int64)
dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1)
_input_ids = next_ids
output_ids = torch.cat([output_ids, next_ids], dim=-1)
# finished sentences should have their next token be a padding token
next_ids = next_ids.squeeze()
if eos_token_id is not None:
if pad_token_id is None:
invalidInputError(False, "If `eos_token_id` is defined, "
"make sure that `pad_token_id` is defined.")
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
if isinstance(outputs.past_key_values, tuple) and local_rank != 0:
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (outputs.past_key_values)[layer_start:]
_past_key_values = past_key_values_placeholder
else:
_past_key_values = outputs.past_key_values
toc = time.time()
if step == 0:
self.first_token_time = toc - tic
else:
self.next_token_time.append(toc - tic)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_ids.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished:
break
step += 1
if self.device.type == 'xpu':
torch.xpu.synchronize()
self.rest_cost_mean = np.mean(self.next_token_time)
return output_ids