ge_train.c 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. /*
  2. ge_train.c
  3. Jean Marc Valin Feb 2012
  4. Joint pitch and energy VQ training program
  5. usage:
  6. cat GE | ./ge_train 2 1000000 8 > quantized
  7. The first column is the log2 of the pitch compared to the lowest freq,
  8. so log2(wo/pi*4000/50) where wo is the frequency your patch outputs. The
  9. second column is the energy in dB, so 10*log10(1e-4+E)
  10. */
  11. /*
  12. Copyright (C) 2012 Jean-Marc Valin
  13. All rights reserved.
  14. This program is free software; you can redistribute it and/or modify
  15. it under the terms of the GNU Lesser General Public License version 2, as
  16. published by the Free Software Foundation. This program is
  17. distributed in the hope that it will be useful, but WITHOUT ANY
  18. WARRANTY; without even the implied warranty of MERCHANTABILITY or
  19. FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
  20. License for more details.
  21. You should have received a copy of the GNU Lesser General Public License
  22. along with this program; if not, see <http://www.gnu.org/licenses/>.
  23. */
  24. #include <valgrind/memcheck.h>
  25. #include <stdlib.h>
  26. #include <stdio.h>
  27. #include <math.h>
  28. #define MIN(a,b) ((a)<(b)?(a):(b))
  29. //#define COEF 0.0
  30. static float COEF[2] = {0.8, 0.9};
  31. //static float COEF[2] = {0.0, 0.};
  32. #define MAX_ENTRIES 16384
  33. void compute_weights2(const float *x, const float *xp, float *w, int ndim)
  34. {
  35. w[0] = 30;
  36. w[1] = 1;
  37. if (x[1]<0)
  38. {
  39. w[0] *= .6;
  40. w[1] *= .3;
  41. }
  42. if (x[1]<-10)
  43. {
  44. w[0] *= .3;
  45. w[1] *= .3;
  46. }
  47. /* Higher weight if pitch is stable */
  48. if (fabs(x[0]-xp[0])<.2)
  49. {
  50. w[0] *= 2;
  51. w[1] *= 1.5;
  52. } else if (fabs(x[0]-xp[0])>.5) /* Lower if not stable */
  53. {
  54. w[0] *= .5;
  55. }
  56. /* Lower weight for low energy */
  57. if (x[1] < xp[1]-10)
  58. {
  59. w[1] *= .5;
  60. }
  61. if (x[1] < xp[1]-20)
  62. {
  63. w[1] *= .5;
  64. }
  65. //w[0] = 30;
  66. //w[1] = 1;
  67. /* Square the weights because it's applied on the squared error */
  68. w[0] *= w[0];
  69. w[1] *= w[1];
  70. }
  71. int find_nearest_weighted(const float *codebook, int nb_entries, float *x, const float *w, int ndim)
  72. {
  73. int i, j;
  74. float min_dist = 1e15;
  75. int nearest = 0;
  76. for (i=0;i<nb_entries;i++)
  77. {
  78. float dist=0;
  79. for (j=0;j<ndim;j++)
  80. dist += w[j]*(x[j]-codebook[i*ndim+j])*(x[j]-codebook[i*ndim+j]);
  81. if (dist<min_dist)
  82. {
  83. min_dist = dist;
  84. nearest = i;
  85. }
  86. }
  87. return nearest;
  88. }
  89. int quantize_ge(const float *x, const float *codebook1, int nb_entries, float *xq, int ndim)
  90. {
  91. int i, n1;
  92. float err[ndim];
  93. float w[ndim];
  94. compute_weights2(x, xq, w, ndim);
  95. for (i=0;i<ndim;i++)
  96. err[i] = x[i]-COEF[i]*xq[i];
  97. n1 = find_nearest_weighted(codebook1, nb_entries, err, w, ndim);
  98. for (i=0;i<ndim;i++)
  99. {
  100. xq[i] = COEF[i]*xq[i] + codebook1[ndim*n1+i];
  101. err[i] -= codebook1[ndim*n1+i];
  102. }
  103. return 0;
  104. }
  105. void split(float *codebook, int nb_entries, int ndim)
  106. {
  107. int i,j;
  108. for (i=0;i<nb_entries;i++)
  109. {
  110. for (j=0;j<ndim;j++)
  111. {
  112. float delta = .01*(rand()/(float)RAND_MAX-.5);
  113. codebook[i*ndim+j] += delta;
  114. codebook[(i+nb_entries)*ndim+j] = codebook[i*ndim+j] - delta;
  115. }
  116. }
  117. }
  118. void update_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
  119. {
  120. int i,j;
  121. float count[MAX_ENTRIES][ndim];
  122. int nearest[nb_vectors];
  123. //fprintf(stderr, "weighted: %d %d\n", nb_entries, ndim);
  124. for (i=0;i<nb_entries;i++)
  125. for (j=0;j<ndim;j++)
  126. count[i][j] = 0;
  127. for (i=0;i<nb_vectors;i++)
  128. {
  129. nearest[i] = find_nearest_weighted(codebook, nb_entries, data+i*ndim, weight+i*ndim, ndim);
  130. }
  131. for (i=0;i<nb_entries*ndim;i++)
  132. codebook[i] = 0;
  133. for (i=0;i<nb_vectors;i++)
  134. {
  135. int n = nearest[i];
  136. for (j=0;j<ndim;j++)
  137. {
  138. float w = sqrt(weight[i*ndim+j]);
  139. count[n][j]+=w;
  140. codebook[n*ndim+j] += w*data[i*ndim+j];
  141. }
  142. }
  143. //float w2=0;
  144. for (i=0;i<nb_entries;i++)
  145. {
  146. for (j=0;j<ndim;j++)
  147. codebook[i*ndim+j] *= (1./count[i][j]);
  148. //w2 += (count[i]/(float)nb_vectors)*(count[i]/(float)nb_vectors);
  149. }
  150. //fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
  151. }
  152. void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
  153. {
  154. int i, j, e;
  155. e = 1;
  156. for (j=0;j<ndim;j++)
  157. codebook[j] = 0;
  158. for (i=0;i<nb_vectors;i++)
  159. for (j=0;j<ndim;j++)
  160. codebook[j] += data[i*ndim+j];
  161. for (j=0;j<ndim;j++)
  162. codebook[j] *= (1./nb_vectors);
  163. while (e< nb_entries)
  164. {
  165. #if 1
  166. split(codebook, e, ndim);
  167. e<<=1;
  168. #else
  169. split1(codebook, e, data, nb_vectors, ndim);
  170. e++;
  171. #endif
  172. fprintf(stderr, "%d\n", e);
  173. for (j=0;j<10;j++)
  174. update_weighted(data, weight, nb_vectors, codebook, e, ndim);
  175. }
  176. }
  177. int main(int argc, char **argv)
  178. {
  179. int i,j;
  180. int nb_vectors, nb_entries, ndim;
  181. float *data, *pred, *codebook, *codebook2, *codebook3;
  182. float *weight, *weight2, *weight3;
  183. float *delta;
  184. double err[2] = {0, 0};
  185. double werr[2] = {0, 0};
  186. double wsum[2] = {0, 0};
  187. ndim = atoi(argv[1]);
  188. nb_vectors = atoi(argv[2]);
  189. nb_entries = 1<<atoi(argv[3]);
  190. data = malloc(nb_vectors*ndim*sizeof(*data));
  191. weight = malloc(nb_vectors*ndim*sizeof(*weight));
  192. weight2 = malloc(nb_vectors*ndim*sizeof(*weight2));
  193. weight3 = malloc(nb_vectors*ndim*sizeof(*weight3));
  194. pred = malloc(nb_vectors*ndim*sizeof(*pred));
  195. codebook = malloc(nb_entries*ndim*sizeof(*codebook));
  196. codebook2 = malloc(nb_entries*ndim*sizeof(*codebook2));
  197. codebook3 = malloc(nb_entries*ndim*sizeof(*codebook3));
  198. for (i=0;i<nb_vectors;i++)
  199. {
  200. if (feof(stdin))
  201. break;
  202. for (j=0;j<ndim;j++)
  203. {
  204. scanf("%f ", &data[i*ndim+j]);
  205. }
  206. }
  207. nb_vectors = i;
  208. VALGRIND_CHECK_MEM_IS_DEFINED(data, nb_entries*ndim);
  209. for (i=0;i<nb_vectors;i++)
  210. {
  211. if (i==0)
  212. compute_weights2(data+i*ndim, data+i*ndim, weight+i*ndim, ndim);
  213. else
  214. compute_weights2(data+i*ndim, data+(i-1)*ndim, weight+i*ndim, ndim);
  215. }
  216. for (i=0;i<ndim;i++)
  217. pred[i] = data[i];
  218. for (i=1;i<nb_vectors;i++)
  219. {
  220. for (j=0;j<ndim;j++)
  221. pred[i*ndim+j] = data[i*ndim+j] - COEF[j]*data[(i-1)*ndim+j];
  222. }
  223. VALGRIND_CHECK_MEM_IS_DEFINED(pred, nb_entries*ndim);
  224. vq_train_weighted(pred, weight, nb_vectors, codebook, nb_entries, ndim);
  225. printf("%d %d\n", ndim, nb_entries);
  226. for (i=0;i<nb_entries;i++)
  227. {
  228. for (j=0;j<ndim;j++)
  229. {
  230. printf("%f ", codebook[i*ndim+j]);
  231. }
  232. printf("\n");
  233. }
  234. delta = malloc(nb_vectors*ndim*sizeof(*data));
  235. float xq[2] = {0,0};
  236. for (i=0;i<nb_vectors;i++)
  237. {
  238. //int nearest = find_nearest_weighted(codebook, nb_entries, &pred[i*ndim], &weight[i*ndim], ndim);
  239. quantize_ge(&data[i*ndim], codebook, nb_entries, xq, ndim);
  240. //printf("%f %f\n", xq[0], xq[1]);
  241. for (j=0;j<ndim;j++)
  242. {
  243. delta[i*ndim+j] = xq[j]-data[i*ndim+j];
  244. err[j] += (delta[i*ndim+j])*(delta[i*ndim+j]);
  245. werr[j] += weight[i*ndim+j]*(delta[i*ndim+j])*(delta[i*ndim+j]);
  246. wsum[j] += weight[i*ndim+j];
  247. //delta[i*ndim+j] = pred[i*ndim+j] - codebook[nearest*ndim+j];
  248. //printf("%f ", delta[i*ndim+j]);
  249. //err[j] += (delta[i*ndim+j])*(delta[i*ndim+j]);
  250. }
  251. //printf("\n");
  252. }
  253. fprintf(stderr, "GE RMS error: %f %f\n", sqrt(err[0]/nb_vectors), sqrt(err[1]/nb_vectors));
  254. fprintf(stderr, "Weighted GE error: %f %f\n", sqrt(werr[0]/wsum[0]), sqrt(werr[1]/wsum[1]));
  255. return 0;
  256. }