fix llama related import (#12611)
This commit is contained in:
parent
54b1d7d333
commit
5f5ac8a856
4 changed files with 2 additions and 137 deletions
|
|
@ -45,7 +45,7 @@ import torch.nn.functional as F
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
from ipex_llm.transformers.models.llama import repeat_kv
|
from ipex_llm.transformers.models.utils import repeat_kv
|
||||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
||||||
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
|
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||||
|
|
|
||||||
|
|
@ -34,8 +34,8 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from ipex_llm.transformers.models.utils import repeat_kv
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.llama import repeat_kv
|
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
|
||||||
|
|
@ -1,124 +0,0 @@
|
||||||
#
|
|
||||||
# Copyright 2016 The BigDL Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# This file is adapted from
|
|
||||||
# https://github.com/Dao-AILab/flash-attention/blob/main/tests/layers/test_rotary.py
|
|
||||||
#
|
|
||||||
# Copyright (c) 2023, Tri Dao.
|
|
||||||
#
|
|
||||||
|
|
||||||
import os
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
apply_rotary_pos_emb as apply_rotary_pos_emb_llama,
|
|
||||||
)
|
|
||||||
from ipex_llm.transformers.layers.rope_embedding import apply_fast_rope_embedding
|
|
||||||
|
|
||||||
device = os.environ['DEVICE']
|
|
||||||
print(f'Running on {device}')
|
|
||||||
if 'xpu' not in device:
|
|
||||||
print(f"The layer.fast_rope test should running on xpu")
|
|
||||||
|
|
||||||
# llama-style rotary embedding
|
|
||||||
@pytest.mark.parametrize("seqlen_offset", [0, 711])
|
|
||||||
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
|
|
||||||
def test_rotary(rotary_emb_fraction, seqlen_offset):
|
|
||||||
device = "xpu"
|
|
||||||
dtype = torch.float16
|
|
||||||
rtol, atol = (1e-3, 5e-3)
|
|
||||||
# set seed
|
|
||||||
torch.random.manual_seed(0)
|
|
||||||
batch_size = 8
|
|
||||||
seqlen_total = 2048
|
|
||||||
seqlen = seqlen_total - seqlen_offset
|
|
||||||
seqlen_offset = torch.tensor([[seqlen_offset]], device=device)
|
|
||||||
nheads = 32
|
|
||||||
headdim = 128
|
|
||||||
rotary_dim = int(headdim * rotary_emb_fraction)
|
|
||||||
qkv = torch.randn(
|
|
||||||
batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
|
||||||
requires_grad=True
|
|
||||||
)
|
|
||||||
rotary_llama = LlamaRotaryEmbedding(rotary_dim, seqlen_total, device=device)
|
|
||||||
# Doesn't matter what tensor we pass in, rotary_llama only uses the device
|
|
||||||
# of the tensor
|
|
||||||
cos_llama, sin_llama = rotary_llama(qkv, seq_len=seqlen_total)
|
|
||||||
cos_llama, sin_llama = cos_llama.to(dtype=dtype), sin_llama.to(dtype=dtype)
|
|
||||||
q_pt = (
|
|
||||||
rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
|
|
||||||
.detach()
|
|
||||||
.clone()
|
|
||||||
.requires_grad_(True)
|
|
||||||
)
|
|
||||||
k_pt = (
|
|
||||||
rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
|
|
||||||
.detach()
|
|
||||||
.clone()
|
|
||||||
.requires_grad_(True)
|
|
||||||
)
|
|
||||||
q_pt_fast = (
|
|
||||||
rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
|
|
||||||
.detach()
|
|
||||||
.clone()
|
|
||||||
.requires_grad_(True)
|
|
||||||
)
|
|
||||||
k_pt_fast = (
|
|
||||||
rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
|
|
||||||
.detach()
|
|
||||||
.clone()
|
|
||||||
.requires_grad_(True)
|
|
||||||
)
|
|
||||||
q_llama, k_llama = apply_rotary_pos_emb_llama(q_pt, k_pt, cos_llama,
|
|
||||||
sin_llama, position_ids=seqlen_offset)
|
|
||||||
q_fast, k_fast = apply_fast_rope_embedding(q_pt_fast, k_pt_fast,
|
|
||||||
position_ids=seqlen_offset,
|
|
||||||
model_family="llama")
|
|
||||||
assert torch.allclose(
|
|
||||||
rearrange(q_llama, "b h s d -> b s h d"),
|
|
||||||
rearrange(q_fast, "b h s d -> b s h d"), rtol=rtol, atol=atol
|
|
||||||
)
|
|
||||||
assert torch.allclose(
|
|
||||||
rearrange(k_llama, "b h s d -> b s h d"),
|
|
||||||
rearrange(k_fast, "b h s d -> b s h d"), rtol=rtol, atol=atol
|
|
||||||
)
|
|
||||||
|
|
||||||
g = torch.randn_like(q_fast)
|
|
||||||
q_fast.backward(g)
|
|
||||||
k_fast.backward(g)
|
|
||||||
q_llama.backward(g)
|
|
||||||
k_llama.backward(g)
|
|
||||||
|
|
||||||
assert torch.allclose(
|
|
||||||
q_pt.grad,
|
|
||||||
q_pt_fast.grad,
|
|
||||||
rtol=rtol,
|
|
||||||
atol=atol,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.allclose(
|
|
||||||
k_pt.grad,
|
|
||||||
k_pt_fast.grad,
|
|
||||||
rtol=rtol,
|
|
||||||
atol=atol,
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
||||||
|
|
@ -43,14 +43,3 @@ time=$((now-start))
|
||||||
|
|
||||||
echo "Bigdl-llm gpu inference tests finished"
|
echo "Bigdl-llm gpu inference tests finished"
|
||||||
echo "Time used:$time seconds"
|
echo "Time used:$time seconds"
|
||||||
|
|
||||||
echo "# Start testing layers.fast_rope_embedding"
|
|
||||||
start=$(date "+%s")
|
|
||||||
|
|
||||||
pytest_check_error pytest ${LLM_INFERENCE_TEST_DIR}/test_layer_fast_rope.py -v -s
|
|
||||||
|
|
||||||
now=$(date "+%s")
|
|
||||||
time=$((now-start))
|
|
||||||
|
|
||||||
echo "Bigdl-llm gpu layers.fast_rope_embedding tests finished"
|
|
||||||
echo "Time used:$time seconds"
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue