diff --git a/docker/llm/serving/cpu/docker/entrypoint.sh b/docker/llm/serving/cpu/docker/entrypoint.sh index 92eab4ed..8943e21d 100644 --- a/docker/llm/serving/cpu/docker/entrypoint.sh +++ b/docker/llm/serving/cpu/docker/entrypoint.sh @@ -1,7 +1,7 @@ #!/bin/bash usage() { - echo "Usage: $0 [-m --mode ] [-h --help]" + echo "Usage: $0 [-m --mode ] [-h --help] [-w --worker ]" echo "-h: Print help message." echo "Controller mode reads the following env:" echo "CONTROLLER_HOST (default: localhost)." @@ -85,6 +85,7 @@ mode="" omp_num_threads="" dispatch_method="shortest_queue" # shortest_queue or lottery stream_interval=1 +worker_type="model_worker" # Update rootCA config if needed update-ca-certificates @@ -101,7 +102,7 @@ if [ "$#" == 0 ]; then exec /usr/bin/tini -s -- "bash" else # Parse command-line options - options=$(getopt -o "m:h" --long "mode:,help" -n "$0" -- "$@") + options=$(getopt -o "m:hw:" --long "mode:,help,worker:" -n "$0" -- "$@") if [ $? != 0 ]; then usage fi @@ -114,6 +115,11 @@ else [[ $mode == "controller" || $mode == "worker" ]] || usage shift 2 ;; + -w|--worker) + worker_type="$2" + [[ $worker_type == "model_worker" || $worker_type == "vllm_worker" ]] || usage + shift 2 + ;; -h|--help) usage ;; @@ -127,6 +133,12 @@ else esac done + if [ "$worker_type" == "model_worker" ]; then + worker_type="fastchat.serve.model_worker" + elif [ "$worker_type" == "vllm_worker" ]; then + worker_type="fastchat.serve.vllm_worker" + fi + if [[ -n $CONTROLLER_HOST ]]; then controller_host=$CONTROLLER_HOST fi @@ -198,9 +210,10 @@ else echo "Please set env MODEL_PATH used for worker" usage fi + echo "Worker type: $worker_type" echo "Worker address: $worker_address" echo "Controller address: $controller_address" - python3 -m fastchat.serve.model_worker --model-path $model_path --device cpu --host $worker_host --port $worker_port --worker-address $worker_address --controller-address $controller_address --stream-interval $stream_interval + python3 -m "$worker_type" --model-path $model_path --device cpu --host $worker_host --port $worker_port --worker-address $worker_address --controller-address $controller_address --stream-interval $stream_interval fi fi