LLM: Add XPU Memory Optimizations for Pipeline Parallel (#11567)
Add XPU Memory Optimizations for Pipeline Parallel
This commit is contained in:
parent
f06d2f72fb
commit
79c742dfd5
1 changed files with 243 additions and 1 deletions
|
|
@ -19,13 +19,17 @@
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Union, Tuple
|
||||
from types import SimpleNamespace
|
||||
import transformers
|
||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
import logging
|
||||
|
|
@ -107,6 +111,34 @@ def init_pipeline_parallel():
|
|||
dist.init_process_group('ccl')
|
||||
|
||||
|
||||
def low_mem_convert(model):
|
||||
from ipex_llm.transformers.convert import convert_forward
|
||||
import importlib
|
||||
if 'llama' in model.config.model_type:
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaForCausalLM,
|
||||
llama_causallm_forward_4_37_lowmem)
|
||||
elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"):
|
||||
if model.config.num_layers == 40:
|
||||
# for glm4-9b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
convert_forward(
|
||||
model,
|
||||
module.ChatGLMForConditionalGeneration,
|
||||
glm4_conditional_generation_forward_lowmem)
|
||||
else:
|
||||
# for chatglm3-6b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
convert_forward(
|
||||
model,
|
||||
module.ChatGLMForConditionalGeneration,
|
||||
chatglm3_conditional_generation_forward_lowmem)
|
||||
return model
|
||||
|
||||
|
||||
def _check_quantize_kv_cache(model, idx, batch_size):
|
||||
# align use_quantize_kv_cache setting for different GPU in pipeline parallel
|
||||
pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
|
||||
|
|
@ -186,6 +218,11 @@ def pipeline_parallel(model, pipeline_parallel_stages):
|
|||
model._modules['model'].norm = DummyLayer()
|
||||
model._modules['lm_head'] = DummyLayer()
|
||||
|
||||
_enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM')
|
||||
_enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1")
|
||||
if _enable_lowmem:
|
||||
model = low_mem_convert(model)
|
||||
|
||||
model.pipeline_parallel_stages = pipeline_parallel_stages
|
||||
model.layer_start = layer_start
|
||||
model.layer_end = layer_end
|
||||
|
|
@ -867,3 +904,208 @@ def _is_chinese_char(cp):
|
|||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def llama_causallm_forward_4_37_lowmem(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
# ipex-llm change starts
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) # noqa
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
torch.xpu.empty_cache()
|
||||
logits = self.lm_head(hidden_states)
|
||||
torch.xpu.empty_cache()
|
||||
# logits = logits.float()
|
||||
|
||||
# ipex-llm change ends
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def chatglm3_conditional_generation_forward_lowmem(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
if return_last_logit:
|
||||
hidden_states = hidden_states[-1:]
|
||||
|
||||
# ipex-llm change starts
|
||||
torch.xpu.empty_cache()
|
||||
lm_logits = self.transformer.output_layer(hidden_states)
|
||||
torch.xpu.empty_cache()
|
||||
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# lm_logits = lm_logits.to(torch.float32)
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
# ipex-llm change ends
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def glm4_conditional_generation_forward_lowmem(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
if return_last_logit:
|
||||
hidden_states = hidden_states[:, -1:]
|
||||
# ipex-llm change starts
|
||||
torch.xpu.empty_cache()
|
||||
lm_logits = self.transformer.output_layer(hidden_states)
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# lm_logits = lm_logits.to(torch.float32)
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
# ipex-llm change ends
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue