LLM: Add XPU Memory Optimizations for Pipeline Parallel (#11567)

Add XPU Memory Optimizations for Pipeline Parallel
This commit is contained in:
Xiangyu Tian 2024-07-16 09:44:50 +08:00 committed by GitHub
parent f06d2f72fb
commit 79c742dfd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -19,13 +19,17 @@
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.distributed as dist
import os
import time
import numpy as np
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union, Tuple
from types import SimpleNamespace
import transformers
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ipex_llm.utils.common import invalidInputError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
import logging
@ -107,6 +111,34 @@ def init_pipeline_parallel():
dist.init_process_group('ccl')
def low_mem_convert(model):
from ipex_llm.transformers.convert import convert_forward
import importlib
if 'llama' in model.config.model_type:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaForCausalLM,
llama_causallm_forward_4_37_lowmem)
elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"):
if model.config.num_layers == 40:
# for glm4-9b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(
model,
module.ChatGLMForConditionalGeneration,
glm4_conditional_generation_forward_lowmem)
else:
# for chatglm3-6b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(
model,
module.ChatGLMForConditionalGeneration,
chatglm3_conditional_generation_forward_lowmem)
return model
def _check_quantize_kv_cache(model, idx, batch_size):
# align use_quantize_kv_cache setting for different GPU in pipeline parallel
pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
@ -186,6 +218,11 @@ def pipeline_parallel(model, pipeline_parallel_stages):
model._modules['model'].norm = DummyLayer()
model._modules['lm_head'] = DummyLayer()
_enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM')
_enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1")
if _enable_lowmem:
model = low_mem_convert(model)
model.pipeline_parallel_stages = pipeline_parallel_stages
model.layer_start = layer_start
model.layer_end = layer_end
@ -867,3 +904,208 @@ def _is_chinese_char(cp):
return True
return False
def llama_causallm_forward_4_37_lowmem(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
# ipex-llm change starts
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) # noqa
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa
logits = torch.cat(logits, dim=-1)
else:
torch.xpu.empty_cache()
logits = self.lm_head(hidden_states)
torch.xpu.empty_cache()
# logits = logits.float()
# ipex-llm change ends
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def chatglm3_conditional_generation_forward_lowmem(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[-1:]
# ipex-llm change starts
torch.xpu.empty_cache()
lm_logits = self.transformer.output_layer(hidden_states)
torch.xpu.empty_cache()
lm_logits = lm_logits.transpose(0, 1).contiguous()
loss = None
if labels is not None:
# lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
# ipex-llm change ends
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def glm4_conditional_generation_forward_lowmem(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[:, -1:]
# ipex-llm change starts
torch.xpu.empty_cache()
lm_logits = self.transformer.output_layer(hidden_states)
torch.xpu.empty_cache()
loss = None
if labels is not None:
# lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
# ipex-llm change ends
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)