@@ -79,6 +79,10 @@ def _load_shared_library(lib_base_name: str):
79
79
80
80
# llama.h bindings
81
81
82
+ GGML_USE_CUBLAS = hasattr (_lib , "ggml_init_cublas" )
83
+ GGML_CUDA_MAX_DEVICES = ctypes .c_int (16 )
84
+ LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else ctypes .c_int (1 )
85
+
82
86
# #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
83
87
LLAMA_FILE_MAGIC_GGJT = ctypes .c_uint (0x67676A74 )
84
88
# #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
@@ -142,9 +146,12 @@ class llama_token_data_array(Structure):
142
146
143
147
144
148
# struct llama_context_params {
145
- # int n_ctx; // text context
146
- # int n_gpu_layers; // number of layers to store in VRAM
147
- # int seed; // RNG seed, -1 for random
149
+ # int n_ctx; // text context
150
+ # int n_batch; // prompt processing batch size
151
+ # int n_gpu_layers; // number of layers to store in VRAM
152
+ # int main_gpu; // the GPU that is used for scratch and small tensors
153
+ # float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
154
+ # int seed; // RNG seed, -1 for random
148
155
149
156
# bool f16_kv; // use fp16 for KV cache
150
157
# bool logits_all; // the llama_eval() call computes all logits, not just the last one
@@ -153,7 +160,6 @@ class llama_token_data_array(Structure):
153
160
# bool use_mlock; // force system to keep model in RAM
154
161
# bool embedding; // embedding mode only
155
162
156
-
157
163
# // called with a progress value between 0 and 1, pass NULL to disable
158
164
# llama_progress_callback progress_callback;
159
165
# // context pointer passed to the progress callback
@@ -162,7 +168,10 @@ class llama_token_data_array(Structure):
162
168
class llama_context_params (Structure ):
163
169
_fields_ = [
164
170
("n_ctx" , c_int ),
171
+ ("n_batch" , c_int ),
165
172
("n_gpu_layers" , c_int ),
173
+ ("main_gpu" , c_int ),
174
+ ("tensor_split" , c_float * LLAMA_MAX_DEVICES .value ),
166
175
("seed" , c_int ),
167
176
("f16_kv" , c_bool ),
168
177
(
0 commit comments