Fix No module named 'transformers.cache_utils' with transformers < 4.36 (#10835)
* update sdp condition * update * fix * fix 431 error * revert sdp & style fix * fix * meet comments
This commit is contained in:
parent
c6e868f7ad
commit
3daad242b8
2 changed files with 14 additions and 12 deletions
|
|
@ -316,7 +316,8 @@ def lookup_generate(self,
|
|||
if max_of_max_matched != max_matched:
|
||||
output_ids = output_ids[:, :max_matched]
|
||||
new_cache_size = max_of_max_matched - max_matched
|
||||
past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size)
|
||||
past_key_values = _crop_past_key_values(self, past_key_values,
|
||||
new_cache_size)
|
||||
|
||||
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||
|
||||
|
|
|
|||
|
|
@ -451,19 +451,20 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
|
|||
|
||||
|
||||
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
|
||||
if hasattr(past_key_values, "_seen_tokens"):
|
||||
past_key_values._seen_tokens -= new_cache_size
|
||||
else:
|
||||
past_key_values.seen_tokens -= new_cache_size
|
||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
|
||||
if hasattr(past_key_values, "_seen_tokens"):
|
||||
past_key_values._seen_tokens -= new_cache_size
|
||||
else:
|
||||
past_key_values.seen_tokens -= new_cache_size
|
||||
|
||||
for i, k in enumerate(past_key_values.key_cache):
|
||||
past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
|
||||
for i, v in enumerate(past_key_values.value_cache):
|
||||
past_key_values.value_cache[i] = v[:, :, :-new_cache_size, :]
|
||||
for i, k in enumerate(past_key_values.key_cache):
|
||||
past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
|
||||
for i, v in enumerate(past_key_values.value_cache):
|
||||
past_key_values.value_cache[i] = v[:, :, :-new_cache_size, :]
|
||||
|
||||
return past_key_values
|
||||
return past_key_values
|
||||
|
||||
if _enable_ipex:
|
||||
cur_len = past_key_values[0][0].size(1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue