[LLM] Add unified default value for cli programs (#8310)
* add unified default value for threads and n_predict
This commit is contained in:
		
							parent
							
								
									f41995051b
								
							
						
					
					
						commit
						c4028d507c
					
				
					 2 changed files with 72 additions and 29 deletions
				
			
		| 
						 | 
					@ -2,6 +2,9 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Default values
 | 
					# Default values
 | 
				
			||||||
model_family=""
 | 
					model_family=""
 | 
				
			||||||
 | 
					threads=8
 | 
				
			||||||
 | 
					n_predict=128
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
llm_dir="$(dirname "$(python -c "import bigdl.llm;print(bigdl.llm.__file__)")")"
 | 
					llm_dir="$(dirname "$(python -c "import bigdl.llm;print(bigdl.llm.__file__)")")"
 | 
				
			||||||
lib_dir="$llm_dir/libs"
 | 
					lib_dir="$llm_dir/libs"
 | 
				
			||||||
| 
						 | 
					@ -24,26 +27,28 @@ function display_help {
 | 
				
			||||||
  echo "usage: ./llm-cli.sh -x MODEL_FAMILY [-h] [args]"
 | 
					  echo "usage: ./llm-cli.sh -x MODEL_FAMILY [-h] [args]"
 | 
				
			||||||
  echo ""
 | 
					  echo ""
 | 
				
			||||||
  echo "options:"
 | 
					  echo "options:"
 | 
				
			||||||
  echo "  -h, --help  show this help message"
 | 
					  echo "  -h, --help           show this help message"
 | 
				
			||||||
  echo "  -x, --model_family {llama,bloom,gptneox}"
 | 
					  echo "  -x, --model_family {llama,bloom,gptneox}"
 | 
				
			||||||
  echo "              family name of model"
 | 
					  echo "                       family name of model"
 | 
				
			||||||
  echo "  args        parameters passed to the specified model function"
 | 
					  echo "  -t N, --threads N    number of threads to use during computation (default: 8)"
 | 
				
			||||||
 | 
					  echo "  -n N, --n_predict N  number of tokens to predict (default: 128, -1 = infinity)"
 | 
				
			||||||
 | 
					  echo "  args                 parameters passed to the specified model function"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function llama {
 | 
					function llama {
 | 
				
			||||||
  command="$lib_dir/main-llama_$avx_flag ${filteredArguments[*]}"
 | 
					  command="$lib_dir/main-llama_$avx_flag -t $threads -n $n_predict ${filteredArguments[*]}"
 | 
				
			||||||
  echo "$command"
 | 
					  echo "$command"
 | 
				
			||||||
  eval "$command"
 | 
					  eval "$command"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function bloom {
 | 
					function bloom {
 | 
				
			||||||
  command="$lib_dir/main-bloom_$avx_flag ${filteredArguments[*]}"
 | 
					  command="$lib_dir/main-bloom_$avx_flag -t $threads -n $n_predict ${filteredArguments[*]}"
 | 
				
			||||||
  echo "$command"
 | 
					  echo "$command"
 | 
				
			||||||
  eval "$command"
 | 
					  eval "$command"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function gptneox {
 | 
					function gptneox {
 | 
				
			||||||
  command="$lib_dir/main-gptneox_$avx_flag ${filteredArguments[*]}"
 | 
					  command="$lib_dir/main-gptneox_$avx_flag -t $threads -n $n_predict ${filteredArguments[*]}"
 | 
				
			||||||
  echo "$command"
 | 
					  echo "$command"
 | 
				
			||||||
  eval "$command"
 | 
					  eval "$command"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -57,10 +62,18 @@ while [[ $# -gt 0 ]]; do
 | 
				
			||||||
    filteredArguments+=("'$1'")
 | 
					    filteredArguments+=("'$1'")
 | 
				
			||||||
    shift
 | 
					    shift
 | 
				
			||||||
    ;;
 | 
					    ;;
 | 
				
			||||||
  -x | --model_family)
 | 
					  -x | --model_family | --model-family)
 | 
				
			||||||
    model_family="$2"
 | 
					    model_family="$2"
 | 
				
			||||||
    shift 2
 | 
					    shift 2
 | 
				
			||||||
    ;;
 | 
					    ;;
 | 
				
			||||||
 | 
					  -t | --threads)
 | 
				
			||||||
 | 
					    threads="$2"
 | 
				
			||||||
 | 
					    shift 2
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
 | 
					  -n | --n_predict | --n-predict)
 | 
				
			||||||
 | 
					    n_predict="$2"
 | 
				
			||||||
 | 
					    shift 2
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
  *)
 | 
					  *)
 | 
				
			||||||
    filteredArguments+=("'$1'")
 | 
					    filteredArguments+=("'$1'")
 | 
				
			||||||
    shift
 | 
					    shift
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,31 +1,41 @@
 | 
				
			||||||
$llm_dir = (Split-Path -Parent (python -c "import bigdl.llm;print(bigdl.llm.__file__)"))
 | 
					$llm_dir = (Split-Path -Parent (python -c "import bigdl.llm;print(bigdl.llm.__file__)"))
 | 
				
			||||||
$lib_dir = Join-Path $llm_dir "libs"
 | 
					$lib_dir = Join-Path $llm_dir "libs"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					$model_family = ""
 | 
				
			||||||
 | 
					$threads = 8
 | 
				
			||||||
 | 
					$n_predict = 128
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Function to display help message
 | 
					# Function to display help message
 | 
				
			||||||
function Display-Help {
 | 
					function Display-Help
 | 
				
			||||||
  Write-Host "usage: ./llm-cli.ps1 -x MODEL_FAMILY [-h] [args]"
 | 
					{
 | 
				
			||||||
  Write-Host ""
 | 
					    Write-Host "usage: ./llm-cli.ps1 -x MODEL_FAMILY [-h] [args]"
 | 
				
			||||||
  Write-Host "options:"
 | 
					    Write-Host ""
 | 
				
			||||||
  Write-Host "  -h, --help  show this help message"
 | 
					    Write-Host "options:"
 | 
				
			||||||
  Write-Host "  -x, --model_family {llama,bloom,gptneox}"
 | 
					    Write-Host "  -h, --help           show this help message"
 | 
				
			||||||
  Write-Host "              family name of model"
 | 
					    Write-Host "  -x, --model_family {llama,bloom,gptneox}"
 | 
				
			||||||
  Write-Host "  args        parameters passed to the specified model function"
 | 
					    Write-Host "                       family name of model"
 | 
				
			||||||
 | 
					    Write-Host "  -t N, --threads N    number of threads to use during computation (default: 8)"
 | 
				
			||||||
 | 
					    Write-Host "  -n N, --n_predict N  number of tokens to predict (default: 128, -1 = infinity)"
 | 
				
			||||||
 | 
					    Write-Host "  args                 parameters passed to the specified model function"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function llama {
 | 
					function llama
 | 
				
			||||||
    $command = "$lib_dir/main-llama.exe $filteredArguments"
 | 
					{
 | 
				
			||||||
 | 
					    $command = "$lib_dir/main-llama.exe -t $threads -n $n_predict $filteredArguments"
 | 
				
			||||||
    Write-Host "$command"
 | 
					    Write-Host "$command"
 | 
				
			||||||
    Invoke-Expression $command
 | 
					    Invoke-Expression $command
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function bloom {
 | 
					function bloom
 | 
				
			||||||
    $command = "$lib_dir/main-bloom.exe $filteredArguments"
 | 
					{
 | 
				
			||||||
 | 
					    $command = "$lib_dir/main-bloom.exe -t $threads -n $n_predict $filteredArguments"
 | 
				
			||||||
    Write-Host "$command"
 | 
					    Write-Host "$command"
 | 
				
			||||||
    Invoke-Expression $command
 | 
					    Invoke-Expression $command
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function gptneox {
 | 
					function gptneox
 | 
				
			||||||
    $command = "$lib_dir/main-gptneox.exe $filteredArguments"
 | 
					{
 | 
				
			||||||
 | 
					    $command = "$lib_dir/main-gptneox.exe -t $threads -n $n_predict $filteredArguments"
 | 
				
			||||||
    Write-Host "$command"
 | 
					    Write-Host "$command"
 | 
				
			||||||
    Invoke-Expression $command
 | 
					    Invoke-Expression $command
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -33,22 +43,42 @@ function gptneox {
 | 
				
			||||||
# Remove model_family/x parameter
 | 
					# Remove model_family/x parameter
 | 
				
			||||||
$filteredArguments = @()
 | 
					$filteredArguments = @()
 | 
				
			||||||
for ($i = 0; $i -lt $args.Length; $i++) {
 | 
					for ($i = 0; $i -lt $args.Length; $i++) {
 | 
				
			||||||
    if ($args[$i] -eq '--model_family' -or $args[$i] -eq '-x') {
 | 
					    if ($args[$i] -eq '--model_family' -or $args[$i] -eq '--model-family' -or $args[$i] -eq '-x')
 | 
				
			||||||
        if ($i + 1 -lt $args.Length -and $args[$i + 1] -notlike '-*') {
 | 
					    {
 | 
				
			||||||
 | 
					        if ($i + 1 -lt $args.Length -and $args[$i + 1] -notlike '-*')
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
            $i++
 | 
					            $i++
 | 
				
			||||||
            $model_family = $args[$i]
 | 
					            $model_family = $args[$i]
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    else {
 | 
					    elseif ($args[$i] -eq '--threads' -or $args[$i] -eq '-t')
 | 
				
			||||||
        $filteredArguments += "`'"+$args[$i]+"`'"
 | 
					    {
 | 
				
			||||||
 | 
					        $i++
 | 
				
			||||||
 | 
					        $threads = $args[$i]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    elseif ($args[$i] -eq '--n_predict' -or $args[$i] -eq '--n-predict' -or $args[$i] -eq '-n')
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        $i++
 | 
				
			||||||
 | 
					        $n_predict = $args[$i]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        $filteredArguments += "`'" + $args[$i] + "`'"
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Perform actions based on the model_family
 | 
					# Perform actions based on the model_family
 | 
				
			||||||
switch ($model_family) {
 | 
					switch ($model_family)
 | 
				
			||||||
    "llama" { llama }
 | 
					{
 | 
				
			||||||
    "bloom" { bloom }
 | 
					    "llama" {
 | 
				
			||||||
    "gptneox" { gptneox }
 | 
					        llama
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    "bloom" {
 | 
				
			||||||
 | 
					        bloom
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    "gptneox" {
 | 
				
			||||||
 | 
					        gptneox
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    default {
 | 
					    default {
 | 
				
			||||||
        Write-Host "Invalid model_family: $model_family"
 | 
					        Write-Host "Invalid model_family: $model_family"
 | 
				
			||||||
        Display-Help
 | 
					        Display-Help
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue