ipex-llm/python/llm/example/GPU/GraphMode
Heyang Sun fa261b8af1
torch 2.3 inference docker (#12517)
* torch 2.3 inference docker

* Update README.md

* add convert code

* rename image

* remove 2.1 and add graph example

* Update README.md
2024-12-13 10:47:04 +08:00
..
convert-model-textgen-to-classfication.py torch 2.3 inference docker (#12517) 2024-12-13 10:47:04 +08:00
gpt2-graph-mode-benchmark.py torch 2.3 inference docker (#12517) 2024-12-13 10:47:04 +08:00
README.md torch 2.3 inference docker (#12517) 2024-12-13 10:47:04 +08:00

Torch Graph Mode

Here, we provide how to run torch graph mode on Intel Arc™ A-Series Graphics with ipex-llm, and gpt2-medium for classification task is used as illustration:

1. Install

conda create -n ipex-llm python=3.11
conda activate ipex-llm
pip install --pre --upgrade ipex-llm[xpu_arc] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
pip install --pre pytorch-triton-xpu==3.0.0+1b2f15840e --index-url https://download.pytorch.org/whl/nightly/xpu
conda install -c conda-forge libstdcxx-ng
unset OCL_ICD_VENDORS

2. Configures OneAPI environment variables

Note

Skip this step if you are running on Windows.

This is a required step on Linux for APT or offline installed oneAPI. Skip this step for PIP-installed oneAPI.

source /opt/intel/oneapi/setvars.sh

3. Run

Convert text-generating GPT2-Medium to the classification:

# The convert step needs to access the internet
export http_proxy=http://your_proxy_url
export https_proxy=http://your_proxy_url

# This will yield gpt2-medium-classification under /llm/models in the container
python convert-model-textgen-to-classfication.py --model-path MODEL_PATH

This will yield a mode directory ends with '-classification' neart your input model path.

Benchmark GPT2-Medium's performance with IPEX-LLM engine:

ipexrun xpu gpt2-graph-mode-benchmark.py --device xpu --engine ipex-llm --batch 16 --model-path MODEL_PATH

# You will see the key output like:
# Average time taken (excluding the first two loops): xxxx seconds, Classification per seconds is xxxx