[LLM] Support CPU deepspeed distributed inference (#9259)
* [LLM] Support CPU Deepspeed distributed inference * Update run_deepspeed.py * Rename * fix style * add new codes * refine * remove annotated codes * refine * Update README.md * refine doc and example code
This commit is contained in:
		
							parent
							
								
									f9bf5382ff
								
							
						
					
					
						commit
						af94058203
					
				
					 8 changed files with 346 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -17,4 +17,6 @@ test_api:
 | 
			
		|||
  - "pytorch_autocast_bf16"
 | 
			
		||||
  # - "ipex_fp16_gpu" # on Intel GPU
 | 
			
		||||
  # - "transformer_int4_gpu"  # on Intel GPU
 | 
			
		||||
  # - "optimize_model_gpu"  # on Intel GPU
 | 
			
		||||
  # - "optimize_model_gpu"  # on Intel GPU
 | 
			
		||||
  # - "deepspeed_transformer_int4_cpu" # on Intel SPR Server
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										18
									
								
								python/llm/dev/benchmark/all-in-one/run-deepspeed-spr.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								python/llm/dev/benchmark/all-in-one/run-deepspeed-spr.sh
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,18 @@
 | 
			
		|||
#!/bin/bash
 | 
			
		||||
source bigdl-nano-init
 | 
			
		||||
unset OMP_NUM_THREADS # deepspeed will set it for each instance automatically
 | 
			
		||||
source /opt/intel/oneccl/env/setvars.sh
 | 
			
		||||
export WORLD_SIZE=2 # run 1 instance per SPR socket, thus 2 instances on 2 sockets, 96 cores
 | 
			
		||||
export MASTER_ADDR=127.0.0.1
 | 
			
		||||
export CCL_ZE_IPC_EXCHANGE=sockets
 | 
			
		||||
export DS_ACCELERATOR="cpu"
 | 
			
		||||
export CCL_WORKER_AFFINITY=auto
 | 
			
		||||
unset KMP_AFFINITY # deepspeed will set it for each instance automatically
 | 
			
		||||
export FI_PROVIDER=tcp
 | 
			
		||||
export CCL_ATL_TRANSPORT=ofi
 | 
			
		||||
export CCL_PROCESS_LAUNCHER=none
 | 
			
		||||
 | 
			
		||||
deepspeed \
 | 
			
		||||
  --bind_cores_to_rank \
 | 
			
		||||
  --bind_core_list 0-95 \
 | 
			
		||||
  run.py
 | 
			
		||||
| 
						 | 
				
			
			@ -55,6 +55,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
 | 
			
		|||
        result = run_pytorch_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
 | 
			
		||||
    elif test_api == 'ipex_fp16_gpu':
 | 
			
		||||
        result = run_ipex_fp16_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
 | 
			
		||||
    elif test_api == 'deepspeed_transformer_int4_cpu':
 | 
			
		||||
        result = run_deepspeed_transformer_int4_cpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit)
 | 
			
		||||
 | 
			
		||||
    for in_out_pair in in_out_pairs:
 | 
			
		||||
        if result:
 | 
			
		||||
| 
						 | 
				
			
			@ -540,6 +542,92 @@ def run_ipex_fp16_gpu(repo_id,
 | 
			
		|||
    torch.xpu.empty_cache()
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def run_deepspeed_transformer_int4_cpu(repo_id,
 | 
			
		||||
                         local_model_hub,
 | 
			
		||||
                         in_out_pairs,
 | 
			
		||||
                         warm_up,
 | 
			
		||||
                         num_trials,
 | 
			
		||||
                         num_beams,
 | 
			
		||||
                         low_bit):
 | 
			
		||||
    from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer
 | 
			
		||||
    import deepspeed
 | 
			
		||||
    from bigdl.llm import optimize_model
 | 
			
		||||
    import argparse
 | 
			
		||||
    # parser is for deepspeed subprocesses' inline parameter
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
			
		||||
    parser.add_argument('--local_rank', type=str, default=0, help='this is automatically set when using deepspeed launcher')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    local_rank = int(os.getenv("RANK", "1"))
 | 
			
		||||
    if local_rank == -1:
 | 
			
		||||
        local_rank = args.local_rank
 | 
			
		||||
    world_size = int(os.getenv("WORLD_SIZE", "1"))
 | 
			
		||||
    model_path = get_model_path(repo_id, local_model_hub)
 | 
			
		||||
 | 
			
		||||
    st = time.perf_counter()
 | 
			
		||||
    # Note: only tested cpu Llama2-7b
 | 
			
		||||
    # Native Huggingface transformers loading to enable deepspeed init
 | 
			
		||||
    if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
 | 
			
		||||
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True, use_cache=True)
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    elif repo_id in LLAMA_IDS:
 | 
			
		||||
        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,
 | 
			
		||||
                                                     use_cache=True)
 | 
			
		||||
        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    else:
 | 
			
		||||
        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, use_cache=True)
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    # Parallelize model on deepspeed
 | 
			
		||||
    model = deepspeed.init_inference(model, mp_size=world_size,
 | 
			
		||||
                                     dtype=torch.float16,
 | 
			
		||||
                                     replace_method="auto")
 | 
			
		||||
 | 
			
		||||
    # Apply BigDL-LLM INT4 optimization to enable BenchmarkWrapper
 | 
			
		||||
    # Note: only tested sym_int4
 | 
			
		||||
    model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit)
 | 
			
		||||
    model = model.to(f'cpu:{local_rank}')
 | 
			
		||||
 | 
			
		||||
    end = time.perf_counter()
 | 
			
		||||
    print(">> loading of model costs {}s".format(end - st))
 | 
			
		||||
 | 
			
		||||
    model = BenchmarkWrapper(model)
 | 
			
		||||
 | 
			
		||||
    result = {}
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        for in_out in in_out_pairs:
 | 
			
		||||
            in_out_len = in_out.split("-")
 | 
			
		||||
            in_len = int(in_out_len[0])
 | 
			
		||||
            out_len = int(in_out_len[1])
 | 
			
		||||
            # As different tokenizer has different encodings,
 | 
			
		||||
            # in_len.txt maybe shorter than we need,
 | 
			
		||||
            # use much longer context to make sure input length
 | 
			
		||||
            test_length = min(in_len*2, 8192)
 | 
			
		||||
            while test_length not in [32, 256, 1024, 2048, 8192]:
 | 
			
		||||
                test_length = test_length * 2
 | 
			
		||||
            input_str = open(f"prompt/{test_length}.txt", 'r').read()
 | 
			
		||||
            # As different tokenizer has different encodings,
 | 
			
		||||
            # slice the input_ids to ensure the prompt length is required length.
 | 
			
		||||
            input_ids = tokenizer.encode(input_str, return_tensors="pt")
 | 
			
		||||
            input_ids = input_ids[:, :in_len]
 | 
			
		||||
            true_str = tokenizer.batch_decode(input_ids)[0]
 | 
			
		||||
            input_ids = tokenizer.encode(true_str, return_tensors="pt")
 | 
			
		||||
            actual_in_len = input_ids.shape[1]
 | 
			
		||||
            result[in_out] = []
 | 
			
		||||
            for i in range(num_trials + warm_up):
 | 
			
		||||
                st = time.perf_counter()
 | 
			
		||||
                output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
 | 
			
		||||
                                                num_beams=num_beams)
 | 
			
		||||
                end = time.perf_counter()
 | 
			
		||||
                if local_rank == 0:
 | 
			
		||||
                    print("model generate cost: " + str(end - st))
 | 
			
		||||
                output = tokenizer.batch_decode(output_ids)
 | 
			
		||||
                if local_rank == 0:
 | 
			
		||||
                    print(output[0])
 | 
			
		||||
                actual_out_len = output_ids.shape[1] - actual_in_len
 | 
			
		||||
                if i >= warm_up :
 | 
			
		||||
                    result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
 | 
			
		||||
                                           actual_in_len, actual_out_len])
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from omegaconf import OmegaConf
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										69
									
								
								python/llm/example/CPU/Deepspeed-AutoTP/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								python/llm/example/CPU/Deepspeed-AutoTP/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,69 @@
 | 
			
		|||
### Run Tensor-Parallel BigDL Transformers INT4 Inference with Deepspeed
 | 
			
		||||
 | 
			
		||||
#### 1. Install Dependencies
 | 
			
		||||
 | 
			
		||||
Install necessary packages (here Python 3.9 is our test environment):
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
bash install.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 2. Initialize Deepspeed Distributed Context
 | 
			
		||||
 | 
			
		||||
Like shown in example code `deepspeed_autotp.py`, you can construct parallel model with Python API:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
# Load in HuggingFace Transformers' model
 | 
			
		||||
from transformers import AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
model = AutoModelForCausalLM.from_pretrained(...)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Parallelize model on deepspeed
 | 
			
		||||
import deepspeed
 | 
			
		||||
 | 
			
		||||
model = deepspeed.init_inference(
 | 
			
		||||
    model, # an AutoModel of Transformers
 | 
			
		||||
    mp_size = world_size, # instance (process) count
 | 
			
		||||
    dtype=torch.float16,
 | 
			
		||||
    replace_method="auto")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Then, returned model is converted into a deepspeed InferenceEnginee type.
 | 
			
		||||
 | 
			
		||||
#### 3. Optimize Model with BigDL-LLM Low Bit
 | 
			
		||||
 | 
			
		||||
Distributed model managed by deepspeed can be further optimized with BigDL low-bit Python API, e.g. sym_int4:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
# Apply BigDL-LLM INT4 optimizations on transformers
 | 
			
		||||
from bigdl.llm import optimize_model
 | 
			
		||||
 | 
			
		||||
model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4')
 | 
			
		||||
model = model.to(f'cpu:{local_rank}') # move partial model to local rank
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Then, a bigdl-llm transformers is returned, which in the following, can serve in parallel with native APIs.
 | 
			
		||||
 | 
			
		||||
#### 4. Start Python Code
 | 
			
		||||
 | 
			
		||||
You can try deepspeed with BigDL LLM by:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
bash run.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
If you want to run your own application, there are **necessary configurations in the script** which can also be ported to run your custom deepspeed application:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
# run.sh
 | 
			
		||||
source bigdl-nano-init
 | 
			
		||||
unset OMP_NUM_THREADS # deepspeed will set it for each instance automatically
 | 
			
		||||
source /opt/intel/oneccl/env/setvars.sh
 | 
			
		||||
......
 | 
			
		||||
export FI_PROVIDER=tcp
 | 
			
		||||
export CCL_ATL_TRANSPORT=ofi
 | 
			
		||||
export CCL_PROCESS_LAUNCHER=none
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Set the above configurations before running `deepspeed` please to ensure right parallel communication and high performance.
 | 
			
		||||
							
								
								
									
										125
									
								
								python/llm/example/CPU/Deepspeed-AutoTP/deepspeed_autotp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								python/llm/example/CPU/Deepspeed-AutoTP/deepspeed_autotp.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,125 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/TimDettmers/bitsandbytes/blob/0.39.1/bitsandbytes/nn/modules.py
 | 
			
		||||
# which is licensed under the MIT license:
 | 
			
		||||
#
 | 
			
		||||
# MIT License
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates.
 | 
			
		||||
#
 | 
			
		||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
# of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
# in the Software without restriction, including without limitation the rights
 | 
			
		||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
			
		||||
# copies of the Software, and to permit persons to whom the Software is
 | 
			
		||||
# furnished to do so, subject to the following conditions:
 | 
			
		||||
 | 
			
		||||
# The above copyright notice and this permission notice shall be included in all
 | 
			
		||||
# copies or substantial portions of the Software.
 | 
			
		||||
 | 
			
		||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
			
		||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
			
		||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
			
		||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
			
		||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
			
		||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
			
		||||
# SOFTWARE.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer
 | 
			
		||||
import deepspeed
 | 
			
		||||
from bigdl.llm import optimize_model
 | 
			
		||||
import torch
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
import time
 | 
			
		||||
import argparse
 | 
			
		||||
from benchmark_util import BenchmarkWrapper
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
			
		||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
			
		||||
                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
 | 
			
		||||
                             ', or the path to the huggingface checkpoint folder')
 | 
			
		||||
    parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
 | 
			
		||||
                        help='Prompt to infer')
 | 
			
		||||
    parser.add_argument('--n-predict', type=int, default=32,
 | 
			
		||||
                        help='Max tokens to predict')
 | 
			
		||||
    parser.add_argument('--local_rank', type=int, default=0, help='this is automatically set when using deepspeed launcher')
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
    world_size = int(os.getenv("WORLD_SIZE", "1"))
 | 
			
		||||
    local_rank = int(os.getenv("RANK", "-1")) # RANK is automatically set by CCL distributed backend
 | 
			
		||||
    if local_rank == -1: # args.local_rank is automatically set by deepspeed subprocess command
 | 
			
		||||
        local_rank = args.local_rank
 | 
			
		||||
 | 
			
		||||
    # Native Huggingface transformers loading
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
        model_path,
 | 
			
		||||
        device_map={"": "cpu"},
 | 
			
		||||
        low_cpu_mem_usage=True,
 | 
			
		||||
        torch_dtype=torch.float16,
 | 
			
		||||
        trust_remote_code=True,
 | 
			
		||||
        use_cache=True
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Parallelize model on deepspeed
 | 
			
		||||
    model = deepspeed.init_inference(
 | 
			
		||||
        model,
 | 
			
		||||
        mp_size = world_size,
 | 
			
		||||
        dtype=torch.float16,
 | 
			
		||||
        replace_method="auto"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Apply BigDL-LLM INT4 optimizations on transformers
 | 
			
		||||
    model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4')
 | 
			
		||||
 | 
			
		||||
    model = model.to(f'cpu:{local_rank}')
 | 
			
		||||
 | 
			
		||||
    print(model)
 | 
			
		||||
    model = BenchmarkWrapper(model, do_print=True)
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    # Generate predicted tokens
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        # Batch tokenizing
 | 
			
		||||
        prompt = args.prompt
 | 
			
		||||
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'cpu:{local_rank}')
 | 
			
		||||
        # ipex model needs a warmup, then inference time can be accurate
 | 
			
		||||
        output = model.generate(input_ids,
 | 
			
		||||
                                max_new_tokens=args.n_predict,
 | 
			
		||||
                                use_cache=True)
 | 
			
		||||
        # start inference
 | 
			
		||||
        start = time.time()
 | 
			
		||||
        # if your selected model is capable of utilizing previous key/value attentions
 | 
			
		||||
        # to enhance decoding speed, but has `"use_cache": false` in its model config,
 | 
			
		||||
        # it is important to set `use_cache=True` explicitly in the `generate` function
 | 
			
		||||
        # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
			
		||||
        output = model.generate(input_ids,
 | 
			
		||||
                                do_sample=False,
 | 
			
		||||
                                max_new_tokens=args.n_predict)
 | 
			
		||||
        end = time.time()
 | 
			
		||||
        if local_rank == 0:
 | 
			
		||||
            output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
			
		||||
            print('-'*20, 'Output', '-'*20)
 | 
			
		||||
            print(output_str)
 | 
			
		||||
            print(f'Inference time: {end - start} s')
 | 
			
		||||
							
								
								
									
										9
									
								
								python/llm/example/CPU/Deepspeed-AutoTP/install.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								python/llm/example/CPU/Deepspeed-AutoTP/install.sh
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,9 @@
 | 
			
		|||
#!/bin/bash
 | 
			
		||||
# install torch
 | 
			
		||||
pip install torch==2.1.0
 | 
			
		||||
# install torchccl
 | 
			
		||||
pip install https://intel-extension-for-pytorch.s3.amazonaws.com/torch_ccl/cpu/oneccl_bind_pt-2.1.0%2Bcpu-cp39-cp39-linux_x86_64.whl
 | 
			
		||||
# install deepspeed
 | 
			
		||||
pip install deepspeed==0.11.1
 | 
			
		||||
# exclude intel deepspeed extension, which is only for XPU
 | 
			
		||||
pip uninstall intel-extension-for-deepspeed --ignore-missing
 | 
			
		||||
							
								
								
									
										18
									
								
								python/llm/example/CPU/Deepspeed-AutoTP/run.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								python/llm/example/CPU/Deepspeed-AutoTP/run.sh
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,18 @@
 | 
			
		|||
#/bin/bash
 | 
			
		||||
source bigdl-nano-init
 | 
			
		||||
unset OMP_NUM_THREADS # deepspeed will set it for each instance automatically
 | 
			
		||||
source /opt/intel/oneccl/env/setvars.sh
 | 
			
		||||
export WORLD_SIZE=2 # run 1 instance per SPR socket, thus 2 instances on 2 sockets, 96 cores
 | 
			
		||||
export MASTER_ADDR=127.0.0.1
 | 
			
		||||
export CCL_ZE_IPC_EXCHANGE=sockets
 | 
			
		||||
export DS_ACCELERATOR="cpu"
 | 
			
		||||
export CCL_WORKER_AFFINITY=auto
 | 
			
		||||
unset KMP_AFFINITY # deepspeed will set it for each instance automatically
 | 
			
		||||
export FI_PROVIDER=tcp
 | 
			
		||||
export CCL_ATL_TRANSPORT=ofi
 | 
			
		||||
export CCL_PROCESS_LAUNCHER=none
 | 
			
		||||
 | 
			
		||||
deepspeed \
 | 
			
		||||
  --bind_cores_to_rank \
 | 
			
		||||
  --bind_core_list 0-95 \
 | 
			
		||||
  deepspeed_autotp.py
 | 
			
		||||
| 
						 | 
				
			
			@ -464,17 +464,30 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                              " supported on CPU")
 | 
			
		||||
            if self.training and x.requires_grad:
 | 
			
		||||
                result = MatMulLowBitCPU.apply(x, self.weight)
 | 
			
		||||
                if self.bias is not None:
 | 
			
		||||
                    result = result + self.bias
 | 
			
		||||
            else:
 | 
			
		||||
                if IS_SERVER and (not IS_SPR) and \
 | 
			
		||||
                        self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
 | 
			
		||||
                    x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
 | 
			
		||||
                    result = F.linear(x, x0_fp32, self.bias)
 | 
			
		||||
                    if self.mp_group is None:
 | 
			
		||||
                        # none-distributed mode
 | 
			
		||||
                        result = F.linear(x, x0_fp32, self.bias)
 | 
			
		||||
                    else:
 | 
			
		||||
                        result = F.linear(x, x0_fp32)
 | 
			
		||||
                        from deepspeed import comm as dist
 | 
			
		||||
                        # Parallel F.linear should be avoided,
 | 
			
		||||
                        # thus deepspeed allreduce after the operation
 | 
			
		||||
                        dist.inference_all_reduce(result, group=self.mp_group)
 | 
			
		||||
                        if self.bias is not None:
 | 
			
		||||
                            result += self.bias
 | 
			
		||||
                else:
 | 
			
		||||
                    result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
 | 
			
		||||
                    new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
                    result = result.view(new_shape)
 | 
			
		||||
                    # bias is consistent among multi instances,
 | 
			
		||||
                    # deepspeed only allreduce result without bias to reduce comunication
 | 
			
		||||
                    if self.mp_group is not None:
 | 
			
		||||
                        from deepspeed import comm as dist
 | 
			
		||||
                        dist.inference_all_reduce(result, group=self.mp_group)
 | 
			
		||||
                    if self.bias is not None:
 | 
			
		||||
                        result += self.bias
 | 
			
		||||
        return result
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue