LLM: fix input length logic for run_transformer_int4_gpu (#9864)

* LLM: fix input length logic for run_transformer_int4_gpu

* small fix

* small fix

* small fix
This commit is contained in:
WeiguangHan 2024-01-10 18:20:14 +08:00 committed by GitHub
parent 53531ae4ee
commit 33fd1f9c76

View file

@ -399,8 +399,10 @@ def run_transformer_int4_gpu(repo_id,
# in_len.txt maybe shorter than we need,
# use much longer context to make sure input length
test_length = min(in_len*2, 8192)
while test_length not in [32, 256, 1024, 2048, 8192]:
while test_length not in [32, 256, 1024, 2048, 8192] and test_length < 8192:
test_length = test_length * 2
# For the sequence length not in [32, 256, 1024, 2048, 8192], it will be truncated from 8192.txt.
test_length = min(test_length, 8192)
input_str = open(f"prompt/{test_length}.txt", 'r').read()
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.