fix gemma2 runtime error caused by sliding window (#11788)
* fix runtime error * revert workflow
This commit is contained in:
parent
dbd14251dd
commit
43cca3be27
2 changed files with 36 additions and 33 deletions
55
.github/workflows/llm_performance_tests.yml
vendored
55
.github/workflows/llm_performance_tests.yml
vendored
|
|
@ -1207,35 +1207,32 @@ jobs:
|
||||||
|
|
||||||
call conda deactivate
|
call conda deactivate
|
||||||
|
|
||||||
# NOTE: Gemma2 not working for 4096-512.
|
- name: Prepare igpu perf test for transformers 4.43 (4096-512 int4+fp16)
|
||||||
# When it works, uncomment this section and remember to change "'s/{today}_test3/{today}_test1/g'" in next section.
|
shell: bash
|
||||||
|
run: |
|
||||||
#- name: Prepare igpu perf test for transformers 4.43 (4096-512 int4+fp16)
|
sed -i 's/{today}_test3/{today}_test4/g' python/llm/dev/benchmark/all-in-one/run.py
|
||||||
# shell: bash
|
sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf/4096-512_int4_fp16_443.yaml
|
||||||
# run: |
|
|
||||||
# sed -i 's/{today}_test3/{today}_test4/g' python/llm/dev/benchmark/all-in-one/run.py
|
|
||||||
# sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf/4096-512_int4_fp16_443.yaml
|
|
||||||
|
|
||||||
#- name: Test on igpu for transformers 4.43 (4096-512 int4+fp16)
|
- name: Test on igpu for transformers 4.43 (4096-512 int4+fp16)
|
||||||
# shell: cmd
|
shell: cmd
|
||||||
# run: |
|
run: |
|
||||||
# call conda activate igpu-perf
|
call conda activate igpu-perf
|
||||||
# pip install transformers==4.43.1
|
pip install transformers==4.43.1
|
||||||
# pip install trl
|
pip install trl
|
||||||
#
|
|
||||||
# set SYCL_CACHE_PERSISTENT=1
|
set SYCL_CACHE_PERSISTENT=1
|
||||||
# set BIGDL_LLM_XMX_DISABLED=1
|
set BIGDL_LLM_XMX_DISABLED=1
|
||||||
#
|
|
||||||
# cd python\llm\dev\benchmark\all-in-one
|
cd python\llm\dev\benchmark\all-in-one
|
||||||
# move ..\..\..\test\benchmark\igpu-perf\4096-512_int4_fp16_443.yaml config.yaml
|
move ..\..\..\test\benchmark\igpu-perf\4096-512_int4_fp16_443.yaml config.yaml
|
||||||
# set PYTHONIOENCODING=utf-8
|
set PYTHONIOENCODING=utf-8
|
||||||
# python run.py >> %CSV_SAVE_PATH%\4096-512_int4_fp16\log\%LOG_FILE% 2>&1
|
python run.py >> %CSV_SAVE_PATH%\4096-512_int4_fp16\log\%LOG_FILE% 2>&1
|
||||||
# if %ERRORLEVEL% neq 0 (exit /b 1)
|
if %ERRORLEVEL% neq 0 (exit /b 1)
|
||||||
# python ..\..\..\test\benchmark\igpu-perf\check_csv_results.py --yaml-file config.yaml --suffix test4
|
python ..\..\..\test\benchmark\igpu-perf\check_csv_results.py --yaml-file config.yaml --suffix test4
|
||||||
# if %ERRORLEVEL% neq 0 (exit /b 1)
|
if %ERRORLEVEL% neq 0 (exit /b 1)
|
||||||
#
|
|
||||||
# pip uninstall trl -y
|
pip uninstall trl -y
|
||||||
# call conda deactivate
|
call conda deactivate
|
||||||
|
|
||||||
- name: Concat csv and generate html (4096-512 int4+fp16)
|
- name: Concat csv and generate html (4096-512 int4+fp16)
|
||||||
shell: cmd
|
shell: cmd
|
||||||
|
|
@ -1259,7 +1256,7 @@ jobs:
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sed -i 's/4096-512/1024-128/g' python/llm/dev/benchmark/all-in-one/run.py
|
sed -i 's/4096-512/1024-128/g' python/llm/dev/benchmark/all-in-one/run.py
|
||||||
sed -i 's/{today}_test3/{today}_test1/g' python/llm/dev/benchmark/all-in-one/run.py
|
sed -i 's/{today}_test4/{today}_test1/g' python/llm/dev/benchmark/all-in-one/run.py
|
||||||
sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf/1024-128_int4_fp16_loadlowbit.yaml
|
sed -i "s/path to your local model hub/$MODEL_HUB_DIR/g" python/llm/test/benchmark/igpu-perf/1024-128_int4_fp16_loadlowbit.yaml
|
||||||
|
|
||||||
- name: Test on igpu (load_low_bit 1024-128 int4+fp16)
|
- name: Test on igpu (load_low_bit 1024-128 int4+fp16)
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,8 @@ def gemma2_attention_forward(
|
||||||
# IPEX_LLM OPT: sdp
|
# IPEX_LLM OPT: sdp
|
||||||
kv_seq_len = q_len if past_key_value is None else past_key_value.kv_seq_len
|
kv_seq_len = q_len if past_key_value is None else past_key_value.kv_seq_len
|
||||||
if (use_sdp_causal(q_len, kv_seq_len, -1, query_states, self.training)
|
if (use_sdp_causal(q_len, kv_seq_len, -1, query_states, self.training)
|
||||||
and kv_seq_len <= key_states.size(2)):
|
and kv_seq_len <= key_states.size(2) and
|
||||||
|
(self.sliding_window is None or kv_seq_len < self.sliding_window)):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
attn_output = xe_addons.gemma2_sdp_causal(query_states,
|
attn_output = xe_addons.gemma2_sdp_causal(query_states,
|
||||||
|
|
@ -141,10 +142,15 @@ def gemma2_attention_forward(
|
||||||
elif use_sdp(q_len, kv_seq_len, -1, query_states):
|
elif use_sdp(q_len, kv_seq_len, -1, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
attn_mask = attention_mask[:, :, :q_len, : key_states.shape[-2]]
|
||||||
|
else:
|
||||||
|
attn_mask = attention_mask
|
||||||
|
|
||||||
attn_output = xe_addons.gemma2_sdp(query_states,
|
attn_output = xe_addons.gemma2_sdp(query_states,
|
||||||
key_states[:, :, :kv_seq_len, :],
|
key_states,
|
||||||
value_states[:, :, :kv_seq_len, :],
|
value_states,
|
||||||
attention_mask[:, :, :q_len, :kv_seq_len],
|
attn_mask,
|
||||||
self.config.attn_logit_softcapping,
|
self.config.attn_logit_softcapping,
|
||||||
self.scaling)
|
self.scaling)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue