@@ -4,13 +4,14 @@ template <size_t split_d_inner, size_t d_conv>
4
4
static __global__ void ssm_conv_f32 (const float * __restrict__ src0, const float * __restrict__ src1,
5
5
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
6
6
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
7
- const int nc, const int ncs, const int nr, const int n_t , const int n_s) {
7
+ const int64_t n_t ) {
8
+ GGML_UNUSED (src0_nb0);
8
9
const int tid = threadIdx .x ;
9
10
const int bidx = blockIdx .x ;
10
11
const int bidy = blockIdx .y ;
11
12
12
- const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
13
- const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
13
+ const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
14
+ const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
14
15
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
15
16
16
17
const int stride_x = src0_nb1 / sizeof (float );
@@ -21,43 +22,42 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
21
22
float w[d_conv] = { 0 .0f };
22
23
23
24
#pragma unroll
24
- for (int j = 0 ; j < d_conv; j++) {
25
+ for (size_t j = 0 ; j < d_conv; j++) {
25
26
w[j] = w_block[tid * stride_w + j];
26
27
}
27
28
28
- for (int i = 0 ; i < n_t ; i++) {
29
+ for (int64_t i = 0 ; i < n_t ; i++) {
29
30
float sumf = 0 .0f ;
30
31
31
32
if (i == 0 ) {
32
- for (int j = 0 ; j < d_conv; j++) {
33
+ for (size_t j = 0 ; j < d_conv; j++) {
33
34
x[j] = x_block[tid * stride_x + j];
34
35
}
35
36
} else {
36
37
x[(i - 1 ) % d_conv] = x_block[tid * stride_x + i + d_conv - 1 ];
37
38
}
38
39
39
40
#pragma unroll
40
- for (int j = 0 ; j < d_conv; j++) {
41
+ for (size_t j = 0 ; j < d_conv; j++) {
41
42
sumf += x[(i + j) % d_conv] * w[j];
42
43
}
43
44
y_block[i * stride_y + tid] = sumf;
44
45
}
45
46
}
46
47
47
- template <size_t split_d_inner, size_t d_conv, size_t split_n_t >
48
+ template <size_t split_d_inner, size_t d_conv, int64_t split_n_t >
48
49
static __global__ void ssm_conv_long_token_f32 (const float * __restrict__ src0, const float * __restrict__ src1,
49
50
const int src0_nb0, const int src0_nb1, const int src0_nb2,
50
51
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
51
- const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
52
- const int nr, const int n_t , const int n_s) {
52
+ const int dst_nb1, const int dst_nb2, const int64_t n_t ) {
53
53
const int tid = threadIdx .x ;
54
54
const int bidx = blockIdx .x ;
55
55
const int bidy = blockIdx .y ;
56
56
const int bidz = blockIdx .z ;
57
57
58
- const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
58
+ const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
59
59
bidz * split_n_t * src0_nb0);
60
- const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
60
+ const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
61
61
float * y_block =
62
62
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
63
63
@@ -69,25 +69,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
69
69
float w[d_conv] = { 0 .0f };
70
70
71
71
#pragma unroll
72
- for (int j = 0 ; j < d_conv; j++) {
72
+ for (size_t j = 0 ; j < d_conv; j++) {
73
73
w[j] = w_block[tid * stride_w + j];
74
74
}
75
75
76
76
#pragma unroll
77
- for (int i = 0 ; i < split_n_t ; i++) {
77
+ for (int64_t i = 0 ; i < split_n_t ; i++) {
78
78
if (bidz * split_n_t + i < n_t ) {
79
79
float sumf = 0 .0f ;
80
80
81
81
if (i == 0 ) {
82
- for (int j = 0 ; j < d_conv; j++) {
82
+ for (size_t j = 0 ; j < d_conv; j++) {
83
83
x[j] = x_block[tid * stride_x + j];
84
84
}
85
85
} else {
86
86
x[(i - 1 ) % d_conv] = x_block[tid * stride_x + i + d_conv - 1 ];
87
87
}
88
88
89
89
#pragma unroll
90
- for (int j = 0 ; j < d_conv; j++) {
90
+ for (size_t j = 0 ; j < d_conv; j++) {
91
91
sumf += x[(i + j) % d_conv] * w[j];
92
92
}
93
93
y_block[i * stride_y + tid] = sumf;
@@ -97,27 +97,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
97
97
98
98
static void ssm_conv_f32_cuda (const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
99
99
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
100
- const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t ,
101
- const int n_s, cudaStream_t stream) {
100
+ const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t ,
101
+ const int64_t n_s, cudaStream_t stream) {
102
102
const int threads = 128 ;
103
103
GGML_ASSERT (nr % threads == 0 );
104
104
105
105
if (n_t <= 32 ) {
106
106
const dim3 blocks (n_s, (nr + threads - 1 ) / threads, 1 );
107
107
if (nc == 4 ) {
108
108
ssm_conv_f32<threads, 4 ><<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109
- dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t ,
110
- n_s);
109
+ dst, dst_nb0, dst_nb1, dst_nb2, n_t );
111
110
} else {
112
111
GGML_ABORT (" Only support kernel size = 4 now." );
113
112
}
114
113
} else {
115
114
if (nc == 4 ) {
116
- const int split_n_t = 32 ;
117
- dim3 blocks (n_s, (nr + threads - 1 ) / threads, (n_t + split_n_t - 1 ) / split_n_t );
118
- ssm_conv_long_token_f32<threads, 4 , split_n_t >
119
- <<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
120
- dst_nb1, dst_nb2, nc, ncs, nr, n_t , n_s);
115
+ const int64_t split_n_t = 32 ;
116
+ dim3 blocks (n_s, (nr + threads - 1 ) / threads, (n_t + split_n_t - 1 ) / split_n_t );
117
+ ssm_conv_long_token_f32<threads, 4 , split_n_t ><<<blocks, threads, 0 , stream>>> (
118
+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t );
121
119
} else {
122
120
GGML_ABORT (" Only support kernel size = 4 right now." );
123
121
}
@@ -128,11 +126,10 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
128
126
const struct ggml_tensor * src0 = dst->src [0 ]; // conv_x
129
127
const struct ggml_tensor * src1 = dst->src [1 ]; // conv1d.weight
130
128
131
- const int nc = src1->ne [0 ]; // d_conv
132
- const int ncs = src0->ne [0 ]; // d_conv - 1 + n_t
133
- const int nr = src0->ne [1 ]; // d_inner
134
- const int n_t = dst->ne [1 ]; // tokens per sequence
135
- const int n_s = dst->ne [2 ]; // number of sequences in the batch
129
+ const int64_t nc = src1->ne [0 ]; // d_conv
130
+ const int64_t nr = src0->ne [1 ]; // d_inner
131
+ const int64_t n_t = dst->ne [1 ]; // tokens per sequence
132
+ const int64_t n_s = dst->ne [2 ]; // number of sequences in the batch
136
133
137
134
GGML_ASSERT (dst->ne [0 ] == nr);
138
135
GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
@@ -147,5 +144,5 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
147
144
GGML_ASSERT (src0->type == GGML_TYPE_F32);
148
145
GGML_ASSERT (dst->type == GGML_TYPE_F32);
149
146
ssm_conv_f32_cuda (src0_d, src1_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src1->nb [1 ], dst_d, dst->nb [0 ], dst->nb [1 ],
150
- dst->nb [2 ], nc, ncs, nr, n_t , n_s, stream);
147
+ dst->nb [2 ], nc, nr, n_t , n_s, stream);
151
148
}
0 commit comments