Skip to content

Commit d11f359

Browse files
committed
diarization : try to cluster embedings from last encoder layer
1 parent d5d7769 commit d11f359

File tree

2 files changed

+108
-37
lines changed

2 files changed

+108
-37
lines changed

ggml.c

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8652,16 +8652,16 @@ void ggml_svd_reduce_dims(
86528652
}
86538653

86548654
// normalize U
8655-
for (int i = 0; i < n; ++i) {
8656-
double sum = 0.0;
8657-
for (int j = 0; j < m; ++j) {
8658-
sum += U[i * m + j] * U[i * m + j];
8659-
}
8660-
sum = sqrt(sum);
8661-
for (int j = 0; j < m; ++j) {
8662-
U[i * m + j] /= sum*sqrt((double) m);
8663-
}
8664-
}
8655+
//for (int i = 0; i < n; ++i) {
8656+
// double sum = 0.0;
8657+
// for (int j = 0; j < m; ++j) {
8658+
// sum += U[i * m + j] * U[i * m + j];
8659+
// }
8660+
// sum = sqrt(sum);
8661+
// for (int j = 0; j < m; ++j) {
8662+
// U[i * m + j] /= sum*sqrt((double) m);
8663+
// }
8664+
//}
86658665

86668666
// print U
86678667
//printf("U:\n");
@@ -8675,9 +8675,10 @@ void ggml_svd_reduce_dims(
86758675
//printf("\n");
86768676

86778677

8678+
printf("n = %d, m = %d, nd = %d\n", n, m, nd);
86788679
// project A0 onto U
86798680
for (int i = 0; i < n; ++i) {
8680-
for (int j = 0; j < n; ++j) {
8681+
for (int j = 0; j < nd; ++j) {
86818682
A[i * nd + j] = 0.0f;
86828683
for (int k = 0; k < m; ++k) {
86838684
A[i * nd + j] += A0[i * m + k] * U[j * m + k];

whisper.cpp

Lines changed: 96 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,8 @@ struct whisper_context {
603603
// [EXPERIMENTAL] speed-up techniques
604604
int32_t exp_n_audio_ctx; // 0 - use default
605605

606+
std::vector<float> audio_embd;
607+
606608
void use_buf(struct ggml_context * ctx, int i) {
607609
#if defined(WHISPER_USE_SCRATCH)
608610
size_t last_size = 0;
@@ -1723,17 +1725,35 @@ static bool whisper_encode(
17231725
}
17241726

17251727
// cur
1728+
//{
1729+
// printf("ne0 = %d\n", cur->ne[0]);
1730+
// printf("ne1 = %d\n", cur->ne[1]);
1731+
// for (int i = 0; i < 10; ++i) {
1732+
// printf("%8.4f ", ((float *)(cur->data))[i]);
1733+
// }
1734+
// printf("... ");
1735+
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1736+
// printf("%8.4f ", ((float *)(cur->data))[i]);
1737+
// }
1738+
// printf("\n");
1739+
//}
1740+
17261741
{
1727-
//printf("ne0 = %d\n", cur->ne[0]);
1728-
//printf("ne1 = %d\n", cur->ne[1]);
1729-
//for (int i = 0; i < 10; ++i) {
1730-
// printf("%8.4f ", ((float *)(cur->data))[i]);
1731-
//}
1732-
//printf("... ");
1733-
//for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1734-
// printf("%8.4f ", ((float *)(cur->data))[i]);
1735-
//}
1736-
//printf("\n");
1742+
//const int i0 = std::min(mel_offset, mel_inp.n_len);
1743+
//const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1744+
const int i0 = 0;
1745+
const int i1 = cur->ne[1];
1746+
1747+
//printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]);
1748+
1749+
wctx.audio_embd.clear();
1750+
wctx.audio_embd.resize(cur->ne[0], 0.0f);
1751+
for (int j = 0; j < cur->ne[0]; ++j) {
1752+
for (int i = i0; i < i1; ++i) {
1753+
wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j];
1754+
}
1755+
wctx.audio_embd[j] /= (i1 - i0);
1756+
}
17371757
}
17381758

17391759
// pre-compute cross-attention memory
@@ -4838,6 +4858,28 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
48384858
const int n_state = ctx->model.hparams.n_audio_state;
48394859
const int n_layer = ctx->model.hparams.n_audio_layer;
48404860

4861+
#if 1
4862+
// use the last layer of the encoder
4863+
{
4864+
std::vector<float> embd(n_segments*n_state);
4865+
4866+
for (int i = 0; i < n_segments; ++i) {
4867+
const auto & segment_i = ctx->result_all[i];
4868+
printf("%s: segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
4869+
4870+
ctx->mel.n_len = segment_i.t1;
4871+
whisper_encode(*ctx, segment_i.t0, 7, true);
4872+
4873+
for (int j = 0; j < n_state; ++j) {
4874+
embd[i*n_state + j] = ctx->audio_embd[j];
4875+
}
4876+
}
4877+
4878+
const int n_features = std::min(4, n_segments);
4879+
4880+
ggml_svd_reduce_dims(n_state, n_segments, embd.data(), n_features);
4881+
#else
4882+
// use cross kv cache of various layers
48414883
for (int il = 0; il < n_layer; ++il) {
48424884
std::vector<float> embd(n_segments*n_ctx*n_state);
48434885

@@ -4856,9 +4898,10 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
48564898
}
48574899
}
48584900

4859-
const int n_features = 64;
4901+
const int n_features = std::min(4, n_segments);
48604902

48614903
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
4904+
#endif
48624905

48634906
std::vector<std::vector<float>> features(n_segments);
48644907

@@ -4927,32 +4970,59 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
49274970
for (int l = 0; l < n_clusters; ++l) {
49284971
//sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
49294972

4930-
// use the euclidean distance
49314973
double d0 = 0.0;
4932-
for (int m = 0; m < n_features; ++m) {
4933-
d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
4934-
}
4935-
d0 = std::sqrt(d0);
4936-
49374974
double d1 = 0.0;
4938-
for (int m = 0; m < n_features; ++m) {
4939-
d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
4940-
}
4941-
d1 = std::sqrt(d1);
49424975

4943-
if (d1 == 0.0) {
4944-
sum += 1.0;
4945-
} else {
4946-
sum += std::pow(d0/d1, 2.0/(1.10 - 1.0));
4976+
// use the euclidean distance
4977+
{
4978+
for (int m = 0; m < n_features; ++m) {
4979+
d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
4980+
}
4981+
d0 = std::sqrt(d0);
4982+
4983+
for (int m = 0; m < n_features; ++m) {
4984+
d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
4985+
}
4986+
d1 = std::sqrt(d1);
49474987
}
4988+
4989+
// use the cosine distance
4990+
//{
4991+
// double dot = 0.0;
4992+
// double norm0 = 0.0;
4993+
// double norm1 = 0.0;
4994+
4995+
// for (int m = 0; m < n_features; ++m) {
4996+
// dot += features[j][m]*centroids[k][m];
4997+
// norm0 += std::pow(features[j][m], 2.0);
4998+
// norm1 += std::pow(centroids[k][m], 2.0);
4999+
// }
5000+
5001+
// d0 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
5002+
5003+
// dot = 0.0;
5004+
// norm0 = 0.0;
5005+
// norm1 = 0.0;
5006+
5007+
// for (int m = 0; m < n_features; ++m) {
5008+
// dot += features[j][m]*centroids[l][m];
5009+
// norm0 += std::pow(features[j][m], 2.0);
5010+
// norm1 += std::pow(centroids[l][m], 2.0);
5011+
// }
5012+
5013+
// d1 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
5014+
//}
5015+
5016+
sum += std::pow(d0/d1, 2.0/(1.15 - 1.0));
49485017
}
49495018

4950-
membership[j][k] = 1.0/sum;
5019+
membership[j][k] = sum == 0.0 ? 0.0 : 1.0/sum;
49515020
}
49525021
}
49535022

49545023
// print the membership
49555024
if (i == niter - 1) {
5025+
//{
49565026
for (int i = 0; i < n_segments; ++i) {
49575027
printf("%s: membership %3d: ", __func__, i);
49585028
for (int j = 0; j < n_clusters; ++j) {

0 commit comments

Comments
 (0)