Merge remote-tracking branch 'upstream/main'
This commit is contained in:
		
						commit
						e8f436453d
					
				
					 29 changed files with 471 additions and 130 deletions
				
			
		| 
						 | 
					@ -22,13 +22,13 @@ Follow [here](https://github.com/kubeflow/mpi-operator/tree/master#installation)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Follow [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/lora/docker#prepare-bigdl-image-for-lora-finetuning) to prepare BigDL Lora Finetuning image in your cluster.
 | 
					Follow [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/lora/docker#prepare-bigdl-image-for-lora-finetuning) to prepare BigDL Lora Finetuning image in your cluster.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
As finetuning is from a base model, first download [Llama 7b hf model from the public download site of Hugging Face](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main). Then, download [cleaned alpaca data](https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json), which contains all kinds of general knowledge and has already been cleaned. Next, move the downloaded files to a shared directory on your NFS server. In addition, make an empty directory under the same destination to save the finetuned model output later.
 | 
					As finetuning is from a base model, first download [Llama 7b hf model from the public download site of Hugging Face](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main). Then, download [cleaned alpaca data](https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json), which contains all kinds of general knowledge and has already been cleaned. Next, move the downloaded files to a shared directory on your NFS server.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 3. Deploy through Helm Chart
 | 
					### 3. Deploy through Helm Chart
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You are allowed to edit and experiment with different parameters in `./kubernetes/values.yaml` to improve finetuning performance and accuracy. For example, you can adjust `trainerNum` and `cpuPerPod` according to node and CPU core numbers in your cluster to make full use of these resources, and different `microBatchSize` result in different training speed and loss (here note that `microBatchSize`×`trainerNum` should not more than 128, as it is the batch size).
 | 
					You are allowed to edit and experiment with different parameters in `./kubernetes/values.yaml` to improve finetuning performance and accuracy. For example, you can adjust `trainerNum` and `cpuPerPod` according to node and CPU core numbers in your cluster to make full use of these resources, and different `microBatchSize` result in different training speed and loss (here note that `microBatchSize`×`trainerNum` should not more than 128, as it is the batch size).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Note: `dataSubPath`, `modelSubPath` and `outputPath` need to have the same names as files under the NFS directory in step 2.**
 | 
					**Note: `dataSubPath` and `modelSubPath` need to have the same names as files under the NFS directory in step 2.**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
After preparing parameters in `./kubernetes/values.yaml`, submit the job as beflow:
 | 
					After preparing parameters in `./kubernetes/values.yaml`, submit the job as beflow:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,7 +52,9 @@ kubectl exec -it <launcher_pod_name> bash -n bigdl-ppml-finetuning # enter launc
 | 
				
			||||||
cat launcher.log # display logs collected from other workers
 | 
					cat launcher.log # display logs collected from other workers
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
From the log, you can see whether finetuning process has been invoked successfully in all MPI worker pods, and a progress bar with finetuning speed and estimated time will be showed after some data preprocessing steps (this may take quiet a while). For the fine-tuned model, it is written by the worker 0 (who holds rank 0), so you can find the model output inside the pod or the `output` folder under the NFS path (because it has been mounted to worker 0 as output path).
 | 
					From the log, you can see whether finetuning process has been invoked successfully in all MPI worker pods, and a progress bar with finetuning speed and estimated time will be showed after some data preprocessing steps (this may take quiet a while).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					For the fine-tuned model, it is written by the worker 0 (who holds rank 0), so you can find the model output inside the pod, which can be saved to host by command tools like `kubectl cp` or `scp`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## To run in TDX-CoCo and enable Remote Attestation API
 | 
					## To run in TDX-CoCo and enable Remote Attestation API
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,7 +8,6 @@ if [ "$WORKER_ROLE" = "launcher" ]
 | 
				
			||||||
then
 | 
					then
 | 
				
			||||||
  sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
 | 
					  sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
 | 
				
			||||||
  export DATA_PATH="/ppml/data/$DATA_SUB_PATH"
 | 
					  export DATA_PATH="/ppml/data/$DATA_SUB_PATH"
 | 
				
			||||||
  export SAVE_PATH="/ppml/output"
 | 
					 | 
				
			||||||
  sleep 10
 | 
					  sleep 10
 | 
				
			||||||
  mpirun \
 | 
					  mpirun \
 | 
				
			||||||
    -n $WORLD_SIZE \
 | 
					    -n $WORLD_SIZE \
 | 
				
			||||||
| 
						 | 
					@ -22,13 +21,13 @@ then
 | 
				
			||||||
    python /ppml/lora_finetune.py \
 | 
					    python /ppml/lora_finetune.py \
 | 
				
			||||||
      --base_model '/ppml/model/'  \
 | 
					      --base_model '/ppml/model/'  \
 | 
				
			||||||
      --data_path "$DATA_PATH" \
 | 
					      --data_path "$DATA_PATH" \
 | 
				
			||||||
      --output_dir "$SAVE_PATH/finetuned_model" \
 | 
					      --output_dir "/home/mpiuser/finetuned_model" \
 | 
				
			||||||
      --micro_batch_size $MICRO_BATCH_SIZE \
 | 
					      --micro_batch_size $MICRO_BATCH_SIZE \
 | 
				
			||||||
      --bf16 > $SAVE_PATH/launcher.log 2>&1
 | 
					      --bf16 > /home/mpiuser/launcher.log 2>&1
 | 
				
			||||||
  exit_status=$?
 | 
					  exit_status=$?
 | 
				
			||||||
  if [ $exit_status -ne 0 ];
 | 
					  if [ $exit_status -ne 0 ];
 | 
				
			||||||
  then
 | 
					  then
 | 
				
			||||||
    cat $SAVE_PATH/launcher.log
 | 
					    cat /home/mpiuser/launcher.log
 | 
				
			||||||
    exit $exit_status
 | 
					    exit $exit_status
 | 
				
			||||||
  else
 | 
					  else
 | 
				
			||||||
    while true
 | 
					    while true
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,9 +51,6 @@ spec:
 | 
				
			||||||
             - name: nfs-storage
 | 
					             - name: nfs-storage
 | 
				
			||||||
               subPath: {{ .Values.dataSubPath }}
 | 
					               subPath: {{ .Values.dataSubPath }}
 | 
				
			||||||
               mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
 | 
					               mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
 | 
				
			||||||
             - name: nfs-storage
 | 
					 | 
				
			||||||
               subPath: {{ .Values.outputSubPath }}
 | 
					 | 
				
			||||||
               mountPath: "/ppml/output"
 | 
					 | 
				
			||||||
    Worker:
 | 
					    Worker:
 | 
				
			||||||
      replicas: {{ .Values.trainerNum }}
 | 
					      replicas: {{ .Values.trainerNum }}
 | 
				
			||||||
      template:
 | 
					      template:
 | 
				
			||||||
| 
						 | 
					@ -86,9 +83,6 @@ spec:
 | 
				
			||||||
            - name: nfs-storage
 | 
					            - name: nfs-storage
 | 
				
			||||||
              subPath: {{ .Values.dataSubPath }}
 | 
					              subPath: {{ .Values.dataSubPath }}
 | 
				
			||||||
              mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
 | 
					              mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
 | 
				
			||||||
            - name: nfs-storage
 | 
					 | 
				
			||||||
              subPath: {{ .Values.outputSubPath }}
 | 
					 | 
				
			||||||
              mountPath: "/ppml/output"
 | 
					 | 
				
			||||||
            resources:
 | 
					            resources:
 | 
				
			||||||
              requests:
 | 
					              requests:
 | 
				
			||||||
                cpu: {{ .Values.cpuPerPod }}
 | 
					                cpu: {{ .Values.cpuPerPod }}
 | 
				
			||||||
| 
						 | 
					@ -96,4 +90,4 @@ spec:
 | 
				
			||||||
          - name: nfs-storage
 | 
					          - name: nfs-storage
 | 
				
			||||||
            persistentVolumeClaim:
 | 
					            persistentVolumeClaim:
 | 
				
			||||||
              claimName: nfs-pvc
 | 
					              claimName: nfs-pvc
 | 
				
			||||||
{{- end }}
 | 
					{{- end }}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -71,9 +71,6 @@ spec:
 | 
				
			||||||
             - name: nfs-storage
 | 
					             - name: nfs-storage
 | 
				
			||||||
               subPath: {{ .Values.dataSubPath }}
 | 
					               subPath: {{ .Values.dataSubPath }}
 | 
				
			||||||
               mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
 | 
					               mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
 | 
				
			||||||
             - name: nfs-storage
 | 
					 | 
				
			||||||
               subPath: {{ .Values.outputSubPath }}
 | 
					 | 
				
			||||||
               mountPath: "/ppml/output"
 | 
					 | 
				
			||||||
             - name: dev
 | 
					             - name: dev
 | 
				
			||||||
               mountPath: /dev
 | 
					               mountPath: /dev
 | 
				
			||||||
             {{- if eq .Values.enableTLS true }}
 | 
					             {{- if eq .Values.enableTLS true }}
 | 
				
			||||||
| 
						 | 
					@ -118,9 +115,6 @@ spec:
 | 
				
			||||||
            - name: nfs-storage
 | 
					            - name: nfs-storage
 | 
				
			||||||
              subPath: {{ .Values.dataSubPath }}
 | 
					              subPath: {{ .Values.dataSubPath }}
 | 
				
			||||||
              mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
 | 
					              mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
 | 
				
			||||||
            - name: nfs-storage
 | 
					 | 
				
			||||||
              subPath: {{ .Values.outputSubPath }}
 | 
					 | 
				
			||||||
              mountPath: "/ppml/output"
 | 
					 | 
				
			||||||
            - name: dev
 | 
					            - name: dev
 | 
				
			||||||
              mountPath: /dev
 | 
					              mountPath: /dev
 | 
				
			||||||
            resources:
 | 
					            resources:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,11 +6,10 @@ nfsServerIp: your_nfs_server_ip
 | 
				
			||||||
nfsPath: a_nfs_shared_folder_path_on_the_server
 | 
					nfsPath: a_nfs_shared_folder_path_on_the_server
 | 
				
			||||||
dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory
 | 
					dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory
 | 
				
			||||||
modelSubPath: llama-7b-hf # a subpath of the model file (dir) under nfs directory
 | 
					modelSubPath: llama-7b-hf # a subpath of the model file (dir) under nfs directory
 | 
				
			||||||
outputSubPath: output # a subpath of the empty directory under the nfs directory to save finetuned model, for example, if you make an empty dir named 'output' at the nfsPath, the value should be 'output'
 | 
					 | 
				
			||||||
ompNumThreads: 14
 | 
					ompNumThreads: 14
 | 
				
			||||||
cpuPerPod: 42
 | 
					cpuPerPod: 42
 | 
				
			||||||
attestionApiServicePort: 9870
 | 
					attestionApiServicePort: 9870
 | 
				
			||||||
 | 
					
 | 
				
			||||||
enableTLS: false # true or false
 | 
					enableTLS: false # true or false
 | 
				
			||||||
base64ServerCrt: "your_base64_format_server_crt"
 | 
					base64ServerCrt: "your_base64_format_server_crt"
 | 
				
			||||||
base64ServerKey: "your_base64_format_server_key"
 | 
					base64ServerKey: "your_base64_format_server_key"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -176,18 +176,16 @@ def run_pytorch_autocast_bf16(repo_id,
 | 
				
			||||||
    st = time.perf_counter()
 | 
					    st = time.perf_counter()
 | 
				
			||||||
    if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
 | 
					    if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
 | 
				
			||||||
        # TODO: need verify chatglm family run bf16.
 | 
					        # TODO: need verify chatglm family run bf16.
 | 
				
			||||||
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype='auto').float()
 | 
					        invalidInputError(False, "Currently pytorch do not support bfloat16 on cpu for chatglm models.")
 | 
				
			||||||
        #model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype='auto').bfloat()
 | 
					 | 
				
			||||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					 | 
				
			||||||
    elif repo_id in ['meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-13b-chat-hf',
 | 
					    elif repo_id in ['meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-13b-chat-hf',
 | 
				
			||||||
                     'meta-llama/Llama-2-70b-chat-hf','decapoda-research/llama-7b-hf',
 | 
					                     'meta-llama/Llama-2-70b-chat-hf','decapoda-research/llama-7b-hf',
 | 
				
			||||||
                     'decapoda-research/llama-65b-hf','lmsys/vicuna-7b-v1.5',
 | 
					                     'decapoda-research/llama-65b-hf','lmsys/vicuna-7b-v1.5',
 | 
				
			||||||
                     'lmsys/vicuna-13b-v1.3','project-baize/merged-baize-30b']:
 | 
					                     'lmsys/vicuna-13b-v1.3','project-baize/merged-baize-30b']:
 | 
				
			||||||
        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype='auto')
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
 | 
				
			||||||
        # Need to use LlamaTokenizer, reason please refer to issue: https://github.com/intel-analytics/BigDL/issues/8944
 | 
					        # Need to use LlamaTokenizer, reason please refer to issue: https://github.com/intel-analytics/BigDL/issues/8944
 | 
				
			||||||
        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype='auto')
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
 | 
				
			||||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
    end = time.perf_counter()
 | 
					    end = time.perf_counter()
 | 
				
			||||||
    print(">> loading of model costs {}s".format(end - st))
 | 
					    print(">> loading of model costs {}s".format(end - st))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -42,7 +42,6 @@ if __name__ == '__main__':
 | 
				
			||||||
    # which convert the relevant layers in the model into INT4 format
 | 
					    # which convert the relevant layers in the model into INT4 format
 | 
				
			||||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
					    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
                                                 load_in_4bit=True,
 | 
					                                                 load_in_4bit=True,
 | 
				
			||||||
                                                 optimize_model=False,
 | 
					 | 
				
			||||||
                                                 trust_remote_code=True,
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
                                                 use_cache=True)
 | 
					                                                 use_cache=True)
 | 
				
			||||||
    model = model.to('xpu')
 | 
					    model = model.to('xpu')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -46,7 +46,6 @@ if __name__ == '__main__':
 | 
				
			||||||
    # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
					    # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
				
			||||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
					    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
                                                 load_in_4bit=True,
 | 
					                                                 load_in_4bit=True,
 | 
				
			||||||
                                                 optimize_model=False,
 | 
					 | 
				
			||||||
                                                 trust_remote_code=True,
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
                                                 use_cache=True)
 | 
					                                                 use_cache=True)
 | 
				
			||||||
    model = model.to('xpu')
 | 
					    model = model.to('xpu')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,7 +44,6 @@ if __name__ == '__main__':
 | 
				
			||||||
    # which convert the relevant layers in the model into INT4 format
 | 
					    # which convert the relevant layers in the model into INT4 format
 | 
				
			||||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
					    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
                                                 load_in_4bit=True,
 | 
					                                                 load_in_4bit=True,
 | 
				
			||||||
                                                 optimize_model=False,
 | 
					 | 
				
			||||||
                                                 trust_remote_code=True,
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
                                                 use_cache=True)
 | 
					                                                 use_cache=True)
 | 
				
			||||||
    model = model.to('xpu')
 | 
					    model = model.to('xpu')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -42,7 +42,6 @@ if __name__ == '__main__':
 | 
				
			||||||
    # which convert the relevant layers in the model into INT4 format
 | 
					    # which convert the relevant layers in the model into INT4 format
 | 
				
			||||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
					    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
                                                 load_in_4bit=True,
 | 
					                                                 load_in_4bit=True,
 | 
				
			||||||
                                                 optimize_model=False,
 | 
					 | 
				
			||||||
                                                 trust_remote_code=True,
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
                                                 use_cache=True)
 | 
					                                                 use_cache=True)
 | 
				
			||||||
    model = model.to('xpu')
 | 
					    model = model.to('xpu')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										2
									
								
								python/llm/portable-executable/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								python/llm/portable-executable/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,2 @@
 | 
				
			||||||
 | 
					python-embed
 | 
				
			||||||
 | 
					portable-executable.zip
 | 
				
			||||||
							
								
								
									
										33
									
								
								python/llm/portable-executable/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								python/llm/portable-executable/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,33 @@
 | 
				
			||||||
 | 
					# BigDL-LLM Portable Executable For Windows: User Guide
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This portable executable includes everything you need to run LLM (except models). Please refer to How to use section to get started.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 13B model running on an Intel 11-Gen Core PC (real-time screen capture)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<p align="left">
 | 
				
			||||||
 | 
					            <img src=https://llm-assets.readthedocs.io/en/latest/_images/one-click-installer-screen-capture.gif width='80%' />
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					</p>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Verified Models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- ChatGLM2-6b
 | 
				
			||||||
 | 
					- Baichuan-13B-Chat
 | 
				
			||||||
 | 
					- Baichuan2-7B-Chat
 | 
				
			||||||
 | 
					- internlm-chat-7b-8k
 | 
				
			||||||
 | 
					- Llama-2-7b-chat-hf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## How to use
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. Download the model to your computer. Please ensure there is a file named `config.json` in the model folder, otherwise the script won't work.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					2. Run `chat.bat` in Terminal and input the path of the model (e.g. `path\to\model`, note that there's no slash at the end of the path).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					3. Press Enter and wait until model finishes loading. Then enjoy chatting with the model!
 | 
				
			||||||
 | 
					4. If you want to stop chatting, just input `stop` and the model will stop running.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   
 | 
				
			||||||
							
								
								
									
										8
									
								
								python/llm/portable-executable/chat.bat
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								python/llm/portable-executable/chat.bat
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,8 @@
 | 
				
			||||||
 | 
					@echo off
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: execute chat script
 | 
				
			||||||
 | 
					set PYTHONUNBUFFERED=1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					set /p modelpath="Please enter the model path: "
 | 
				
			||||||
 | 
					.\python-embed\python.exe .\chat.py --model-path="%modelpath%"
 | 
				
			||||||
							
								
								
									
										116
									
								
								python/llm/portable-executable/chat.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								python/llm/portable-executable/chat.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,116 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# todo: support more model class
 | 
				
			||||||
 | 
					from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
 | 
				
			||||||
 | 
					from transformers import TextIteratorStreamer
 | 
				
			||||||
 | 
					from transformers.tools.agents import StopSequenceCriteria
 | 
				
			||||||
 | 
					from transformers.generation.stopping_criteria import StoppingCriteriaList
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from colorama import Fore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from bigdl.llm import optimize_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SYSTEM_PROMPT = "A chat between a curious human <human> and an artificial intelligence assistant <bot>.\
 | 
				
			||||||
 | 
					The assistant gives helpful, detailed, and polite answers to the human's questions."
 | 
				
			||||||
 | 
					HUMAN_ID = "<human>"
 | 
				
			||||||
 | 
					BOT_ID = "<bot>"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# chat_history formated in [(iput_str, output_str)]
 | 
				
			||||||
 | 
					def format_prompt(input_str,
 | 
				
			||||||
 | 
					                  chat_history):
 | 
				
			||||||
 | 
					    prompt = [f"{SYSTEM_PROMPT}\n"]
 | 
				
			||||||
 | 
					    for history_input_str, history_output_str in chat_history:
 | 
				
			||||||
 | 
					        prompt.append(f"{HUMAN_ID} {history_input_str}\n{BOT_ID} {history_output_str}\n")
 | 
				
			||||||
 | 
					    prompt.append(f"{HUMAN_ID} {input_str}\n{BOT_ID} ")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return "".join(prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def stream_chat(model,
 | 
				
			||||||
 | 
					                tokenizer,
 | 
				
			||||||
 | 
					                stopping_criteria,
 | 
				
			||||||
 | 
					                input_str,
 | 
				
			||||||
 | 
					                chat_history):
 | 
				
			||||||
 | 
					    prompt = format_prompt(input_str, chat_history)
 | 
				
			||||||
 | 
					    # print(prompt)
 | 
				
			||||||
 | 
					    input_ids = tokenizer([prompt], return_tensors="pt")
 | 
				
			||||||
 | 
					    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 | 
				
			||||||
 | 
					    generate_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, stopping_criteria=stopping_criteria)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from threading import Thread
 | 
				
			||||||
 | 
					    # to ensure non-blocking access to the generated text, generation process should be ran in a separate thread
 | 
				
			||||||
 | 
					    thread = Thread(target=model.generate, kwargs=generate_kwargs)
 | 
				
			||||||
 | 
					    thread.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    output_str = []
 | 
				
			||||||
 | 
					    print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="")
 | 
				
			||||||
 | 
					    for partial_output_str in streamer:
 | 
				
			||||||
 | 
					        output_str.append(partial_output_str)
 | 
				
			||||||
 | 
					        # remove the last HUMAN_ID if exists
 | 
				
			||||||
 | 
					        print(partial_output_str.replace(f"{HUMAN_ID}", ""), end="")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    chat_history.append((input_str, "".join(output_str).replace(f"{HUMAN_ID}", "").rstrip()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def auto_select_model(model_name):
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
 | 
					                                                        low_cpu_mem_usage=True,
 | 
				
			||||||
 | 
					                                                        torch_dtype="auto",
 | 
				
			||||||
 | 
					                                                        trust_remote_code=True,
 | 
				
			||||||
 | 
					                                                        use_cache=True)
 | 
				
			||||||
 | 
					        except:
 | 
				
			||||||
 | 
					            model = AutoModel.from_pretrained(model_path,
 | 
				
			||||||
 | 
					                                             low_cpu_mem_usage=True,
 | 
				
			||||||
 | 
					                                             torch_dtype="auto",
 | 
				
			||||||
 | 
					                                             trust_remote_code=True,
 | 
				
			||||||
 | 
					                                             use_cache=True)
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        print("Sorry, the model you entered is not supported in installer.")
 | 
				
			||||||
 | 
					        sys.exit()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					  parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					  parser.add_argument("--model-path", type=str, help="path to an llm")
 | 
				
			||||||
 | 
					  args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  model_path = args.model_path
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  model = auto_select_model(model_path)
 | 
				
			||||||
 | 
					  model = optimize_model(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(HUMAN_ID, tokenizer)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  chat_history = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  while True:
 | 
				
			||||||
 | 
					      with torch.inference_mode():
 | 
				
			||||||
 | 
					          user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
 | 
				
			||||||
 | 
					          if user_input == "stop": # let's stop the conversation when user input "stop"
 | 
				
			||||||
 | 
					              break
 | 
				
			||||||
 | 
					          stream_chat(model=model,
 | 
				
			||||||
 | 
					                      tokenizer=tokenizer,
 | 
				
			||||||
 | 
					                      stopping_criteria=stopping_criteria,
 | 
				
			||||||
 | 
					                      input_str=user_input,
 | 
				
			||||||
 | 
					                      chat_history=chat_history)
 | 
				
			||||||
							
								
								
									
										23
									
								
								python/llm/portable-executable/setup.bat
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								python/llm/portable-executable/setup.bat
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,23 @@
 | 
				
			||||||
 | 
					:: download python and extract zip
 | 
				
			||||||
 | 
					powershell -Command "Start-BitsTransfer -Source https://www.python.org/ftp/python/3.9.13/python-3.9.13-embed-amd64.zip -Destination python-3.9.13-embed-amd64.zip"
 | 
				
			||||||
 | 
					powershell -Command "Expand-Archive .\python-3.9.13-embed-amd64.zip -DestinationPath .\python-embed"
 | 
				
			||||||
 | 
					del .\python-3.9.13-embed-amd64.zip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					set "python-embed=.\python-embed\python.exe"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: download get-pip.py and install
 | 
				
			||||||
 | 
					powershell -Command "Invoke-WebRequest https://bootstrap.pypa.io/get-pip.py -OutFile .\python-embed\get-pip.py"
 | 
				
			||||||
 | 
					%python-embed% .\python-embed\get-pip.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: enable run site.main() automatically
 | 
				
			||||||
 | 
					cd .\python-embed
 | 
				
			||||||
 | 
					set "search=#import site"
 | 
				
			||||||
 | 
					set "replace=import site"
 | 
				
			||||||
 | 
					powershell -Command "(gc python39._pth) -replace '%search%', '%replace%' | Out-File -encoding ASCII python39._pth"
 | 
				
			||||||
 | 
					cd ..
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: install pip packages
 | 
				
			||||||
 | 
					%python-embed% -m pip install bigdl-llm[all] transformers_stream_generator tiktoken einops colorama
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: compress the python and scripts
 | 
				
			||||||
 | 
					powershell -Command "Compress-Archive -Path '.\python-embed', '.\chat.bat', '.\chat.py', '.\README.md' -DestinationPath .\portable-executable.zip"
 | 
				
			||||||
							
								
								
									
										5
									
								
								python/llm/portable-executable/setup.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								python/llm/portable-executable/setup.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,5 @@
 | 
				
			||||||
 | 
					# BigDL-LLM Portable Executable Setup Script For Windows
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# How to use
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Just simply run `setup.bat` and it will download and install all dependency and generate a zip file for user to use.
 | 
				
			||||||
| 
						 | 
					@ -173,6 +173,15 @@ def optimize(model):
 | 
				
			||||||
                        module.SelfAttention,
 | 
					                        module.SelfAttention,
 | 
				
			||||||
                        chatglm_attention_forward
 | 
					                        chatglm_attention_forward
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					    elif "mpt" in model.config._name_or_path:
 | 
				
			||||||
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					        attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
 | 
				
			||||||
 | 
					        module = importlib.import_module(attention_module_name)
 | 
				
			||||||
 | 
					        from bigdl.llm.transformers.models.mpt import mpt_multihead_attention_forward
 | 
				
			||||||
 | 
					        convert_forward(model,
 | 
				
			||||||
 | 
					                        module.MultiheadAttention,
 | 
				
			||||||
 | 
					                        mpt_multihead_attention_forward
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
    elif "gptj" in model.config.model_type:
 | 
					    elif "gptj" in model.config.model_type:
 | 
				
			||||||
        # dolly-v1-6b
 | 
					        # dolly-v1-6b
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
| 
						 | 
					@ -181,7 +190,7 @@ def optimize(model):
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.GPTJAttention,
 | 
					                        module.GPTJAttention,
 | 
				
			||||||
                        gptj_attention_forward)
 | 
					                        gptj_attention_forward)
 | 
				
			||||||
    elif "bloom" in model.config._name_or_path:
 | 
					    elif "bloom" in model.config.model_type:
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
        from bigdl.llm.transformers.models.bloom import bloom_attention_forward
 | 
					        from bigdl.llm.transformers.models.bloom import bloom_attention_forward
 | 
				
			||||||
| 
						 | 
					@ -189,17 +198,18 @@ def optimize(model):
 | 
				
			||||||
                        module.BloomAttention,
 | 
					                        module.BloomAttention,
 | 
				
			||||||
                        bloom_attention_forward
 | 
					                        bloom_attention_forward
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
    elif "falcon" in model.config._name_or_path:
 | 
					    elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type:
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
        if "RWForCausalLM" in model.config.architectures:
 | 
					        if "RWForCausalLM" in model.config.architectures:
 | 
				
			||||||
            if hasattr(model.config, "multi_query"):
 | 
					            if hasattr(model.config, "multi_query"):
 | 
				
			||||||
                # falcon-7b
 | 
					                # falcon-7b need to check performance drop after kv cache support.
 | 
				
			||||||
                from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
 | 
					                # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
 | 
				
			||||||
                convert_forward(model,
 | 
					                # convert_forward(model,
 | 
				
			||||||
                                module.Attention,
 | 
					                #                 module.Attention,
 | 
				
			||||||
                                rw_attention_forward_7b
 | 
					                #                 rw_attention_forward_7b
 | 
				
			||||||
                                )
 | 
					                #                 )
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                # falcon-40b
 | 
					                # falcon-40b
 | 
				
			||||||
                from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
 | 
					                from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
 | 
				
			||||||
| 
						 | 
					@ -262,5 +272,4 @@ def optimize(model):
 | 
				
			||||||
                        transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
 | 
					                        transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
 | 
				
			||||||
                        gptneox_attention_forward
 | 
					                        gptneox_attention_forward
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,6 +30,7 @@ def save_low_bit(self, *args, **kwargs):
 | 
				
			||||||
    invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False),
 | 
					    invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False),
 | 
				
			||||||
                      f"Detected this model is not a low-bit model, please use from_pretrained's"
 | 
					                      f"Detected this model is not a low-bit model, please use from_pretrained's"
 | 
				
			||||||
                      f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
 | 
					                      f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
 | 
				
			||||||
 | 
					    self.to('cpu')
 | 
				
			||||||
    self.save_pretrained(*args, **kwargs)
 | 
					    self.save_pretrained(*args, **kwargs)
 | 
				
			||||||
    import json
 | 
					    import json
 | 
				
			||||||
    import os
 | 
					    import os
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,7 +26,7 @@ import torch.utils.checkpoint
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
					from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
					from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
| 
						 | 
					@ -71,7 +71,7 @@ def baichuan_attention_forward_7b(
 | 
				
			||||||
        cache_v = past_key_value[1]
 | 
					        cache_v = past_key_value[1]
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
				
			||||||
                                                       self.num_heads,
 | 
					                                                       self.num_heads,
 | 
				
			||||||
                                                       self.head_dim,
 | 
					                                                       self.head_dim,
 | 
				
			||||||
                                                       cache_k.size(2),
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
| 
						 | 
					@ -87,13 +87,13 @@ def baichuan_attention_forward_7b(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(bsz,
 | 
					        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
				
			||||||
                                                           self.num_heads,
 | 
					                                                         self.num_heads,
 | 
				
			||||||
                                                           self.head_dim,
 | 
					                                                         self.head_dim,
 | 
				
			||||||
                                                           kv_seq_len,
 | 
					                                                         kv_seq_len,
 | 
				
			||||||
                                                           max_cache_length,
 | 
					                                                         max_cache_length,
 | 
				
			||||||
                                                           dtype=key_states.dtype,
 | 
					                                                         dtype=key_states.dtype,
 | 
				
			||||||
                                                           device=device)
 | 
					                                                         device=device)
 | 
				
			||||||
        new_key_states[:] = key_states
 | 
					        new_key_states[:] = key_states
 | 
				
			||||||
        new_value_states[:] = value_states
 | 
					        new_value_states[:] = value_states
 | 
				
			||||||
        key_states = new_key_states
 | 
					        key_states = new_key_states
 | 
				
			||||||
| 
						 | 
					@ -169,7 +169,7 @@ def baichuan_attention_forward_13b(
 | 
				
			||||||
        cache_v = past_key_value[1]
 | 
					        cache_v = past_key_value[1]
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
				
			||||||
                                                       self.num_heads,
 | 
					                                                       self.num_heads,
 | 
				
			||||||
                                                       self.head_dim,
 | 
					                                                       self.head_dim,
 | 
				
			||||||
                                                       cache_k.size(2),
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
| 
						 | 
					@ -185,13 +185,13 @@ def baichuan_attention_forward_13b(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(bsz,
 | 
					        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
				
			||||||
                                                           self.num_heads,
 | 
					                                                         self.num_heads,
 | 
				
			||||||
                                                           self.head_dim,
 | 
					                                                         self.head_dim,
 | 
				
			||||||
                                                           kv_seq_len,
 | 
					                                                         kv_seq_len,
 | 
				
			||||||
                                                           max_cache_length,
 | 
					                                                         max_cache_length,
 | 
				
			||||||
                                                           dtype=key_states.dtype,
 | 
					                                                         dtype=key_states.dtype,
 | 
				
			||||||
                                                           device=device)
 | 
					                                                         device=device)
 | 
				
			||||||
        new_key_states[:] = key_states
 | 
					        new_key_states[:] = key_states
 | 
				
			||||||
        new_value_states[:] = value_states
 | 
					        new_value_states[:] = value_states
 | 
				
			||||||
        key_states = new_key_states
 | 
					        key_states = new_key_states
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,7 +26,7 @@ from torch import nn
 | 
				
			||||||
from torch.nn import functional as F
 | 
					from torch.nn import functional as F
 | 
				
			||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
					from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
					from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
				
			||||||
from transformers.utils import logging, ContextManagers
 | 
					from transformers.utils import logging, ContextManagers
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
| 
						 | 
					@ -83,7 +83,7 @@ def baichuan_attention_forward_7b(
 | 
				
			||||||
        cache_v = past_key_value[1]
 | 
					        cache_v = past_key_value[1]
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
				
			||||||
                                                       self.num_heads,
 | 
					                                                       self.num_heads,
 | 
				
			||||||
                                                       self.head_dim,
 | 
					                                                       self.head_dim,
 | 
				
			||||||
                                                       cache_k.size(2),
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
| 
						 | 
					@ -99,13 +99,13 @@ def baichuan_attention_forward_7b(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(bsz,
 | 
					        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
				
			||||||
                                                           self.num_heads,
 | 
					                                                         self.num_heads,
 | 
				
			||||||
                                                           self.head_dim,
 | 
					                                                         self.head_dim,
 | 
				
			||||||
                                                           kv_seq_len,
 | 
					                                                         kv_seq_len,
 | 
				
			||||||
                                                           max_cache_length,
 | 
					                                                         max_cache_length,
 | 
				
			||||||
                                                           dtype=key_states.dtype,
 | 
					                                                         dtype=key_states.dtype,
 | 
				
			||||||
                                                           device=device)
 | 
					                                                         device=device)
 | 
				
			||||||
        new_key_states[:] = key_states
 | 
					        new_key_states[:] = key_states
 | 
				
			||||||
        new_value_states[:] = value_states
 | 
					        new_value_states[:] = value_states
 | 
				
			||||||
        key_states = new_key_states
 | 
					        key_states = new_key_states
 | 
				
			||||||
| 
						 | 
					@ -177,8 +177,10 @@ def baichuan_attention_forward_13b(
 | 
				
			||||||
        cache_k = past_key_value[0]
 | 
					        cache_k = past_key_value[0]
 | 
				
			||||||
        cache_v = past_key_value[1]
 | 
					        cache_v = past_key_value[1]
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
 | 
					            if device.type == 'xpu':
 | 
				
			||||||
 | 
					                torch.xpu.empty_cache()
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
				
			||||||
                                                       self.num_heads,
 | 
					                                                       self.num_heads,
 | 
				
			||||||
                                                       self.head_dim,
 | 
					                                                       self.head_dim,
 | 
				
			||||||
                                                       cache_k.size(2),
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
| 
						 | 
					@ -194,13 +196,13 @@ def baichuan_attention_forward_13b(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(bsz,
 | 
					        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
				
			||||||
                                                           self.num_heads,
 | 
					                                                         self.num_heads,
 | 
				
			||||||
                                                           self.head_dim,
 | 
					                                                         self.head_dim,
 | 
				
			||||||
                                                           kv_seq_len,
 | 
					                                                         kv_seq_len,
 | 
				
			||||||
                                                           max_cache_length,
 | 
					                                                         max_cache_length,
 | 
				
			||||||
                                                           dtype=key_states.dtype,
 | 
					                                                         dtype=key_states.dtype,
 | 
				
			||||||
                                                           device=device)
 | 
					                                                         device=device)
 | 
				
			||||||
        new_key_states[:] = key_states
 | 
					        new_key_states[:] = key_states
 | 
				
			||||||
        new_value_states[:] = value_states
 | 
					        new_value_states[:] = value_states
 | 
				
			||||||
        key_states = new_key_states
 | 
					        key_states = new_key_states
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,7 +37,7 @@ from typing import Optional, Tuple
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.utils.checkpoint
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
from torch.nn import functional as F
 | 
					from torch.nn import functional as F
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
| 
						 | 
					@ -96,6 +96,8 @@ def bloom_attention_forward(
 | 
				
			||||||
        self.head_dim
 | 
					        self.head_dim
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    _, _, kv_length = key_layer.shape
 | 
					    _, _, kv_length = key_layer.shape
 | 
				
			||||||
 | 
					    if layer_past is not None:
 | 
				
			||||||
 | 
					        kv_length += layer_past[0].shape[-1]
 | 
				
			||||||
    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
    key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					    key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
    value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					    value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
| 
						 | 
					@ -106,7 +108,7 @@ def bloom_attention_forward(
 | 
				
			||||||
        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
					        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(
 | 
				
			||||||
                batch_size,
 | 
					                batch_size,
 | 
				
			||||||
                self.num_heads,
 | 
					                self.num_heads,
 | 
				
			||||||
                self.head_dim,
 | 
					                self.head_dim,
 | 
				
			||||||
| 
						 | 
					@ -124,7 +126,7 @@ def bloom_attention_forward(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(
 | 
					        new_key_states, new_value_states = init_kv_cache(
 | 
				
			||||||
            batch_size,
 | 
					            batch_size,
 | 
				
			||||||
            self.num_heads,
 | 
					            self.num_heads,
 | 
				
			||||||
            self.head_dim,
 | 
					            self.head_dim,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,7 +22,7 @@ import torch
 | 
				
			||||||
import torch.utils.checkpoint
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rotate_half(x):
 | 
					def rotate_half(x):
 | 
				
			||||||
| 
						 | 
					@ -68,7 +68,7 @@ def attention_fn(
 | 
				
			||||||
        past_length = cache_k.size(2)
 | 
					        past_length = cache_k.size(2)
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(batch_size,
 | 
				
			||||||
                                                       self.num_attention_heads_per_partition,
 | 
					                                                       self.num_attention_heads_per_partition,
 | 
				
			||||||
                                                       self.hidden_size_per_attention_head,
 | 
					                                                       self.hidden_size_per_attention_head,
 | 
				
			||||||
                                                       past_length,
 | 
					                                                       past_length,
 | 
				
			||||||
| 
						 | 
					@ -82,10 +82,10 @@ def attention_fn(
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
 | 
					        max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
 | 
				
			||||||
            + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					            + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
 | 
					        key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
 | 
				
			||||||
                                                 self.hidden_size_per_attention_head, cur_length,
 | 
					                                               self.hidden_size_per_attention_head, cur_length,
 | 
				
			||||||
                                                 max_cache_length,
 | 
					                                               max_cache_length,
 | 
				
			||||||
                                                 dtype=query_layer.dtype, device=device)
 | 
					                                               dtype=query_layer.dtype, device=device)
 | 
				
			||||||
        key_cache[:] = key_layer
 | 
					        key_cache[:] = key_layer
 | 
				
			||||||
        value_cache[:] = value_layer
 | 
					        value_cache[:] = value_layer
 | 
				
			||||||
        key_layer = key_cache
 | 
					        key_layer = key_cache
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,7 +20,7 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
 | 
					from typing import Optional, Tuple, Union, List, Callable, Dict, Any
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
| 
						 | 
					@ -152,7 +152,7 @@ def chatglm2_attention_forward_8eb45c(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(batch_size,
 | 
				
			||||||
                                                       self.num_attention_heads_per_partition,
 | 
					                                                       self.num_attention_heads_per_partition,
 | 
				
			||||||
                                                       self.hidden_size_per_attention_head,
 | 
					                                                       self.hidden_size_per_attention_head,
 | 
				
			||||||
                                                       past_length,
 | 
					                                                       past_length,
 | 
				
			||||||
| 
						 | 
					@ -170,10 +170,10 @@ def chatglm2_attention_forward_8eb45c(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
 | 
					        max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
 | 
				
			||||||
            + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					            + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
 | 
					        key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
 | 
				
			||||||
                                                 self.hidden_size_per_attention_head, cur_length,
 | 
					                                               self.hidden_size_per_attention_head, cur_length,
 | 
				
			||||||
                                                 max_cache_length,
 | 
					                                               max_cache_length,
 | 
				
			||||||
                                                 dtype=query_layer.dtype, device=device)
 | 
					                                               dtype=query_layer.dtype, device=device)
 | 
				
			||||||
        key_cache[:] = key_layer
 | 
					        key_cache[:] = key_layer
 | 
				
			||||||
        value_cache[:] = value_layer
 | 
					        value_cache[:] = value_layer
 | 
				
			||||||
        key_layer = key_cache
 | 
					        key_layer = key_cache
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -38,7 +38,7 @@ from typing import Optional, Tuple
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.nn import functional as F
 | 
					from torch.nn import functional as F
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
| 
						 | 
					@ -86,7 +86,8 @@ def rw_attention_forward_7b(
 | 
				
			||||||
    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
 | 
					    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _, kv_length, _ = key_layer.shape
 | 
					    _, kv_length, _ = key_layer.shape
 | 
				
			||||||
 | 
					    if layer_past is not None:
 | 
				
			||||||
 | 
					        kv_length += layer_past[0].shape[-2]
 | 
				
			||||||
    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
    key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
 | 
					    key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
 | 
				
			||||||
    value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
 | 
					    value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
 | 
				
			||||||
| 
						 | 
					@ -98,7 +99,7 @@ def rw_attention_forward_7b(
 | 
				
			||||||
        cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
 | 
					        cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(
 | 
				
			||||||
                batch_size,
 | 
					                batch_size,
 | 
				
			||||||
                self.num_kv,
 | 
					                self.num_kv,
 | 
				
			||||||
                self.head_dim,
 | 
					                self.head_dim,
 | 
				
			||||||
| 
						 | 
					@ -116,7 +117,7 @@ def rw_attention_forward_7b(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(
 | 
					        new_key_states, new_value_states = init_kv_cache(
 | 
				
			||||||
            batch_size,
 | 
					            batch_size,
 | 
				
			||||||
            self.num_kv,
 | 
					            self.num_kv,
 | 
				
			||||||
            self.head_dim,
 | 
					            self.head_dim,
 | 
				
			||||||
| 
						 | 
					@ -264,6 +265,8 @@ def rw_attention_forward_40b(
 | 
				
			||||||
    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
 | 
					    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _, kv_length, _ = key_layer.shape
 | 
					    _, kv_length, _ = key_layer.shape
 | 
				
			||||||
 | 
					    if layer_past is not None:
 | 
				
			||||||
 | 
					        kv_length += layer_past[0].shape[-2]
 | 
				
			||||||
    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
    key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					    key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
    value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					    value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
| 
						 | 
					@ -275,7 +278,7 @@ def rw_attention_forward_40b(
 | 
				
			||||||
        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
					        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(
 | 
				
			||||||
                batch_size,
 | 
					                batch_size,
 | 
				
			||||||
                self.num_heads,
 | 
					                self.num_heads,
 | 
				
			||||||
                self.head_dim,
 | 
					                self.head_dim,
 | 
				
			||||||
| 
						 | 
					@ -293,7 +296,7 @@ def rw_attention_forward_40b(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(
 | 
					        new_key_states, new_value_states = init_kv_cache(
 | 
				
			||||||
            batch_size,
 | 
					            batch_size,
 | 
				
			||||||
            self.num_heads,
 | 
					            self.num_heads,
 | 
				
			||||||
            self.head_dim,
 | 
					            self.head_dim,
 | 
				
			||||||
| 
						 | 
					@ -437,7 +440,8 @@ def falcon_attention_forward(
 | 
				
			||||||
    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
 | 
					    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _, kv_length, _ = key_layer.shape
 | 
					    _, kv_length, _ = key_layer.shape
 | 
				
			||||||
 | 
					    if layer_past is not None:
 | 
				
			||||||
 | 
					        kv_length += layer_past[0].shape[-2]
 | 
				
			||||||
    query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
					    query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
				
			||||||
    key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
					    key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
				
			||||||
    value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
					    value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
				
			||||||
| 
						 | 
					@ -448,7 +452,7 @@ def falcon_attention_forward(
 | 
				
			||||||
        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
					        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(
 | 
				
			||||||
                batch_size,
 | 
					                batch_size,
 | 
				
			||||||
                self.num_heads,
 | 
					                self.num_heads,
 | 
				
			||||||
                self.head_dim,
 | 
					                self.head_dim,
 | 
				
			||||||
| 
						 | 
					@ -466,7 +470,7 @@ def falcon_attention_forward(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(
 | 
					        new_key_states, new_value_states = init_kv_cache(
 | 
				
			||||||
            batch_size,
 | 
					            batch_size,
 | 
				
			||||||
            self.num_heads,
 | 
					            self.num_heads,
 | 
				
			||||||
            self.head_dim,
 | 
					            self.head_dim,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,8 +19,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from typing import Optional, Tuple, Union
 | 
					from typing import Optional, Tuple, Union
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache, \
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
				
			||||||
    apply_rotary_pos_emb
 | 
					    apply_rotary_pos_emb, append_kv_cache
 | 
				
			||||||
from transformers.utils.import_utils import is_torch_fx_proxy
 | 
					from transformers.utils.import_utils import is_torch_fx_proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -144,7 +144,7 @@ def gptj_attention_forward(
 | 
				
			||||||
        past_length = cache_k.size(2)
 | 
					        past_length = cache_k.size(2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(batch_size,
 | 
				
			||||||
                                                       self.num_attention_heads,
 | 
					                                                       self.num_attention_heads,
 | 
				
			||||||
                                                       self.head_dim,
 | 
					                                                       self.head_dim,
 | 
				
			||||||
                                                       past_length,
 | 
					                                                       past_length,
 | 
				
			||||||
| 
						 | 
					@ -158,13 +158,13 @@ def gptj_attention_forward(
 | 
				
			||||||
        key, value = append_kv_cache(cache_k, cache_v, key, value)
 | 
					        key, value = append_kv_cache(cache_k, cache_v, key, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        key_cache, value_cache = create_kv_cache(batch_size,
 | 
					        key_cache, value_cache = init_kv_cache(batch_size,
 | 
				
			||||||
                                                 self.num_attention_heads,
 | 
					                                               self.num_attention_heads,
 | 
				
			||||||
                                                 self.head_dim,
 | 
					                                               self.head_dim,
 | 
				
			||||||
                                                 kv_seq_len,
 | 
					                                               kv_seq_len,
 | 
				
			||||||
                                                 kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					                                               kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
                                                 dtype=key.dtype,
 | 
					                                               dtype=key.dtype,
 | 
				
			||||||
                                                 device=device)
 | 
					                                               device=device)
 | 
				
			||||||
        key_cache[:] = key
 | 
					        key_cache[:] = key
 | 
				
			||||||
        value_cache[:] = value
 | 
					        value_cache[:] = value
 | 
				
			||||||
        key = key_cache
 | 
					        key = key_cache
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,7 +34,7 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
| 
						 | 
					@ -91,7 +91,7 @@ def gptneox_attention_forward(
 | 
				
			||||||
        past_value = layer_past[1]
 | 
					        past_value = layer_past[1]
 | 
				
			||||||
        if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
 | 
					        if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_past_key, new_past_value = create_kv_cache(bsz,
 | 
					            new_past_key, new_past_value = extend_kv_cache(bsz,
 | 
				
			||||||
                                                           self.num_attention_heads,
 | 
					                                                           self.num_attention_heads,
 | 
				
			||||||
                                                           self.head_size,
 | 
					                                                           self.head_size,
 | 
				
			||||||
                                                           past_key.size(2),
 | 
					                                                           past_key.size(2),
 | 
				
			||||||
| 
						 | 
					@ -106,13 +106,13 @@ def gptneox_attention_forward(
 | 
				
			||||||
        key, value = append_kv_cache(past_key, past_value, key, value)
 | 
					        key, value = append_kv_cache(past_key, past_value, key, value)
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key, new_value = create_kv_cache(bsz,
 | 
					        new_key, new_value = init_kv_cache(bsz,
 | 
				
			||||||
                                             self.num_attention_heads,
 | 
					                                           self.num_attention_heads,
 | 
				
			||||||
                                             self.head_size,
 | 
					                                           self.head_size,
 | 
				
			||||||
                                             seq_len,
 | 
					                                           seq_len,
 | 
				
			||||||
                                             max_cache_length,
 | 
					                                           max_cache_length,
 | 
				
			||||||
                                             dtype=key.dtype,
 | 
					                                           dtype=key.dtype,
 | 
				
			||||||
                                             device=device)
 | 
					                                           device=device)
 | 
				
			||||||
        new_key[:] = key
 | 
					        new_key[:] = key
 | 
				
			||||||
        new_value[:] = value
 | 
					        new_value[:] = value
 | 
				
			||||||
        key = new_key
 | 
					        key = new_key
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,7 +37,7 @@ from typing import Optional, Tuple
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
					from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -113,7 +113,7 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
        cache_v = past_key_value[1]
 | 
					        cache_v = past_key_value[1]
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
					            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
				
			||||||
                                                       self.num_key_value_heads,  # Support GQA
 | 
					                                                       self.num_key_value_heads,  # Support GQA
 | 
				
			||||||
                                                       self.head_dim,
 | 
					                                                       self.head_dim,
 | 
				
			||||||
                                                       cache_k.size(2),
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
| 
						 | 
					@ -129,13 +129,13 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif use_cache:
 | 
					    elif use_cache:
 | 
				
			||||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
        new_key_states, new_value_states = create_kv_cache(bsz,
 | 
					        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
				
			||||||
                                                           self.num_key_value_heads,
 | 
					                                                         self.num_key_value_heads,
 | 
				
			||||||
                                                           self.head_dim,
 | 
					                                                         self.head_dim,
 | 
				
			||||||
                                                           kv_seq_len,
 | 
					                                                         kv_seq_len,
 | 
				
			||||||
                                                           max_cache_length,
 | 
					                                                         max_cache_length,
 | 
				
			||||||
                                                           dtype=key_states.dtype,
 | 
					                                                         dtype=key_states.dtype,
 | 
				
			||||||
                                                           device=device)
 | 
					                                                         device=device)
 | 
				
			||||||
        new_key_states[:] = key_states
 | 
					        new_key_states[:] = key_states
 | 
				
			||||||
        new_value_states[:] = value_states
 | 
					        new_value_states[:] = value_states
 | 
				
			||||||
        key_states = new_key_states
 | 
					        key_states = new_key_states
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										149
									
								
								python/llm/src/bigdl/llm/transformers/models/mpt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								python/llm/src/bigdl/llm/transformers/models/mpt.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,149 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from einops import rearrange
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
 | 
				
			||||||
 | 
					                                    attention_mask=None, is_causal=True, needs_weights=False):
 | 
				
			||||||
 | 
					    qkv = self.Wqkv(x)
 | 
				
			||||||
 | 
					    if self.clip_qkv:
 | 
				
			||||||
 | 
					        qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 | 
				
			||||||
 | 
					    (query, key, value) = qkv.chunk(3, dim=2)
 | 
				
			||||||
 | 
					    key_padding_mask = attention_mask
 | 
				
			||||||
 | 
					    if self.qk_ln:
 | 
				
			||||||
 | 
					        dtype = query.dtype
 | 
				
			||||||
 | 
					        query = self.q_ln(query).to(dtype)
 | 
				
			||||||
 | 
					        key = self.k_ln(key).to(dtype)
 | 
				
			||||||
 | 
					    (context, attn_weights, past_key_value) = \
 | 
				
			||||||
 | 
					        mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
 | 
				
			||||||
 | 
					                                                   past_key_value=past_key_value,
 | 
				
			||||||
 | 
					                                                   softmax_scale=self.softmax_scale,
 | 
				
			||||||
 | 
					                                                   attn_bias=attn_bias,
 | 
				
			||||||
 | 
					                                                   key_padding_mask=key_padding_mask,
 | 
				
			||||||
 | 
					                                                   is_causal=is_causal,
 | 
				
			||||||
 | 
					                                                   dropout_p=self.attn_dropout_p,
 | 
				
			||||||
 | 
					                                                   training=self.training,
 | 
				
			||||||
 | 
					                                                   needs_weights=needs_weights)
 | 
				
			||||||
 | 
					    return (self.out_proj(context), attn_weights, past_key_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
 | 
				
			||||||
 | 
					                                               past_key_value=None,
 | 
				
			||||||
 | 
					                                               softmax_scale=None,
 | 
				
			||||||
 | 
					                                               attn_bias=None,
 | 
				
			||||||
 | 
					                                               key_padding_mask=None,
 | 
				
			||||||
 | 
					                                               is_causal=False,
 | 
				
			||||||
 | 
					                                               dropout_p=0.0,
 | 
				
			||||||
 | 
					                                               training=False,
 | 
				
			||||||
 | 
					                                               needs_weights=False,
 | 
				
			||||||
 | 
					                                               multiquery=False):
 | 
				
			||||||
 | 
					    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
 | 
				
			||||||
 | 
					    bsz, n_heads, q_len, head_dim = q.size()
 | 
				
			||||||
 | 
					    device = q.device
 | 
				
			||||||
 | 
					    kv_n_heads = 1 if multiquery else n_heads
 | 
				
			||||||
 | 
					    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
 | 
				
			||||||
 | 
					    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
 | 
				
			||||||
 | 
					    kv_seq_len = k.shape[-1]
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        if len(past_key_value) != 0:
 | 
				
			||||||
 | 
					            # k = torch.cat([past_key_value[0], k], dim=3)
 | 
				
			||||||
 | 
					            # v = torch.cat([past_key_value[1], v], dim=2)
 | 
				
			||||||
 | 
					            cache_k = past_key_value[0].transpose(2, 3)
 | 
				
			||||||
 | 
					            cache_v = past_key_value[1]
 | 
				
			||||||
 | 
					            kv_seq_len += cache_k.shape[-2]
 | 
				
			||||||
 | 
					            if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
 | 
					                # allocate new
 | 
				
			||||||
 | 
					                new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                           kv_n_heads,  # Support GQA
 | 
				
			||||||
 | 
					                                                           head_dim,
 | 
				
			||||||
 | 
					                                                           cache_k.size(2),
 | 
				
			||||||
 | 
					                                                           kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                                                           dtype=cache_k.dtype,
 | 
				
			||||||
 | 
					                                                           device=device)
 | 
				
			||||||
 | 
					                new_cache_k[:] = cache_k
 | 
				
			||||||
 | 
					                new_cache_v[:] = cache_v
 | 
				
			||||||
 | 
					                cache_k = new_cache_k
 | 
				
			||||||
 | 
					                cache_v = new_cache_v
 | 
				
			||||||
 | 
					            key_states, value_states = append_kv_cache(cache_k, cache_v, k.transpose(2, 3), v)
 | 
				
			||||||
 | 
					            k = key_states.transpose(2, 3)
 | 
				
			||||||
 | 
					            v = value_states
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
 | 
					            new_key_states, new_value_states = init_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                             kv_n_heads,
 | 
				
			||||||
 | 
					                                                             head_dim,
 | 
				
			||||||
 | 
					                                                             kv_seq_len,
 | 
				
			||||||
 | 
					                                                             max_cache_length,
 | 
				
			||||||
 | 
					                                                             dtype=k.dtype,
 | 
				
			||||||
 | 
					                                                             device=device)
 | 
				
			||||||
 | 
					            new_key_states[:] = k.transpose(2, 3)
 | 
				
			||||||
 | 
					            new_value_states[:] = v
 | 
				
			||||||
 | 
					            k = new_key_states.transpose(2, 3)
 | 
				
			||||||
 | 
					            v = new_value_states
 | 
				
			||||||
 | 
					        past_key_value = (k, v)
 | 
				
			||||||
 | 
					    (b, _, s_q, d) = q.shape
 | 
				
			||||||
 | 
					    s_k = k.size(-1)
 | 
				
			||||||
 | 
					    if softmax_scale is None:
 | 
				
			||||||
 | 
					        softmax_scale = 1 / math.sqrt(d)
 | 
				
			||||||
 | 
					    attn_weight = q.matmul(k) * softmax_scale
 | 
				
			||||||
 | 
					    if attn_bias is not None:
 | 
				
			||||||
 | 
					        _s_q = max(0, attn_bias.size(2) - s_q)
 | 
				
			||||||
 | 
					        _s_k = max(0, attn_bias.size(3) - s_k)
 | 
				
			||||||
 | 
					        attn_bias = attn_bias[:, :, _s_q:, _s_k:]
 | 
				
			||||||
 | 
					        if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k \
 | 
				
			||||||
 | 
					                or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
 | 
				
			||||||
 | 
					            invalidInputError(False, f'attn_bias (shape: {attn_bias.shape}) '
 | 
				
			||||||
 | 
					                                     f'is expected to broadcast to shape: {attn_weight.shape}.')
 | 
				
			||||||
 | 
					        attn_weight = attn_weight + attn_bias
 | 
				
			||||||
 | 
					    min_val = torch.finfo(q.dtype).min
 | 
				
			||||||
 | 
					    if key_padding_mask is not None:
 | 
				
			||||||
 | 
					        if attn_bias is not None:
 | 
				
			||||||
 | 
					            warnings.warn('Propogating key_padding_mask to the attention module '
 | 
				
			||||||
 | 
					                          + 'and applying it within the attention module can cause '
 | 
				
			||||||
 | 
					                          + 'unneccessary computation/memory usage. Consider integrating '
 | 
				
			||||||
 | 
					                          + 'into attn_bias once and passing that to each attention '
 | 
				
			||||||
 | 
					                          + 'module instead.')
 | 
				
			||||||
 | 
					        attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
 | 
				
			||||||
 | 
					    if is_causal and (not q.size(2) == 1):
 | 
				
			||||||
 | 
					        s = max(s_q, s_k)
 | 
				
			||||||
 | 
					        causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
 | 
				
			||||||
 | 
					        causal_mask = causal_mask.tril()
 | 
				
			||||||
 | 
					        causal_mask = causal_mask.to(torch.bool)
 | 
				
			||||||
 | 
					        causal_mask = ~causal_mask
 | 
				
			||||||
 | 
					        causal_mask = causal_mask[-s_q:, -s_k:]
 | 
				
			||||||
 | 
					        attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
 | 
				
			||||||
 | 
					    attn_weight = torch.softmax(attn_weight, dim=-1)
 | 
				
			||||||
 | 
					    if dropout_p:
 | 
				
			||||||
 | 
					        attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p,
 | 
				
			||||||
 | 
					                                                  training=training, inplace=True)
 | 
				
			||||||
 | 
					    out = attn_weight.to(v.dtype).matmul(v)
 | 
				
			||||||
 | 
					    out = rearrange(out, 'b h s d -> b s (h d)')
 | 
				
			||||||
 | 
					    if needs_weights:
 | 
				
			||||||
 | 
					        return (out, attn_weight, past_key_value)
 | 
				
			||||||
 | 
					    return (out, None, past_key_value)
 | 
				
			||||||
| 
						 | 
					@ -18,9 +18,7 @@ import torch
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
					def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
				
			||||||
    if device.type == 'xpu':
 | 
					 | 
				
			||||||
        torch.xpu.empty_cache()
 | 
					 | 
				
			||||||
    key_cache_storage = torch.empty(batch_size, num_heads,
 | 
					    key_cache_storage = torch.empty(batch_size, num_heads,
 | 
				
			||||||
                                    max_length, head_dim,
 | 
					                                    max_length, head_dim,
 | 
				
			||||||
                                    dtype=dtype, device=device)
 | 
					                                    dtype=dtype, device=device)
 | 
				
			||||||
| 
						 | 
					@ -29,7 +27,7 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
 | 
				
			||||||
                                      dtype=dtype, device=device)
 | 
					                                      dtype=dtype, device=device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key_cache = key_cache_storage.as_strided((batch_size, num_heads,
 | 
					    key_cache = key_cache_storage.as_strided((batch_size, num_heads,
 | 
				
			||||||
                                             current_length, head_dim),
 | 
					                                              current_length, head_dim),
 | 
				
			||||||
                                             key_cache_storage.stride(),
 | 
					                                             key_cache_storage.stride(),
 | 
				
			||||||
                                             storage_offset=0)
 | 
					                                             storage_offset=0)
 | 
				
			||||||
    value_cache = value_cache_storage.as_strided((batch_size, num_heads,
 | 
					    value_cache = value_cache_storage.as_strided((batch_size, num_heads,
 | 
				
			||||||
| 
						 | 
					@ -39,6 +37,13 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
 | 
				
			||||||
    return key_cache, value_cache
 | 
					    return key_cache, value_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
				
			||||||
 | 
					    # empty cache to reduce gpu memory
 | 
				
			||||||
 | 
					    if device.type == 'xpu':
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					    return init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
					def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
				
			||||||
    new_size = (cache_k.size(0),
 | 
					    new_size = (cache_k.size(0),
 | 
				
			||||||
                cache_k.size(1),
 | 
					                cache_k.size(1),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue