[LLM] Add unified default value for cli programs (#8310)
* add unified default value for threads and n_predict
This commit is contained in:
		
							parent
							
								
									f41995051b
								
							
						
					
					
						commit
						c4028d507c
					
				
					 2 changed files with 72 additions and 29 deletions
				
			
		| 
						 | 
				
			
			@ -2,6 +2,9 @@
 | 
			
		|||
 | 
			
		||||
# Default values
 | 
			
		||||
model_family=""
 | 
			
		||||
threads=8
 | 
			
		||||
n_predict=128
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
llm_dir="$(dirname "$(python -c "import bigdl.llm;print(bigdl.llm.__file__)")")"
 | 
			
		||||
lib_dir="$llm_dir/libs"
 | 
			
		||||
| 
						 | 
				
			
			@ -27,23 +30,25 @@ function display_help {
 | 
			
		|||
  echo "  -h, --help           show this help message"
 | 
			
		||||
  echo "  -x, --model_family {llama,bloom,gptneox}"
 | 
			
		||||
  echo "                       family name of model"
 | 
			
		||||
  echo "  -t N, --threads N    number of threads to use during computation (default: 8)"
 | 
			
		||||
  echo "  -n N, --n_predict N  number of tokens to predict (default: 128, -1 = infinity)"
 | 
			
		||||
  echo "  args                 parameters passed to the specified model function"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function llama {
 | 
			
		||||
  command="$lib_dir/main-llama_$avx_flag ${filteredArguments[*]}"
 | 
			
		||||
  command="$lib_dir/main-llama_$avx_flag -t $threads -n $n_predict ${filteredArguments[*]}"
 | 
			
		||||
  echo "$command"
 | 
			
		||||
  eval "$command"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function bloom {
 | 
			
		||||
  command="$lib_dir/main-bloom_$avx_flag ${filteredArguments[*]}"
 | 
			
		||||
  command="$lib_dir/main-bloom_$avx_flag -t $threads -n $n_predict ${filteredArguments[*]}"
 | 
			
		||||
  echo "$command"
 | 
			
		||||
  eval "$command"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function gptneox {
 | 
			
		||||
  command="$lib_dir/main-gptneox_$avx_flag ${filteredArguments[*]}"
 | 
			
		||||
  command="$lib_dir/main-gptneox_$avx_flag -t $threads -n $n_predict ${filteredArguments[*]}"
 | 
			
		||||
  echo "$command"
 | 
			
		||||
  eval "$command"
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -57,10 +62,18 @@ while [[ $# -gt 0 ]]; do
 | 
			
		|||
    filteredArguments+=("'$1'")
 | 
			
		||||
    shift
 | 
			
		||||
    ;;
 | 
			
		||||
  -x | --model_family)
 | 
			
		||||
  -x | --model_family | --model-family)
 | 
			
		||||
    model_family="$2"
 | 
			
		||||
    shift 2
 | 
			
		||||
    ;;
 | 
			
		||||
  -t | --threads)
 | 
			
		||||
    threads="$2"
 | 
			
		||||
    shift 2
 | 
			
		||||
    ;;
 | 
			
		||||
  -n | --n_predict | --n-predict)
 | 
			
		||||
    n_predict="$2"
 | 
			
		||||
    shift 2
 | 
			
		||||
    ;;
 | 
			
		||||
  *)
 | 
			
		||||
    filteredArguments+=("'$1'")
 | 
			
		||||
    shift
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,31 +1,41 @@
 | 
			
		|||
$llm_dir = (Split-Path -Parent (python -c "import bigdl.llm;print(bigdl.llm.__file__)"))
 | 
			
		||||
$lib_dir = Join-Path $llm_dir "libs"
 | 
			
		||||
 | 
			
		||||
$model_family = ""
 | 
			
		||||
$threads = 8
 | 
			
		||||
$n_predict = 128
 | 
			
		||||
 | 
			
		||||
# Function to display help message
 | 
			
		||||
function Display-Help {
 | 
			
		||||
function Display-Help
 | 
			
		||||
{
 | 
			
		||||
    Write-Host "usage: ./llm-cli.ps1 -x MODEL_FAMILY [-h] [args]"
 | 
			
		||||
    Write-Host ""
 | 
			
		||||
    Write-Host "options:"
 | 
			
		||||
    Write-Host "  -h, --help           show this help message"
 | 
			
		||||
    Write-Host "  -x, --model_family {llama,bloom,gptneox}"
 | 
			
		||||
    Write-Host "                       family name of model"
 | 
			
		||||
    Write-Host "  -t N, --threads N    number of threads to use during computation (default: 8)"
 | 
			
		||||
    Write-Host "  -n N, --n_predict N  number of tokens to predict (default: 128, -1 = infinity)"
 | 
			
		||||
    Write-Host "  args                 parameters passed to the specified model function"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function llama {
 | 
			
		||||
    $command = "$lib_dir/main-llama.exe $filteredArguments"
 | 
			
		||||
function llama
 | 
			
		||||
{
 | 
			
		||||
    $command = "$lib_dir/main-llama.exe -t $threads -n $n_predict $filteredArguments"
 | 
			
		||||
    Write-Host "$command"
 | 
			
		||||
    Invoke-Expression $command
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function bloom {
 | 
			
		||||
    $command = "$lib_dir/main-bloom.exe $filteredArguments"
 | 
			
		||||
function bloom
 | 
			
		||||
{
 | 
			
		||||
    $command = "$lib_dir/main-bloom.exe -t $threads -n $n_predict $filteredArguments"
 | 
			
		||||
    Write-Host "$command"
 | 
			
		||||
    Invoke-Expression $command
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function gptneox {
 | 
			
		||||
    $command = "$lib_dir/main-gptneox.exe $filteredArguments"
 | 
			
		||||
function gptneox
 | 
			
		||||
{
 | 
			
		||||
    $command = "$lib_dir/main-gptneox.exe -t $threads -n $n_predict $filteredArguments"
 | 
			
		||||
    Write-Host "$command"
 | 
			
		||||
    Invoke-Expression $command
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -33,22 +43,42 @@ function gptneox {
 | 
			
		|||
# Remove model_family/x parameter
 | 
			
		||||
$filteredArguments = @()
 | 
			
		||||
for ($i = 0; $i -lt $args.Length; $i++) {
 | 
			
		||||
    if ($args[$i] -eq '--model_family' -or $args[$i] -eq '-x') {
 | 
			
		||||
        if ($i + 1 -lt $args.Length -and $args[$i + 1] -notlike '-*') {
 | 
			
		||||
    if ($args[$i] -eq '--model_family' -or $args[$i] -eq '--model-family' -or $args[$i] -eq '-x')
 | 
			
		||||
    {
 | 
			
		||||
        if ($i + 1 -lt $args.Length -and $args[$i + 1] -notlike '-*')
 | 
			
		||||
        {
 | 
			
		||||
            $i++
 | 
			
		||||
            $model_family = $args[$i]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    else {
 | 
			
		||||
    elseif ($args[$i] -eq '--threads' -or $args[$i] -eq '-t')
 | 
			
		||||
    {
 | 
			
		||||
        $i++
 | 
			
		||||
        $threads = $args[$i]
 | 
			
		||||
    }
 | 
			
		||||
    elseif ($args[$i] -eq '--n_predict' -or $args[$i] -eq '--n-predict' -or $args[$i] -eq '-n')
 | 
			
		||||
    {
 | 
			
		||||
        $i++
 | 
			
		||||
        $n_predict = $args[$i]
 | 
			
		||||
    }
 | 
			
		||||
    else
 | 
			
		||||
    {
 | 
			
		||||
        $filteredArguments += "`'" + $args[$i] + "`'"
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Perform actions based on the model_family
 | 
			
		||||
switch ($model_family) {
 | 
			
		||||
    "llama" { llama }
 | 
			
		||||
    "bloom" { bloom }
 | 
			
		||||
    "gptneox" { gptneox }
 | 
			
		||||
switch ($model_family)
 | 
			
		||||
{
 | 
			
		||||
    "llama" {
 | 
			
		||||
        llama
 | 
			
		||||
    }
 | 
			
		||||
    "bloom" {
 | 
			
		||||
        bloom
 | 
			
		||||
    }
 | 
			
		||||
    "gptneox" {
 | 
			
		||||
        gptneox
 | 
			
		||||
    }
 | 
			
		||||
    default {
 | 
			
		||||
        Write-Host "Invalid model_family: $model_family"
 | 
			
		||||
        Display-Help
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue