Skip to content

Commit a8cbab2

Browse files
PABannierggerganov
authored andcommitted
ggml: add GGML_SET Metal kernel + i32 CPU kernel (ggml/1037)
* implemented cpu kernel * add i32 test cases in test-backend-ops * typedef `ggml_metal_kargs_set` * implemented `kernel_set` * memcpy
1 parent c2082d9 commit a8cbab2

File tree

5 files changed

+206
-1
lines changed

5 files changed

+206
-1
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,10 @@ struct ggml_compute_state {
13741374

13751375
inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
13761376
inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1377-
inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1377+
1378+
inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1379+
inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1380+
13781381
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
13791382
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
13801383
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
@@ -8248,6 +8251,77 @@ static void ggml_compute_forward_set_f32(
82488251
}
82498252
}
82508253

8254+
static void ggml_compute_forward_set_i32(
8255+
const struct ggml_compute_params * params,
8256+
struct ggml_tensor * dst) {
8257+
8258+
const struct ggml_tensor * src0 = dst->src[0];
8259+
const struct ggml_tensor * src1 = dst->src[1];
8260+
8261+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
8262+
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
8263+
8264+
// view src0 and dst with these strides and data offset inbytes during set
8265+
// nb0 is implicitly element_size because src0 and dst are contiguous
8266+
size_t nb1 = ((int32_t *) dst->op_params)[0];
8267+
size_t nb2 = ((int32_t *) dst->op_params)[1];
8268+
size_t nb3 = ((int32_t *) dst->op_params)[2];
8269+
size_t offset = ((int32_t *) dst->op_params)[3];
8270+
bool inplace = (bool) ((int32_t *) dst->op_params)[4];
8271+
8272+
if (!inplace) {
8273+
if (params->ith == 0) {
8274+
// memcpy needs to be synchronized across threads to avoid race conditions.
8275+
// => do it in INIT phase
8276+
memcpy(
8277+
((char *) dst->data),
8278+
((char *) src0->data),
8279+
ggml_nbytes(dst));
8280+
}
8281+
ggml_barrier(params->threadpool);
8282+
}
8283+
8284+
const int ith = params->ith;
8285+
const int nth = params->nth;
8286+
8287+
const int nr = ggml_nrows(src1);
8288+
const int nc = src1->ne[0];
8289+
8290+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
8291+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
8292+
8293+
// src0 and dst as viewed during set
8294+
const size_t nb0 = ggml_element_size(src0);
8295+
8296+
const int im0 = (ne10 == 0 ? 0 : ne10-1);
8297+
const int im1 = (ne11 == 0 ? 0 : ne11-1);
8298+
const int im2 = (ne12 == 0 ? 0 : ne12-1);
8299+
const int im3 = (ne13 == 0 ? 0 : ne13-1);
8300+
8301+
GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
8302+
8303+
GGML_ASSERT(nb10 == sizeof(int32_t));
8304+
8305+
// rows per thread
8306+
const int dr = (nr + nth - 1)/nth;
8307+
8308+
// row range for this thread
8309+
const int ir0 = dr*ith;
8310+
const int ir1 = MIN(ir0 + dr, nr);
8311+
8312+
for (int ir = ir0; ir < ir1; ++ir) {
8313+
// src0 and dst are viewed with shape of src1 and offset
8314+
// => same indices
8315+
const int i3 = ir/(ne12*ne11);
8316+
const int i2 = (ir - i3*ne12*ne11)/ne11;
8317+
const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
8318+
8319+
ggml_vec_cpy_i32(nc,
8320+
(int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
8321+
(int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8322+
}
8323+
}
8324+
82518325
static void ggml_compute_forward_set(
82528326
const struct ggml_compute_params * params,
82538327
struct ggml_tensor * dst) {
@@ -8259,6 +8333,10 @@ static void ggml_compute_forward_set(
82598333
{
82608334
ggml_compute_forward_set_f32(params, dst);
82618335
} break;
8336+
case GGML_TYPE_I32:
8337+
{
8338+
ggml_compute_forward_set_i32(params, dst);
8339+
} break;
82628340
case GGML_TYPE_F16:
82638341
case GGML_TYPE_BF16:
82648342
case GGML_TYPE_Q4_0:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,21 @@ typedef struct {
102102
uint64_t nb3;
103103
} ggml_metal_kargs_cpy;
104104

105+
typedef struct {
106+
int64_t ne10;
107+
int64_t ne11;
108+
int64_t ne12;
109+
uint64_t nb10;
110+
uint64_t nb11;
111+
uint64_t nb12;
112+
uint64_t nb13;
113+
uint64_t nb1;
114+
uint64_t nb2;
115+
uint64_t nb3;
116+
uint64_t offs;
117+
bool inplace;
118+
} ggml_metal_kargs_set;
119+
105120
typedef struct {
106121
int32_t ne00;
107122
int32_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
372372
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
373373
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
374374
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
375+
GGML_METAL_KERNEL_TYPE_SET_I32,
376+
GGML_METAL_KERNEL_TYPE_SET_F32,
375377
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
376378
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
377379
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
@@ -940,6 +942,8 @@ @implementation GGMLMetalClass
940942
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
941943
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
942944
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
945+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
946+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
943947
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
944948
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
945949
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
@@ -1159,6 +1163,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
11591163
return false;
11601164
};
11611165
}
1166+
case GGML_OP_SET:
1167+
{
1168+
switch (op->src[0]->type) {
1169+
case GGML_TYPE_F32:
1170+
case GGML_TYPE_I32:
1171+
return true;
1172+
default:
1173+
return false;
1174+
};
1175+
}
11621176
case GGML_OP_DIAG_MASK_INF:
11631177
case GGML_OP_GET_ROWS:
11641178
{
@@ -3824,6 +3838,68 @@ static void ggml_metal_encode_node(
38243838

38253839
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
38263840
} break;
3841+
case GGML_OP_SET:
3842+
{
3843+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
3844+
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
3845+
3846+
// src0 and dst as viewed during set
3847+
const size_t dst_nb0 = ggml_element_size(src0);
3848+
3849+
const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
3850+
const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
3851+
const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
3852+
const size_t offset = ((int32_t *) dst->op_params)[3];
3853+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
3854+
3855+
if (!inplace) {
3856+
memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst));
3857+
}
3858+
3859+
const int im0 = (ne10 == 0 ? 0 : ne10-1);
3860+
const int im1 = (ne11 == 0 ? 0 : ne11-1);
3861+
const int im2 = (ne12 == 0 ? 0 : ne12-1);
3862+
const int im3 = (ne13 == 0 ? 0 : ne13-1);
3863+
3864+
GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst));
3865+
3866+
id<MTLComputePipelineState> pipeline = nil;
3867+
3868+
switch (src0t) {
3869+
case GGML_TYPE_F32:
3870+
GGML_ASSERT(nb10 == sizeof(float));
3871+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
3872+
case GGML_TYPE_I32:
3873+
GGML_ASSERT(nb10 == sizeof(int32_t));
3874+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
3875+
default: GGML_ABORT("fatal error");
3876+
}
3877+
3878+
ggml_metal_kargs_set args = {
3879+
/*.ne10 =*/ ne10,
3880+
/*.ne11 =*/ ne11,
3881+
/*.ne12 =*/ ne12,
3882+
/*.nb10 =*/ nb10,
3883+
/*.nb11 =*/ nb11,
3884+
/*.nb12 =*/ nb12,
3885+
/*.nb13 =*/ nb13,
3886+
/*.nb1 =*/ dst_nb1,
3887+
/*.nb2 =*/ dst_nb2,
3888+
/*.nb3 =*/ dst_nb3,
3889+
/*.offs =*/ offset,
3890+
/*.inplace =*/ inplace,
3891+
};
3892+
3893+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
3894+
3895+
[encoder setComputePipelineState:pipeline];
3896+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3897+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3898+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3899+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3900+
3901+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3902+
} break;
38273903
case GGML_OP_POOL_2D:
38283904
{
38293905
GGML_ASSERT(ggml_is_contiguous(src0));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3927,6 +3927,38 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_
39273927

39283928
#undef FA_TYPES
39293929

3930+
template<typename T>
3931+
kernel void kernel_set(
3932+
constant ggml_metal_kargs_set & args,
3933+
device const char * src0,
3934+
device const char * src1,
3935+
device char * dst,
3936+
uint3 tgpig[[threadgroup_position_in_grid]],
3937+
ushort3 tpitg[[thread_position_in_threadgroup]],
3938+
ushort3 ntg[[threads_per_threadgroup]]) {
3939+
const int i13 = tgpig[2];
3940+
const int i12 = tgpig[1];
3941+
const int i11 = tgpig[0];
3942+
3943+
const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
3944+
3945+
const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
3946+
const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
3947+
const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
3948+
3949+
device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
3950+
3951+
for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
3952+
device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
3953+
dst_data[i10] = (T) src[0];
3954+
}
3955+
}
3956+
3957+
typedef decltype(kernel_set<float>) kernel_set_t;
3958+
3959+
template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
3960+
template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
3961+
39303962
template<typename T0, typename T1>
39313963
kernel void kernel_cpy(
39323964
constant ggml_metal_kargs_cpy & args,

tests/test-backend-ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3521,6 +3521,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
35213521
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
35223522
}
35233523

3524+
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
3525+
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
3526+
}
3527+
35243528
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
35253529
for (ggml_type type_dst : all_types) {
35263530
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));

0 commit comments

Comments
 (0)