Skip to content

Commit be9a78c

Browse files
JohannesGaesslermglambda
authored andcommitted
RoPE: fix back, CUDA support for back + noncont. (ggml-org#11240)
* RoPE: fix back, CUDA support for back + noncont. * fix comments reg. non-cont. RoPE support [no-ci]
1 parent 2e3edad commit be9a78c

File tree

9 files changed

+270
-259
lines changed

9 files changed

+270
-259
lines changed

ggml/include/ggml.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,7 @@ extern "C" {
15001500

15011501
// rotary position embedding backward, i.e compute dx from dy
15021502
// a - dy
1503-
GGML_API struct ggml_tensor * ggml_rope_back(
1503+
GGML_API struct ggml_tensor * ggml_rope_ext_back(
15041504
struct ggml_context * ctx,
15051505
struct ggml_tensor * a, // gradients of ggml_rope result
15061506
struct ggml_tensor * b, // positions
@@ -1515,6 +1515,23 @@ extern "C" {
15151515
float beta_fast,
15161516
float beta_slow);
15171517

1518+
GGML_API struct ggml_tensor * ggml_rope_multi_back(
1519+
struct ggml_context * ctx,
1520+
struct ggml_tensor * a,
1521+
struct ggml_tensor * b,
1522+
struct ggml_tensor * c,
1523+
int n_dims,
1524+
int sections[4],
1525+
int mode,
1526+
int n_ctx_orig,
1527+
float freq_base,
1528+
float freq_scale,
1529+
float ext_factor,
1530+
float attn_factor,
1531+
float beta_fast,
1532+
float beta_slow);
1533+
1534+
15181535
// clamp
15191536
// in-place, returns view(a)
15201537
GGML_API struct ggml_tensor * ggml_clamp(

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13668,6 +13668,7 @@ struct ggml_cplan ggml_graph_plan(
1366813668
} break;
1366913669
case GGML_OP_SOFT_MAX:
1367013670
case GGML_OP_ROPE:
13671+
case GGML_OP_ROPE_BACK:
1367113672
{
1367213673
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1367313674
} break;

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,6 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
403403
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
404404
case GGML_OP_MUL_MAT:
405405
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
406-
case GGML_OP_ROPE_BACK:
407-
return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
408406
case GGML_OP_IM2COL_BACK:
409407
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
410408
case GGML_OP_OUT_PROD:

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2141,6 +2141,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21412141
case GGML_OP_ROPE:
21422142
ggml_cuda_op_rope(ctx, dst);
21432143
break;
2144+
case GGML_OP_ROPE_BACK:
2145+
ggml_cuda_op_rope_back(ctx, dst);
2146+
break;
21442147
case GGML_OP_IM2COL:
21452148
ggml_cuda_op_im2col(ctx, dst);
21462149
break;
@@ -3025,7 +3028,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30253028
case GGML_OP_SOFT_MAX:
30263029
return true;
30273030
case GGML_OP_ROPE:
3028-
return ggml_is_contiguous(op->src[0]);
3031+
case GGML_OP_ROPE_BACK: {
3032+
const size_t ts = ggml_type_size(op->src[0]->type);
3033+
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
3034+
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
3035+
}
30293036
case GGML_OP_IM2COL:
30303037
case GGML_OP_POOL_2D:
30313038
case GGML_OP_SUM:
@@ -3081,6 +3088,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
30813088
return op->ne[1];
30823089
case GGML_OP_MUL_MAT_ID:
30833090
case GGML_OP_ROPE:
3091+
case GGML_OP_ROPE_BACK:
30843092
return op->ne[2];
30853093
default:
30863094
return ggml_nrows(op);

0 commit comments

Comments
 (0)