LLM: Install CPU version torch with extras [all] (#10868)

Modify setup.py to install CPU version torch with extras [all]
This commit is contained in:
Xiangyu Tian 2024-05-16 10:39:55 +08:00 committed by GitHub
parent 59df750326
commit 612a365479
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 5 deletions

View file

@ -38,7 +38,7 @@ runs:
pip install --upgrade --pre -i https://pypi.python.org/simple --force-reinstall "python/llm/dist/${whl_name}[xpu_2.1]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
pip install pytest expecttest
else
pip install --upgrade --pre -i https://pypi.python.org/simple --force-reinstall "python/llm/dist/${whl_name}[all]"
pip install --upgrade --pre -i https://pypi.python.org/simple --force-reinstall "python/llm/dist/${whl_name}[all]" --extra-index-url https://download.pytorch.org/whl/cpu
pip install pytest
bash python/llm/test/run-llm-install-tests.sh
fi

View file

@ -50,11 +50,13 @@ CORE_XE_VERSION = VERSION.replace("2.1.0", "2.5.0")
llm_home = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
github_artifact_dir = os.path.join(llm_home, '../llm-binary')
libs_dir = os.path.join(llm_home, "ipex_llm", "libs")
cpu_torch_version = ["torch==2.1.2+cpu;platform_system=='Linux'", "torch==2.1.2;platform_system=='Windows'"]
CONVERT_DEP = ['numpy == 1.26.4', # lastet 2.0.0b1 will cause error
'torch',
'transformers == 4.31.0', 'sentencepiece', 'tokenizers == 0.13.3',
# TODO: Support accelerate 0.22.0
'accelerate == 0.21.0', 'tabulate']
'accelerate == 0.21.0', 'tabulate'] + cpu_torch_version
SERVING_DEP = ['fschat[model_worker, webui] == 0.2.36', 'protobuf']
windows_binarys = [
"llama.dll",
@ -277,7 +279,7 @@ def setup_package():
# Add internal requires for llama-index
llama_index_requires = copy.deepcopy(all_requires)
for exclude_require in ['torch', 'transformers == 4.31.0', 'tokenizers == 0.13.3']:
for exclude_require in ['transformers == 4.31.0', 'tokenizers == 0.13.3'] + cpu_torch_version:
llama_index_requires.remove(exclude_require)
llama_index_requires += ["torch<2.2.0",
"transformers>=4.34.0,<4.39.0",
@ -289,7 +291,8 @@ def setup_package():
"onednn==2024.0.0;platform_system=='Windows'"]
# Linux install with --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
xpu_21_requires = copy.deepcopy(all_requires)
xpu_21_requires.remove('torch')
for exclude_require in cpu_torch_version:
xpu_21_requires.remove(exclude_require)
xpu_21_requires += ["torch==2.1.0a0",
"torchvision==0.16.0a0",
"intel_extension_for_pytorch==2.1.10+xpu",