LLM: Add XPU Memory Optimizations for Pipeline Parallel (#11567)
Add XPU Memory Optimizations for Pipeline Parallel
This commit is contained in:
		
							parent
							
								
									f06d2f72fb
								
							
						
					
					
						commit
						79c742dfd5
					
				
					 1 changed files with 243 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -19,13 +19,17 @@
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
from torch.nn import CrossEntropyLoss
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Callable, List, Optional
 | 
			
		||||
from typing import Callable, List, Optional, Union, Tuple
 | 
			
		||||
from types import SimpleNamespace
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
import logging
 | 
			
		||||
| 
						 | 
				
			
			@ -107,6 +111,34 @@ def init_pipeline_parallel():
 | 
			
		|||
    dist.init_process_group('ccl')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def low_mem_convert(model):
 | 
			
		||||
    from ipex_llm.transformers.convert import convert_forward
 | 
			
		||||
    import importlib
 | 
			
		||||
    if 'llama' in model.config.model_type:
 | 
			
		||||
        convert_forward(
 | 
			
		||||
            model,
 | 
			
		||||
            transformers.models.llama.modeling_llama.LlamaForCausalLM,
 | 
			
		||||
            llama_causallm_forward_4_37_lowmem)
 | 
			
		||||
    elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"):
 | 
			
		||||
        if model.config.num_layers == 40:
 | 
			
		||||
            # for glm4-9b
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            convert_forward(
 | 
			
		||||
                model,
 | 
			
		||||
                module.ChatGLMForConditionalGeneration,
 | 
			
		||||
                glm4_conditional_generation_forward_lowmem)
 | 
			
		||||
        else:
 | 
			
		||||
            # for chatglm3-6b
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            convert_forward(
 | 
			
		||||
                model,
 | 
			
		||||
                module.ChatGLMForConditionalGeneration,
 | 
			
		||||
                chatglm3_conditional_generation_forward_lowmem)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_quantize_kv_cache(model, idx, batch_size):
 | 
			
		||||
    # align use_quantize_kv_cache setting for different GPU in pipeline parallel
 | 
			
		||||
    pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
 | 
			
		||||
| 
						 | 
				
			
			@ -186,6 +218,11 @@ def pipeline_parallel(model, pipeline_parallel_stages):
 | 
			
		|||
            model._modules['model'].norm = DummyLayer()
 | 
			
		||||
            model._modules['lm_head'] = DummyLayer()
 | 
			
		||||
 | 
			
		||||
    _enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM')
 | 
			
		||||
    _enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1")
 | 
			
		||||
    if _enable_lowmem:
 | 
			
		||||
        model = low_mem_convert(model)
 | 
			
		||||
 | 
			
		||||
    model.pipeline_parallel_stages = pipeline_parallel_stages
 | 
			
		||||
    model.layer_start = layer_start
 | 
			
		||||
    model.layer_end = layer_end
 | 
			
		||||
| 
						 | 
				
			
			@ -867,3 +904,208 @@ def _is_chinese_char(cp):
 | 
			
		|||
        return True
 | 
			
		||||
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_causallm_forward_4_37_lowmem(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
 | 
			
		||||
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions  # noqa
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  # noqa
 | 
			
		||||
    )
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 | 
			
		||||
    outputs = self.model(
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        past_key_values=past_key_values,
 | 
			
		||||
        inputs_embeds=inputs_embeds,
 | 
			
		||||
        use_cache=use_cache,
 | 
			
		||||
        output_attentions=output_attentions,
 | 
			
		||||
        output_hidden_states=output_hidden_states,
 | 
			
		||||
        return_dict=return_dict,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    hidden_states = outputs[0]
 | 
			
		||||
 | 
			
		||||
    # ipex-llm change starts
 | 
			
		||||
 | 
			
		||||
    if self.config.pretraining_tp > 1:
 | 
			
		||||
        lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)  # noqa
 | 
			
		||||
        logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  # noqa
 | 
			
		||||
        logits = torch.cat(logits, dim=-1)
 | 
			
		||||
    else:
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
    # logits = logits.float()
 | 
			
		||||
 | 
			
		||||
    # ipex-llm change ends
 | 
			
		||||
 | 
			
		||||
    loss = None
 | 
			
		||||
    if labels is not None:
 | 
			
		||||
        # Shift so that tokens < n predict n
 | 
			
		||||
        shift_logits = logits[..., :-1, :].contiguous()
 | 
			
		||||
        shift_labels = labels[..., 1:].contiguous()
 | 
			
		||||
        # Flatten the tokens
 | 
			
		||||
        loss_fct = CrossEntropyLoss()
 | 
			
		||||
        shift_logits = shift_logits.view(-1, self.config.vocab_size)
 | 
			
		||||
        shift_labels = shift_labels.view(-1)
 | 
			
		||||
        # Enable model parallelism
 | 
			
		||||
        shift_labels = shift_labels.to(shift_logits.device)
 | 
			
		||||
        loss = loss_fct(shift_logits, shift_labels)
 | 
			
		||||
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        output = (logits,) + outputs[1:]
 | 
			
		||||
        return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
    return CausalLMOutputWithPast(
 | 
			
		||||
        loss=loss,
 | 
			
		||||
        logits=logits,
 | 
			
		||||
        past_key_values=outputs.past_key_values,
 | 
			
		||||
        hidden_states=outputs.hidden_states,
 | 
			
		||||
        attentions=outputs.attentions,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm3_conditional_generation_forward_lowmem(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.Tensor] = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
    labels: Optional[torch.Tensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
    return_last_logit: Optional[bool] = False,
 | 
			
		||||
):
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    transformer_outputs = self.transformer(
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        past_key_values=past_key_values,
 | 
			
		||||
        inputs_embeds=inputs_embeds,
 | 
			
		||||
        use_cache=use_cache,
 | 
			
		||||
        output_hidden_states=output_hidden_states,
 | 
			
		||||
        return_dict=return_dict,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    hidden_states = transformer_outputs[0]
 | 
			
		||||
    if return_last_logit:
 | 
			
		||||
        hidden_states = hidden_states[-1:]
 | 
			
		||||
 | 
			
		||||
    # ipex-llm change starts
 | 
			
		||||
    torch.xpu.empty_cache()
 | 
			
		||||
    lm_logits = self.transformer.output_layer(hidden_states)
 | 
			
		||||
    torch.xpu.empty_cache()
 | 
			
		||||
    lm_logits = lm_logits.transpose(0, 1).contiguous()
 | 
			
		||||
 | 
			
		||||
    loss = None
 | 
			
		||||
    if labels is not None:
 | 
			
		||||
        # lm_logits = lm_logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        # Shift so that tokens < n predict n
 | 
			
		||||
        shift_logits = lm_logits[..., :-1, :].contiguous()
 | 
			
		||||
        shift_labels = labels[..., 1:].contiguous()
 | 
			
		||||
        # Flatten the tokens
 | 
			
		||||
        loss_fct = CrossEntropyLoss(ignore_index=-100)
 | 
			
		||||
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        lm_logits = lm_logits.to(hidden_states.dtype)
 | 
			
		||||
        loss = loss.to(hidden_states.dtype)
 | 
			
		||||
    # ipex-llm change ends
 | 
			
		||||
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        output = (lm_logits,) + transformer_outputs[1:]
 | 
			
		||||
        return ((loss,) + output) if loss is not None else output
 | 
			
		||||
 | 
			
		||||
    return CausalLMOutputWithPast(
 | 
			
		||||
        loss=loss,
 | 
			
		||||
        logits=lm_logits,
 | 
			
		||||
        past_key_values=transformer_outputs.past_key_values,
 | 
			
		||||
        hidden_states=transformer_outputs.hidden_states,
 | 
			
		||||
        attentions=transformer_outputs.attentions,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def glm4_conditional_generation_forward_lowmem(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.Tensor] = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
    labels: Optional[torch.Tensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
    return_last_logit: Optional[bool] = False,
 | 
			
		||||
):
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    transformer_outputs = self.transformer(
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        past_key_values=past_key_values,
 | 
			
		||||
        inputs_embeds=inputs_embeds,
 | 
			
		||||
        use_cache=use_cache,
 | 
			
		||||
        output_hidden_states=output_hidden_states,
 | 
			
		||||
        return_dict=return_dict,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    hidden_states = transformer_outputs[0]
 | 
			
		||||
    if return_last_logit:
 | 
			
		||||
        hidden_states = hidden_states[:, -1:]
 | 
			
		||||
    # ipex-llm change starts
 | 
			
		||||
    torch.xpu.empty_cache()
 | 
			
		||||
    lm_logits = self.transformer.output_layer(hidden_states)
 | 
			
		||||
    torch.xpu.empty_cache()
 | 
			
		||||
 | 
			
		||||
    loss = None
 | 
			
		||||
    if labels is not None:
 | 
			
		||||
        # lm_logits = lm_logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        # Shift so that tokens < n predict n
 | 
			
		||||
        shift_logits = lm_logits[..., :-1, :].contiguous()
 | 
			
		||||
        shift_labels = labels[..., 1:].contiguous()
 | 
			
		||||
        # Flatten the tokens
 | 
			
		||||
        loss_fct = CrossEntropyLoss(ignore_index=-100)
 | 
			
		||||
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        lm_logits = lm_logits.to(hidden_states.dtype)
 | 
			
		||||
        loss = loss.to(hidden_states.dtype)
 | 
			
		||||
    # ipex-llm change ends
 | 
			
		||||
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        output = (lm_logits,) + transformer_outputs[1:]
 | 
			
		||||
        return ((loss,) + output) if loss is not None else output
 | 
			
		||||
 | 
			
		||||
    return CausalLMOutputWithPast(
 | 
			
		||||
        loss=loss,
 | 
			
		||||
        logits=lm_logits,
 | 
			
		||||
        past_key_values=transformer_outputs.past_key_values,
 | 
			
		||||
        hidden_states=transformer_outputs.hidden_states,
 | 
			
		||||
        attentions=transformer_outputs.attentions,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue