parent
2be8bbd236
commit
76e30d8ec8
1 changed files with 19 additions and 0 deletions
|
|
@ -532,8 +532,21 @@ class LowBitLinear(nn.Linear):
|
||||||
self.compute_dtype = None # only for training
|
self.compute_dtype = None # only for training
|
||||||
self.enable_xetla = enable_xetla
|
self.enable_xetla = enable_xetla
|
||||||
self.optimize_lm_head = optimize_lm_head
|
self.optimize_lm_head = optimize_lm_head
|
||||||
|
self.device = None # detected only once in the first forward
|
||||||
|
# empty cache before and after lm_head at first token (by default on arc) for models
|
||||||
|
# with large vocabulary (e.g. baichuan/qwen) when given long input at inference time.
|
||||||
|
# The condition makes sure that empty cache only takes effect if this layer is lm_head.
|
||||||
|
# TODO: may modify the value constraints for other models.
|
||||||
|
self.low_memory_mode = self.in_len * self.out_len >= 70000*4096
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
|
# empty cache before and after lm_head at first token when input > 1024
|
||||||
|
# on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time.
|
||||||
|
if self.device is None:
|
||||||
|
self.device = get_xpu_device_type(self.weight.data)
|
||||||
|
self.low_memory_mode = \
|
||||||
|
self.low_memory_mode and\
|
||||||
|
(self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1")
|
||||||
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||||
# we should check both self.training and torch.is_inference_mode_enabled().
|
# we should check both self.training and torch.is_inference_mode_enabled().
|
||||||
is_training = self.training and not torch.is_inference_mode_enabled()
|
is_training = self.training and not torch.is_inference_mode_enabled()
|
||||||
|
|
@ -599,6 +612,10 @@ class LowBitLinear(nn.Linear):
|
||||||
# current workaround to reduce first token latency of fp32 input
|
# current workaround to reduce first token latency of fp32 input
|
||||||
# sometimes fp16 cause nan and training instability
|
# sometimes fp16 cause nan and training instability
|
||||||
# disable the conversion when training
|
# disable the conversion when training
|
||||||
|
# TODO: may modify the input length condition for empty cache.
|
||||||
|
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
|
||||||
|
if do_empty_cache:
|
||||||
|
torch.xpu.empty_cache()
|
||||||
if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32 and \
|
if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32 and \
|
||||||
not use_xmx(x_2d, self.weight.qtype):
|
not use_xmx(x_2d, self.weight.qtype):
|
||||||
x_2d = x_2d.half()
|
x_2d = x_2d.half()
|
||||||
|
|
@ -608,6 +625,8 @@ class LowBitLinear(nn.Linear):
|
||||||
else:
|
else:
|
||||||
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||||
input_seq_size)
|
input_seq_size)
|
||||||
|
if do_empty_cache:
|
||||||
|
torch.xpu.empty_cache()
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue