Skip to content

Commit 1792346

Browse files
Add run_inference_server.py for Running llama.cpp Built-in Server (ggml-org#204)
* Update CMakeLists.txt I added a CMake option to compile the Llama.cpp server. This update allows us to easily build and deploy the server using BitNet * Create run_inference_server.py same as run_inference, but for use with llama.cpp's built in server, for some extra comfort In particular: - The build directory is determined based on whether the system is running on Windows or not. - A list of arguments (`--model`, `-m` etc.) is created. - The main argument list is parsed and passed to the `subprocess.run()` method to execute the system command.
1 parent c17d1c5 commit 1792346

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ endif()
3939
find_package(Threads REQUIRED)
4040

4141
add_subdirectory(src)
42+
set(LLAMA_BUILD_SERVER ON CACHE BOOL "Build llama.cpp server" FORCE)
4243
add_subdirectory(3rdparty/llama.cpp)
4344

4445
# install
@@ -74,4 +75,4 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake
7475
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama)
7576

7677
set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h)
77-
install(TARGETS llama LIBRARY PUBLIC_HEADER)
78+
install(TARGETS llama LIBRARY PUBLIC_HEADER)

run_inference_server.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import os
2+
import sys
3+
import signal
4+
import platform
5+
import argparse
6+
import subprocess
7+
8+
def run_command(command, shell=False):
9+
"""Run a system command and ensure it succeeds."""
10+
try:
11+
subprocess.run(command, shell=shell, check=True)
12+
except subprocess.CalledProcessError as e:
13+
print(f"Error occurred while running command: {e}")
14+
sys.exit(1)
15+
16+
def run_server():
17+
build_dir = "build"
18+
if platform.system() == "Windows":
19+
server_path = os.path.join(build_dir, "bin", "Release", "llama-server.exe")
20+
if not os.path.exists(server_path):
21+
server_path = os.path.join(build_dir, "bin", "llama-server")
22+
else:
23+
server_path = os.path.join(build_dir, "bin", "llama-server")
24+
25+
command = [
26+
f'{server_path}',
27+
'-m', args.model,
28+
'-c', str(args.ctx_size),
29+
'-t', str(args.threads),
30+
'-n', str(args.n_predict),
31+
'-ngl', '0',
32+
'--temp', str(args.temperature),
33+
'--host', args.host,
34+
'--port', str(args.port),
35+
'-cb' # Enable continuous batching
36+
]
37+
38+
if args.prompt:
39+
command.extend(['-p', args.prompt])
40+
41+
# Note: -cnv flag is removed as it's not supported by the server
42+
43+
print(f"Starting server on {args.host}:{args.port}")
44+
run_command(command)
45+
46+
def signal_handler(sig, frame):
47+
print("Ctrl+C pressed, shutting down server...")
48+
sys.exit(0)
49+
50+
if __name__ == "__main__":
51+
signal.signal(signal.SIGINT, signal_handler)
52+
53+
parser = argparse.ArgumentParser(description='Run llama.cpp server')
54+
parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf")
55+
parser.add_argument("-p", "--prompt", type=str, help="System prompt for the model", required=False)
56+
parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict", required=False, default=4096)
57+
parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
58+
parser.add_argument("-c", "--ctx-size", type=int, help="Size of the context window", required=False, default=2048)
59+
parser.add_argument("--temperature", type=float, help="Temperature for sampling", required=False, default=0.8)
60+
parser.add_argument("--host", type=str, help="IP address to listen on", required=False, default="127.0.0.1")
61+
parser.add_argument("--port", type=int, help="Port to listen on", required=False, default=8080)
62+
63+
args = parser.parse_args()
64+
run_server()

0 commit comments

Comments
 (0)