123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- /*
- ge_train.c
- Jean Marc Valin Feb 2012
- Joint pitch and energy VQ training program
- usage:
- cat GE | ./ge_train 2 1000000 8 > quantized
- The first column is the log2 of the pitch compared to the lowest freq,
- so log2(wo/pi*4000/50) where wo is the frequency your patch outputs. The
- second column is the energy in dB, so 10*log10(1e-4+E)
- */
- /*
- Copyright (C) 2012 Jean-Marc Valin
- All rights reserved.
- This program is free software; you can redistribute it and/or modify
- it under the terms of the GNU Lesser General Public License version 2, as
- published by the Free Software Foundation. This program is
- distributed in the hope that it will be useful, but WITHOUT ANY
- WARRANTY; without even the implied warranty of MERCHANTABILITY or
- FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
- License for more details.
- You should have received a copy of the GNU Lesser General Public License
- along with this program; if not, see <http://www.gnu.org/licenses/>.
- */
- #include <valgrind/memcheck.h>
- #include <stdlib.h>
- #include <stdio.h>
- #include <math.h>
- #define MIN(a,b) ((a)<(b)?(a):(b))
- //#define COEF 0.0
- static float COEF[2] = {0.8, 0.9};
- //static float COEF[2] = {0.0, 0.};
- #define MAX_ENTRIES 16384
- void compute_weights2(const float *x, const float *xp, float *w, int ndim)
- {
- w[0] = 30;
- w[1] = 1;
- if (x[1]<0)
- {
- w[0] *= .6;
- w[1] *= .3;
- }
- if (x[1]<-10)
- {
- w[0] *= .3;
- w[1] *= .3;
- }
- /* Higher weight if pitch is stable */
- if (fabs(x[0]-xp[0])<.2)
- {
- w[0] *= 2;
- w[1] *= 1.5;
- } else if (fabs(x[0]-xp[0])>.5) /* Lower if not stable */
- {
- w[0] *= .5;
- }
- /* Lower weight for low energy */
- if (x[1] < xp[1]-10)
- {
- w[1] *= .5;
- }
- if (x[1] < xp[1]-20)
- {
- w[1] *= .5;
- }
- //w[0] = 30;
- //w[1] = 1;
-
- /* Square the weights because it's applied on the squared error */
- w[0] *= w[0];
- w[1] *= w[1];
- }
- int find_nearest_weighted(const float *codebook, int nb_entries, float *x, const float *w, int ndim)
- {
- int i, j;
- float min_dist = 1e15;
- int nearest = 0;
-
- for (i=0;i<nb_entries;i++)
- {
- float dist=0;
- for (j=0;j<ndim;j++)
- dist += w[j]*(x[j]-codebook[i*ndim+j])*(x[j]-codebook[i*ndim+j]);
- if (dist<min_dist)
- {
- min_dist = dist;
- nearest = i;
- }
- }
- return nearest;
- }
- int quantize_ge(const float *x, const float *codebook1, int nb_entries, float *xq, int ndim)
- {
- int i, n1;
- float err[ndim];
- float w[ndim];
-
- compute_weights2(x, xq, w, ndim);
-
- for (i=0;i<ndim;i++)
- err[i] = x[i]-COEF[i]*xq[i];
- n1 = find_nearest_weighted(codebook1, nb_entries, err, w, ndim);
-
- for (i=0;i<ndim;i++)
- {
- xq[i] = COEF[i]*xq[i] + codebook1[ndim*n1+i];
- err[i] -= codebook1[ndim*n1+i];
- }
- return 0;
- }
- void split(float *codebook, int nb_entries, int ndim)
- {
- int i,j;
- for (i=0;i<nb_entries;i++)
- {
- for (j=0;j<ndim;j++)
- {
- float delta = .01*(rand()/(float)RAND_MAX-.5);
- codebook[i*ndim+j] += delta;
- codebook[(i+nb_entries)*ndim+j] = codebook[i*ndim+j] - delta;
- }
- }
- }
- void update_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
- {
- int i,j;
- float count[MAX_ENTRIES][ndim];
- int nearest[nb_vectors];
-
- //fprintf(stderr, "weighted: %d %d\n", nb_entries, ndim);
- for (i=0;i<nb_entries;i++)
- for (j=0;j<ndim;j++)
- count[i][j] = 0;
-
- for (i=0;i<nb_vectors;i++)
- {
- nearest[i] = find_nearest_weighted(codebook, nb_entries, data+i*ndim, weight+i*ndim, ndim);
- }
- for (i=0;i<nb_entries*ndim;i++)
- codebook[i] = 0;
-
- for (i=0;i<nb_vectors;i++)
- {
- int n = nearest[i];
- for (j=0;j<ndim;j++)
- {
- float w = sqrt(weight[i*ndim+j]);
- count[n][j]+=w;
- codebook[n*ndim+j] += w*data[i*ndim+j];
- }
- }
- //float w2=0;
- for (i=0;i<nb_entries;i++)
- {
- for (j=0;j<ndim;j++)
- codebook[i*ndim+j] *= (1./count[i][j]);
- //w2 += (count[i]/(float)nb_vectors)*(count[i]/(float)nb_vectors);
- }
- //fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
- }
- void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
- {
- int i, j, e;
- e = 1;
- for (j=0;j<ndim;j++)
- codebook[j] = 0;
- for (i=0;i<nb_vectors;i++)
- for (j=0;j<ndim;j++)
- codebook[j] += data[i*ndim+j];
- for (j=0;j<ndim;j++)
- codebook[j] *= (1./nb_vectors);
-
-
- while (e< nb_entries)
- {
- #if 1
- split(codebook, e, ndim);
- e<<=1;
- #else
- split1(codebook, e, data, nb_vectors, ndim);
- e++;
- #endif
- fprintf(stderr, "%d\n", e);
- for (j=0;j<10;j++)
- update_weighted(data, weight, nb_vectors, codebook, e, ndim);
- }
- }
- int main(int argc, char **argv)
- {
- int i,j;
- int nb_vectors, nb_entries, ndim;
- float *data, *pred, *codebook, *codebook2, *codebook3;
- float *weight, *weight2, *weight3;
- float *delta;
- double err[2] = {0, 0};
- double werr[2] = {0, 0};
- double wsum[2] = {0, 0};
-
- ndim = atoi(argv[1]);
- nb_vectors = atoi(argv[2]);
- nb_entries = 1<<atoi(argv[3]);
-
- data = malloc(nb_vectors*ndim*sizeof(*data));
- weight = malloc(nb_vectors*ndim*sizeof(*weight));
- weight2 = malloc(nb_vectors*ndim*sizeof(*weight2));
- weight3 = malloc(nb_vectors*ndim*sizeof(*weight3));
- pred = malloc(nb_vectors*ndim*sizeof(*pred));
- codebook = malloc(nb_entries*ndim*sizeof(*codebook));
- codebook2 = malloc(nb_entries*ndim*sizeof(*codebook2));
- codebook3 = malloc(nb_entries*ndim*sizeof(*codebook3));
-
- for (i=0;i<nb_vectors;i++)
- {
- if (feof(stdin))
- break;
- for (j=0;j<ndim;j++)
- {
- scanf("%f ", &data[i*ndim+j]);
- }
- }
- nb_vectors = i;
- VALGRIND_CHECK_MEM_IS_DEFINED(data, nb_entries*ndim);
- for (i=0;i<nb_vectors;i++)
- {
- if (i==0)
- compute_weights2(data+i*ndim, data+i*ndim, weight+i*ndim, ndim);
- else
- compute_weights2(data+i*ndim, data+(i-1)*ndim, weight+i*ndim, ndim);
- }
- for (i=0;i<ndim;i++)
- pred[i] = data[i];
- for (i=1;i<nb_vectors;i++)
- {
- for (j=0;j<ndim;j++)
- pred[i*ndim+j] = data[i*ndim+j] - COEF[j]*data[(i-1)*ndim+j];
- }
- VALGRIND_CHECK_MEM_IS_DEFINED(pred, nb_entries*ndim);
- vq_train_weighted(pred, weight, nb_vectors, codebook, nb_entries, ndim);
- printf("%d %d\n", ndim, nb_entries);
- for (i=0;i<nb_entries;i++)
- {
- for (j=0;j<ndim;j++)
- {
- printf("%f ", codebook[i*ndim+j]);
- }
- printf("\n");
- }
-
- delta = malloc(nb_vectors*ndim*sizeof(*data));
- float xq[2] = {0,0};
- for (i=0;i<nb_vectors;i++)
- {
- //int nearest = find_nearest_weighted(codebook, nb_entries, &pred[i*ndim], &weight[i*ndim], ndim);
- quantize_ge(&data[i*ndim], codebook, nb_entries, xq, ndim);
- //printf("%f %f\n", xq[0], xq[1]);
- for (j=0;j<ndim;j++)
- {
- delta[i*ndim+j] = xq[j]-data[i*ndim+j];
- err[j] += (delta[i*ndim+j])*(delta[i*ndim+j]);
- werr[j] += weight[i*ndim+j]*(delta[i*ndim+j])*(delta[i*ndim+j]);
- wsum[j] += weight[i*ndim+j];
- //delta[i*ndim+j] = pred[i*ndim+j] - codebook[nearest*ndim+j];
- //printf("%f ", delta[i*ndim+j]);
- //err[j] += (delta[i*ndim+j])*(delta[i*ndim+j]);
- }
- //printf("\n");
- }
- fprintf(stderr, "GE RMS error: %f %f\n", sqrt(err[0]/nb_vectors), sqrt(err[1]/nb_vectors));
- fprintf(stderr, "Weighted GE error: %f %f\n", sqrt(werr[0]/wsum[0]), sqrt(werr[1]/wsum[1]));
- return 0;
- }
|