LLM: Add XPU Memory Optimizations for Pipeline Parallel (#11567)
Add XPU Memory Optimizations for Pipeline Parallel
This commit is contained in:
parent
f06d2f72fb
commit
79c742dfd5
1 changed files with 243 additions and 1 deletions
|
|
@ -19,13 +19,17 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
import torch.nn.functional as F
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional, Union, Tuple
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
import transformers
|
||||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -107,6 +111,34 @@ def init_pipeline_parallel():
|
||||||
dist.init_process_group('ccl')
|
dist.init_process_group('ccl')
|
||||||
|
|
||||||
|
|
||||||
|
def low_mem_convert(model):
|
||||||
|
from ipex_llm.transformers.convert import convert_forward
|
||||||
|
import importlib
|
||||||
|
if 'llama' in model.config.model_type:
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaForCausalLM,
|
||||||
|
llama_causallm_forward_4_37_lowmem)
|
||||||
|
elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"):
|
||||||
|
if model.config.num_layers == 40:
|
||||||
|
# for glm4-9b
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
module.ChatGLMForConditionalGeneration,
|
||||||
|
glm4_conditional_generation_forward_lowmem)
|
||||||
|
else:
|
||||||
|
# for chatglm3-6b
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
module.ChatGLMForConditionalGeneration,
|
||||||
|
chatglm3_conditional_generation_forward_lowmem)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _check_quantize_kv_cache(model, idx, batch_size):
|
def _check_quantize_kv_cache(model, idx, batch_size):
|
||||||
# align use_quantize_kv_cache setting for different GPU in pipeline parallel
|
# align use_quantize_kv_cache setting for different GPU in pipeline parallel
|
||||||
pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
|
pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
|
||||||
|
|
@ -186,6 +218,11 @@ def pipeline_parallel(model, pipeline_parallel_stages):
|
||||||
model._modules['model'].norm = DummyLayer()
|
model._modules['model'].norm = DummyLayer()
|
||||||
model._modules['lm_head'] = DummyLayer()
|
model._modules['lm_head'] = DummyLayer()
|
||||||
|
|
||||||
|
_enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM')
|
||||||
|
_enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1")
|
||||||
|
if _enable_lowmem:
|
||||||
|
model = low_mem_convert(model)
|
||||||
|
|
||||||
model.pipeline_parallel_stages = pipeline_parallel_stages
|
model.pipeline_parallel_stages = pipeline_parallel_stages
|
||||||
model.layer_start = layer_start
|
model.layer_start = layer_start
|
||||||
model.layer_end = layer_end
|
model.layer_end = layer_end
|
||||||
|
|
@ -867,3 +904,208 @@ def _is_chinese_char(cp):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def llama_causallm_forward_4_37_lowmem(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
# ipex-llm change starts
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) # noqa
|
||||||
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa
|
||||||
|
logits = torch.cat(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
# logits = logits.float()
|
||||||
|
|
||||||
|
# ipex-llm change ends
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chatglm3_conditional_generation_forward_lowmem(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
return_last_logit: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
if return_last_logit:
|
||||||
|
hidden_states = hidden_states[-1:]
|
||||||
|
|
||||||
|
# ipex-llm change starts
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
lm_logits = self.transformer.output_layer(hidden_states)
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# lm_logits = lm_logits.to(torch.float32)
|
||||||
|
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||||
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||||
|
loss = loss.to(hidden_states.dtype)
|
||||||
|
# ipex-llm change ends
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def glm4_conditional_generation_forward_lowmem(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
return_last_logit: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
if return_last_logit:
|
||||||
|
hidden_states = hidden_states[:, -1:]
|
||||||
|
# ipex-llm change starts
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
lm_logits = self.transformer.output_layer(hidden_states)
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# lm_logits = lm_logits.to(torch.float32)
|
||||||
|
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||||
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||||
|
loss = loss.to(hidden_states.dtype)
|
||||||
|
# ipex-llm change ends
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue