File tree Expand file tree Collapse file tree 1 file changed +18
-5
lines changed Expand file tree Collapse file tree 1 file changed +18
-5
lines changed Original file line number Diff line number Diff line change @@ -998,6 +998,15 @@ def set_cache(self, cache: Optional[BaseLlamaCache]):
998
998
"""
999
999
self .cache = cache
1000
1000
1001
+ def set_seed (self , seed : int ):
1002
+ """Set the random seed.
1003
+
1004
+ Args:
1005
+ seed: The random seed.
1006
+ """
1007
+ assert self ._ctx .ctx is not None
1008
+ llama_cpp .llama_set_rng_seed (self ._ctx .ctx , seed )
1009
+
1001
1010
def reset (self ):
1002
1011
"""Reset the model state."""
1003
1012
self .n_tokens = 0
@@ -1318,10 +1327,14 @@ def _create_completion(
1318
1327
completion_tokens : List [int ] = []
1319
1328
# Add blank space to start of prompt to match OG llama tokenizer
1320
1329
prompt_tokens : List [int ] = (
1321
- self .tokenize (prompt .encode ("utf-8" ), special = True )
1322
- if prompt != ""
1323
- else [self .token_bos ()]
1324
- ) if isinstance (prompt , str ) else prompt
1330
+ (
1331
+ self .tokenize (prompt .encode ("utf-8" ), special = True )
1332
+ if prompt != ""
1333
+ else [self .token_bos ()]
1334
+ )
1335
+ if isinstance (prompt , str )
1336
+ else prompt
1337
+ )
1325
1338
text : bytes = b""
1326
1339
returned_tokens : int = 0
1327
1340
stop = (
@@ -1374,7 +1387,7 @@ def _create_completion(
1374
1387
except KeyError :
1375
1388
if self .verbose :
1376
1389
print ("Llama._create_completion: cache miss" , file = sys .stderr )
1377
-
1390
+
1378
1391
if seed is not None :
1379
1392
self ._ctx .set_rng_seed (seed )
1380
1393
You can’t perform that action at this time.
0 commit comments