Support fast rope for training (#9745)
* init * init * fix style * add test and fix * address comment * update * merge upstream main
This commit is contained in:
		
							parent
							
								
									0c498a7b64
								
							
						
					
					
						commit
						98b86f83d4
					
				
					 6 changed files with 344 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -0,0 +1,67 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import logging
 | 
			
		||||
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger("bigdl.llm.rope_embedding")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Fast RoPE for finetuning, split the q and k
 | 
			
		||||
def apply_fast_rope_embedding(q, k, position_ids, model_family):
 | 
			
		||||
    if q.device.type != "xpu":
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"only xpu is supported in this function")
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
 | 
			
		||||
                        "mixtral"]:
 | 
			
		||||
        q_embed = FastRopeEmbedding.apply(q, position_ids)
 | 
			
		||||
        k_embed = FastRopeEmbedding.apply(k, position_ids)
 | 
			
		||||
        return q_embed, k_embed
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"{model_family} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Fast RoPE for finetuning, split the q and k
 | 
			
		||||
class FastRopeEmbedding(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, x, position_ids):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        x_embed = torch.empty(x.shape, dtype=x.dtype, device=x.device)
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_or_k(x, position_ids,
 | 
			
		||||
                                                       x_embed, False)
 | 
			
		||||
        ctx.save_for_backward(position_ids)
 | 
			
		||||
        return x_embed
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    def backward(ctx, grad_output):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        # LOG.info(f"backward, grad_output: {grad_output}")
 | 
			
		||||
        position_ids, = ctx.saved_tensors
 | 
			
		||||
        dx = torch.empty(grad_output.shape,
 | 
			
		||||
                         dtype=grad_output.dtype,
 | 
			
		||||
                         device=grad_output.device)
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_or_k(grad_output,
 | 
			
		||||
                                                       position_ids,
 | 
			
		||||
                                                       dx,
 | 
			
		||||
                                                       True)
 | 
			
		||||
        # LOG.info(f"backward, dx: {dx}, position_ids: {position_ids},
 | 
			
		||||
        #          requires_grad: {ctx.needs_input_grad}")
 | 
			
		||||
        return dx, None
 | 
			
		||||
| 
						 | 
				
			
			@ -127,6 +127,15 @@ def should_use_fuse_rope(self, query_states, position_ids):
 | 
			
		|||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Only for xpu and training
 | 
			
		||||
def should_use_fast_rope(self, query_states, position_ids):
 | 
			
		||||
    use_fuse_rope = query_states.device.type == "xpu"
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and (self.training or query_states.requires_grad)
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and position_ids is not None
 | 
			
		||||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_attention_forward_4_31(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -911,3 +920,115 @@ def llama_model_selective_batching_forward_4_31(
 | 
			
		|||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attns,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# For training
 | 
			
		||||
def llama_attention_fast_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
    padding_mask: Optional[torch.LongTensor] = None,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    use_fast_rope = should_use_fast_rope(self, hidden_states, position_ids)
 | 
			
		||||
 | 
			
		||||
    # Check for inference
 | 
			
		||||
    if use_cache and past_key_value is not None and q_len == 1:
 | 
			
		||||
        A, past_key_value = llama_attention_forward_4_31(
 | 
			
		||||
            self,
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            past_key_value,
 | 
			
		||||
            position_ids,
 | 
			
		||||
        )
 | 
			
		||||
        return A, None, past_key_value
 | 
			
		||||
 | 
			
		||||
    if self.config.pretraining_tp > 1:
 | 
			
		||||
        key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
 | 
			
		||||
                             self.config.pretraining_tp)
 | 
			
		||||
        query_slices = self.q_proj.weight.split(
 | 
			
		||||
            (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
 | 
			
		||||
        )
 | 
			
		||||
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
 | 
			
		||||
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
 | 
			
		||||
 | 
			
		||||
        query_states = [F.linear(hidden_states, query_slices[i])
 | 
			
		||||
                        for i in range(self.config.pretraining_tp)]
 | 
			
		||||
        query_states = torch.cat(query_states, dim=-1)
 | 
			
		||||
 | 
			
		||||
        key_states = [F.linear(hidden_states, key_slices[i])
 | 
			
		||||
                      for i in range(self.config.pretraining_tp)]
 | 
			
		||||
        key_states = torch.cat(key_states, dim=-1)
 | 
			
		||||
 | 
			
		||||
        value_states = [F.linear(hidden_states, value_slices[i])
 | 
			
		||||
                        for i in range(self.config.pretraining_tp)]
 | 
			
		||||
        value_states = torch.cat(value_states, dim=-1)
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        query_states = self.q_proj(hidden_states)
 | 
			
		||||
        key_states = self.k_proj(hidden_states)
 | 
			
		||||
        value_states = self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
 | 
			
		||||
                                 self.head_dim).transpose(1, 2)
 | 
			
		||||
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
 | 
			
		||||
                                     self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    if use_fast_rope:
 | 
			
		||||
        from bigdl.llm.transformers.layers.rope_embedding import apply_fast_rope_embedding
 | 
			
		||||
        query_states, key_states = apply_fast_rope_embedding(query_states,
 | 
			
		||||
                                                             key_states,
 | 
			
		||||
                                                             position_ids,
 | 
			
		||||
                                                             "llama")
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                        cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
			
		||||
        value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
			
		||||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
			
		||||
                                           attention_mask,
 | 
			
		||||
                                           bsz, q_len, kv_seq_len,
 | 
			
		||||
                                           self.head_dim, self.num_heads)
 | 
			
		||||
 | 
			
		||||
    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
			
		||||
    if attn_output.size() != attn_output_size:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"`attn_output` should be of size {attn_output_size},"
 | 
			
		||||
                          f" but is {attn_output.size()}")
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    if self.config.pretraining_tp > 1:
 | 
			
		||||
        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
 | 
			
		||||
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
 | 
			
		||||
                                                 dim=1)
 | 
			
		||||
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
 | 
			
		||||
                           for i in range(self.config.pretraining_tp)])
 | 
			
		||||
    else:
 | 
			
		||||
        attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -170,7 +170,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
 | 
			
		|||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
 | 
			
		||||
                        "mixtral"]:
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_qk(q, k, position_ids, q_embed, k_embed)
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed)
 | 
			
		||||
        return q_embed, k_embed
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,6 +49,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import logging
 | 
			
		||||
from torch.nn import Linear, Embedding
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size
 | 
			
		||||
from peft.tuners.lora import LoraLayer
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +59,8 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		|||
import functools
 | 
			
		||||
from bigdl.llm.transformers import training_patch
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger("bigdl.llm.qlora")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LoraLowBitLinear(LowBitLinear, LoraLayer):
 | 
			
		||||
    # Lora implemented in a dense layer
 | 
			
		||||
| 
						 | 
				
			
			@ -252,6 +255,7 @@ def get_peft_model(*args, **kwargs):
 | 
			
		|||
 | 
			
		||||
    if model.device.type == "xpu":
 | 
			
		||||
        cast_lora_weight(model, torch.bfloat16)
 | 
			
		||||
        _optimize_post(model)
 | 
			
		||||
        torch.xpu.synchronize()
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
| 
						 | 
				
			
			@ -345,3 +349,18 @@ def cast_lora_weight(model, dtype=torch.bfloat16):
 | 
			
		|||
            if hasattr(module, 'weight'):
 | 
			
		||||
                if module.weight.dtype == torch.float32:
 | 
			
		||||
                    module = module.to(dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _optimize_post(model):
 | 
			
		||||
    import transformers
 | 
			
		||||
    from packaging import version
 | 
			
		||||
    from bigdl.llm.transformers.convert import convert_forward
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_attention_fast_forward
 | 
			
		||||
 | 
			
		||||
    trans_version = transformers.__version__
 | 
			
		||||
    if version.parse(trans_version) >= version.parse("4.31.0"):
 | 
			
		||||
        LOG.info("Optimizing Llama finetuning....")
 | 
			
		||||
        convert_forward(
 | 
			
		||||
            model,
 | 
			
		||||
            transformers.models.llama.modeling_llama.LlamaAttention,
 | 
			
		||||
            llama_attention_fast_forward,)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										124
									
								
								python/llm/test/inference_gpu/test_layer_fast_rope.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								python/llm/test/inference_gpu/test_layer_fast_rope.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,124 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
#
 | 
			
		||||
# This file is adapted from 
 | 
			
		||||
# https://github.com/Dao-AILab/flash-attention/blob/main/tests/layers/test_rotary.py
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) 2023, Tri Dao.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
 | 
			
		||||
from transformers.models.llama.modeling_llama import (
 | 
			
		||||
    apply_rotary_pos_emb as apply_rotary_pos_emb_llama,
 | 
			
		||||
)
 | 
			
		||||
from bigdl.llm.transformers.layers.rope_embedding import apply_fast_rope_embedding
 | 
			
		||||
 | 
			
		||||
device = os.environ['DEVICE']
 | 
			
		||||
print(f'Running on {device}')
 | 
			
		||||
if 'xpu' not in device:
 | 
			
		||||
    print(f"The layer.fast_rope test should running on xpu")
 | 
			
		||||
 | 
			
		||||
# llama-style rotary embedding
 | 
			
		||||
@pytest.mark.parametrize("seqlen_offset", [0, 711])
 | 
			
		||||
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
 | 
			
		||||
def test_rotary(rotary_emb_fraction, seqlen_offset):
 | 
			
		||||
    device = "xpu"
 | 
			
		||||
    dtype = torch.float16
 | 
			
		||||
    rtol, atol = (1e-3, 5e-3)
 | 
			
		||||
    # set seed
 | 
			
		||||
    torch.random.manual_seed(0)
 | 
			
		||||
    batch_size = 8
 | 
			
		||||
    seqlen_total = 2048
 | 
			
		||||
    seqlen = seqlen_total - seqlen_offset
 | 
			
		||||
    seqlen_offset = torch.tensor([[seqlen_offset]], device=device)
 | 
			
		||||
    nheads = 32
 | 
			
		||||
    headdim = 128
 | 
			
		||||
    rotary_dim = int(headdim * rotary_emb_fraction)
 | 
			
		||||
    qkv = torch.randn(
 | 
			
		||||
        batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
 | 
			
		||||
        requires_grad=True
 | 
			
		||||
    )
 | 
			
		||||
    rotary_llama = LlamaRotaryEmbedding(rotary_dim, seqlen_total, device=device)
 | 
			
		||||
    # Doesn't matter what tensor we pass in, rotary_llama only uses the device
 | 
			
		||||
    # of the tensor
 | 
			
		||||
    cos_llama, sin_llama = rotary_llama(qkv, seq_len=seqlen_total)
 | 
			
		||||
    cos_llama, sin_llama = cos_llama.to(dtype=dtype), sin_llama.to(dtype=dtype)
 | 
			
		||||
    q_pt = (
 | 
			
		||||
        rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
 | 
			
		||||
        .detach()
 | 
			
		||||
        .clone()
 | 
			
		||||
        .requires_grad_(True)
 | 
			
		||||
    )
 | 
			
		||||
    k_pt = (
 | 
			
		||||
        rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
 | 
			
		||||
        .detach()
 | 
			
		||||
        .clone()
 | 
			
		||||
        .requires_grad_(True)
 | 
			
		||||
    )
 | 
			
		||||
    q_pt_fast = (
 | 
			
		||||
        rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
 | 
			
		||||
        .detach()
 | 
			
		||||
        .clone()
 | 
			
		||||
        .requires_grad_(True)
 | 
			
		||||
    )
 | 
			
		||||
    k_pt_fast = (
 | 
			
		||||
        rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
 | 
			
		||||
        .detach()
 | 
			
		||||
        .clone()
 | 
			
		||||
        .requires_grad_(True)
 | 
			
		||||
    )
 | 
			
		||||
    q_llama, k_llama = apply_rotary_pos_emb_llama(q_pt, k_pt, cos_llama,
 | 
			
		||||
                                                  sin_llama, position_ids=seqlen_offset)
 | 
			
		||||
    q_fast, k_fast = apply_fast_rope_embedding(q_pt_fast, k_pt_fast,
 | 
			
		||||
                                               position_ids=seqlen_offset,
 | 
			
		||||
                                               model_family="llama")
 | 
			
		||||
    assert torch.allclose(
 | 
			
		||||
        rearrange(q_llama, "b h s d -> b s h d"),
 | 
			
		||||
        rearrange(q_fast, "b h s d -> b s h d"), rtol=rtol, atol=atol
 | 
			
		||||
    )
 | 
			
		||||
    assert torch.allclose(
 | 
			
		||||
        rearrange(k_llama, "b h s d -> b s h d"),
 | 
			
		||||
        rearrange(k_fast, "b h s d -> b s h d"), rtol=rtol, atol=atol
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    g = torch.randn_like(q_fast)
 | 
			
		||||
    q_fast.backward(g)
 | 
			
		||||
    k_fast.backward(g)
 | 
			
		||||
    q_llama.backward(g)
 | 
			
		||||
    k_llama.backward(g)
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(
 | 
			
		||||
        q_pt.grad,
 | 
			
		||||
        q_pt_fast.grad,
 | 
			
		||||
        rtol=rtol,
 | 
			
		||||
        atol=atol,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(
 | 
			
		||||
        k_pt.grad,
 | 
			
		||||
        k_pt_fast.grad,
 | 
			
		||||
        rtol=rtol,
 | 
			
		||||
        atol=atol,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
| 
						 | 
				
			
			@ -22,5 +22,16 @@ pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api.py -v -s
 | 
			
		|||
now=$(date "+%s")
 | 
			
		||||
time=$((now-start))
 | 
			
		||||
 | 
			
		||||
echo "Bigdl-llm gpu tests finished"
 | 
			
		||||
echo "Bigdl-llm gpu inference tests finished"
 | 
			
		||||
echo "Time used:$time seconds"
 | 
			
		||||
 | 
			
		||||
echo "# Start testing layers.fast_rope_embedding"
 | 
			
		||||
start=$(date "+%s")
 | 
			
		||||
 | 
			
		||||
pytest ${LLM_INFERENCE_TEST_DIR}/test_layer_fast_rope.py -v -s
 | 
			
		||||
 | 
			
		||||
now=$(date "+%s")
 | 
			
		||||
time=$((now-start))
 | 
			
		||||
 | 
			
		||||
echo "Bigdl-llm gpu layers.fast_rope_embedding tests finished"
 | 
			
		||||
echo "Time used:$time seconds"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue