@@ -417,39 +417,45 @@ struct llama_buffer {
417
417
size_t size = 0 ;
418
418
419
419
llama_buffer () = default ;
420
+ ~llama_buffer () { resize (0 ); }
421
+ llama_buffer (const llama_buffer& rhs) { *this = rhs; }
422
+ llama_buffer& operator =(const llama_buffer& rhs) {
423
+ resize (rhs.size );
424
+ memcpy (addr, rhs.addr , size);
425
+ return *this ;
426
+ }
427
+ llama_buffer (llama_buffer&& rhs):addr(rhs.addr), size(rhs.size) {
428
+ new (&rhs) llama_buffer ();
429
+ }
430
+ llama_buffer& operator =(llama_buffer&& rhs) {
431
+ this ->~llama_buffer ();
432
+ addr = rhs.addr ;
433
+ size = rhs.size ;
434
+ new (&rhs) llama_buffer ();
435
+ return *this ;
436
+ }
420
437
421
438
void resize (size_t len) {
439
+ size = 0 ;
422
440
#ifdef GGML_USE_METAL
423
441
free (addr);
424
- int result = posix_memalign ((void **) &addr, getpagesize (), len);
425
- if (result == 0 ) {
426
- memset (addr, 0 , len);
427
- }
428
- else {
429
- addr = NULL ;
430
- len = 0 ;
442
+ if (len) {
443
+ int result = posix_memalign ((void **) &addr, getpagesize (), len);
444
+ if (result == 0 ) {
445
+ memset (addr, 0 , len);
446
+ size = len;
447
+ } else {
448
+ addr = NULL ;
449
+ }
431
450
}
432
451
#else
433
452
delete[] addr;
434
- addr = new uint8_t [len];
435
- #endif
436
- size = len;
437
- }
438
-
439
- ~llama_buffer () {
440
- #ifdef GGML_USE_METAL
441
- free (addr);
442
- #else
443
- delete[] addr;
453
+ if (len) {
454
+ addr = new uint8_t [len];
455
+ size = len;
456
+ }
444
457
#endif
445
- addr = NULL ;
446
458
}
447
-
448
- // disable copy and move
449
- llama_buffer (const llama_buffer&) = delete ;
450
- llama_buffer (llama_buffer&&) = delete ;
451
- llama_buffer& operator =(const llama_buffer&) = delete ;
452
- llama_buffer& operator =(llama_buffer&&) = delete ;
453
459
};
454
460
455
461
#ifdef GGML_USE_CUBLAS
@@ -459,44 +465,48 @@ struct llama_ctx_buffer {
459
465
bool is_cuda;
460
466
size_t size = 0 ;
461
467
462
- llama_ctx_buffer () = default ;
463
-
464
- void resize (size_t size) {
468
+ void resize (size_t len) {
465
469
free ();
466
-
467
- addr = (uint8_t *) ggml_cuda_host_malloc (size);
468
- if (addr) {
469
- is_cuda = true ;
470
- }
471
- else {
470
+ addr = (uint8_t *) ggml_cuda_host_malloc (len);
471
+ is_cuda = static_cast <bool >(addr);
472
+ if (!is_cuda) {
472
473
// fall back to pageable memory
473
474
addr = new uint8_t [size];
474
- is_cuda = false ;
475
475
}
476
- this -> size = size ;
476
+ size = len ;
477
477
}
478
478
479
479
void free () {
480
480
if (addr) {
481
481
if (is_cuda) {
482
482
ggml_cuda_host_free (addr);
483
- }
484
- else {
483
+ } else {
485
484
delete[] addr;
486
485
}
487
486
}
488
- addr = NULL ;
487
+ new ( this ) llama_ctx_buffer () ;
489
488
}
490
489
491
- ~llama_ctx_buffer () {
492
- free ();
490
+ llama_ctx_buffer () = default ;
491
+ ~llama_ctx_buffer () { free (); }
492
+ llama_ctx_buffer (const llama_ctx_buffer& rhs) { *this = rhs; }
493
+ llama_ctx_buffer& operator =(const llama_ctx_buffer& rhs) {
494
+ resize (rhs.size );
495
+ memcpy (addr, rhs.addr , size); // cuda memcpy if is_cuda?
496
+ return *this ;
497
+ }
498
+ llama_ctx_buffer (llama_ctx_buffer&& rhs):addr(rhs.addr), is_cuda(rhs.is_cuda), size(rhs.size) {
499
+ new (&rhs) llama_ctx_buffer ();
500
+ }
501
+ llama_ctx_buffer& operator =(llama_ctx_buffer&& rhs) {
502
+ this ->~llama_ctx_buffer ();
503
+ addr = rhs.addr ;
504
+ is_cuda = rhs.is_cuda ;
505
+ size = rhs.size ;
506
+ new (&rhs) llama_ctx_buffer ();
507
+ return *this ;
493
508
}
494
509
495
- // disable copy and move
496
- llama_ctx_buffer (const llama_ctx_buffer&) = delete ;
497
- llama_ctx_buffer (llama_ctx_buffer&&) = delete ;
498
- llama_ctx_buffer& operator =(const llama_ctx_buffer&) = delete ;
499
- llama_ctx_buffer& operator =(llama_ctx_buffer&&) = delete ;
500
510
};
501
511
#else
502
512
typedef llama_buffer llama_ctx_buffer;
0 commit comments