diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 9f6fe65c..e26405d5 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -86,6 +86,7 @@ def chatglm2_attention_forward_8eb45c( # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + device = hidden_states.device mixed_x_layer = self.query_key_value(hidden_states) if self.multi_query_attention: @@ -151,11 +152,13 @@ def chatglm2_attention_forward_8eb45c( self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition, self.max_cache_length, - self.hidden_size_per_attention_head,), + self.hidden_size_per_attention_head, + device=device), torch.empty(batch_size, self.num_attention_heads_per_partition, self.max_cache_length, - self.hidden_size_per_attention_head,)) + self.hidden_size_per_attention_head, + device=device)) self.kv_cache[0][:, :, :past_length, :] = cache_k self.kv_cache[1][:, :, :past_length, :] = cache_v self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer @@ -168,9 +171,11 @@ def chatglm2_attention_forward_8eb45c( self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + KV_CACHE_ALLOC_BLOCK_LENGTH self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition, - self.max_cache_length, self.hidden_size_per_attention_head,), + self.max_cache_length, self.hidden_size_per_attention_head, + device=device), torch.empty(batch_size, self.num_attention_heads_per_partition, - self.max_cache_length, self.hidden_size_per_attention_head,)) + self.max_cache_length, self.hidden_size_per_attention_head, + device=device)) self.kv_cache[0][:, :, :cur_length, :] = key_layer self.kv_cache[1][:, :, :cur_length, :] = value_layer @@ -196,7 +201,7 @@ def chatglm2_attention_forward_8eb45c( def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask): pytorch_major_version = int(torch.__version__.split('.')[0]) - if query_layer.size(0) > 1 and pytorch_major_version >= 2: + if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1): query_layer = query_layer.permute(1, 2, 0, 3) if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: @@ -222,8 +227,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio is_causal=True) else: if attention_mask is not None: - attention_mask = ~attention_mask - attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) + attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) if torch.is_autocast_cpu_enabled(): query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) @@ -258,6 +262,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device ) # Raw attention scores. [b * np, sq, sk] @@ -313,6 +318,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio context_layer = torch.empty( output_size[0] * output_size[1], output_size[2], value_layer.size(-1), dtype=value_layer.dtype, + device=value_layer.device, ) torch.bmm(attention_probs, value_layer, out=context_layer) # change view [b, np, sq, hn] diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index a13e5146..558ebda9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -129,11 +129,13 @@ def llama_attention_forward_4_31( # value_states = torch.cat([past_key_value[1], value_states], dim=2) if kv_seq_len > self.max_cache_length: new_cache_key = torch.empty(bsz, self.num_heads, - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim) + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim, + device=device) new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :] new_cache_value = torch.empty(bsz, self.num_heads, - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim) + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim, + device=device) new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :] self.kv_cache = (new_cache_key, new_cache_value) self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH @@ -146,8 +148,10 @@ def llama_attention_forward_4_31( # first token case self.max_cache_length = max(min(self.max_position_embeddings, 2 * kv_seq_len), kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH) - self.kv_cache = (torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim), - torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim)) + self.kv_cache = (torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim, + dtype=key_states.dtype, device=device), + torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim, + dtype=key_states.dtype, device=device)) self.kv_cache[0][:, :, :kv_seq_len, :] = key_states self.kv_cache[1][:, :, :kv_seq_len, :] = value_states