lm_head empty_cache for more models (#10490)
* modify constraint * fix style
This commit is contained in:
parent
1579ee4421
commit
30f111cd32
1 changed files with 8 additions and 5 deletions
|
|
@ -533,17 +533,20 @@ class LowBitLinear(nn.Linear):
|
||||||
self.enable_xetla = enable_xetla
|
self.enable_xetla = enable_xetla
|
||||||
self.optimize_lm_head = optimize_lm_head
|
self.optimize_lm_head = optimize_lm_head
|
||||||
self.device = None # detected only once in the first forward
|
self.device = None # detected only once in the first forward
|
||||||
# empty cache before and after lm_head at first token (by default on arc) for models
|
# empty cache before and after lm_head at first token (by default on arc)
|
||||||
# with large vocabulary (e.g. baichuan/qwen) when given long input at inference time.
|
# especially for baichuan/qwen when given long input at inference time.
|
||||||
# The condition makes sure that empty cache only takes effect if this layer is lm_head.
|
# The condition makes sure that empty cache only takes effect if this layer is lm_head.
|
||||||
# TODO: may modify the value constraints for other models.
|
# For other models like llama, lm_cache will be applied as well
|
||||||
self.low_memory_mode = self.in_len * self.out_len >= 70000*4096
|
# since performance isn't impacted.
|
||||||
|
self.is_lm_head = self.in_len * self.out_len >= 30000 * 4096
|
||||||
|
self.low_memory_mode = self.is_lm_head
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# empty cache before and after lm_head at first token when input > 1024
|
# empty cache before and after lm_head at first token when input > 1024
|
||||||
# on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time.
|
# on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time.
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
self.device = get_xpu_device_type(self.weight.data)
|
self.device = get_xpu_device_type(self.weight.data)
|
||||||
|
# TODO: may remove BIGDL_LOW_MEMORY_MODE here, probably not necessary
|
||||||
self.low_memory_mode = \
|
self.low_memory_mode = \
|
||||||
self.low_memory_mode and \
|
self.low_memory_mode and \
|
||||||
(self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1")
|
(self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue