vq_train_jvm.c 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. /*---------------------------------------------------------------------------*\
  2. FILE........: vq_train_jvm.c
  3. AUTHOR......: Jean-Marc Valin
  4. DATE CREATED: 21 Jan 2012
  5. Multi-stage Vector Quantoser training program developed by Jean-Marc at
  6. linux.conf.au 2012. Minor mods by David Rowe
  7. \*---------------------------------------------------------------------------*/
  8. /*
  9. Copyright (C) 2012 Jean-Marc Valin
  10. All rights reserved.
  11. This program is free software; you can redistribute it and/or modify
  12. it under the terms of the GNU Lesser General Public License version 2, as
  13. published by the Free Software Foundation. This program is
  14. distributed in the hope that it will be useful, but WITHOUT ANY
  15. WARRANTY; without even the implied warranty of MERCHANTABILITY or
  16. FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
  17. License for more details.
  18. You should have received a copy of the GNU Lesser General Public License
  19. along with this program; if not, see <http://www.gnu.org/licenses/>.
  20. */
  21. #ifdef VALGRIND
  22. #include <valgrind/memcheck.h>
  23. #endif
  24. #include <assert.h>
  25. #include <stdlib.h>
  26. #include <stdio.h>
  27. #include <math.h>
  28. #define MIN(a,b) ((a)<(b)?(a):(b))
  29. #define COEF 0.0f
  30. #define MAX_ENTRIES 16384
  31. void compute_weights(const float *x, float *w, int ndim)
  32. {
  33. int i;
  34. w[0] = MIN(x[0], x[1]-x[0]);
  35. for (i=1;i<ndim-1;i++)
  36. w[i] = MIN(x[i]-x[i-1], x[i+1]-x[i]);
  37. w[ndim-1] = MIN(x[ndim-1]-x[ndim-2], M_PI-x[ndim-1]);
  38. for (i=0;i<ndim;i++)
  39. w[i] = 1./(.01+w[i]);
  40. w[0]*=3;
  41. w[1]*=2;
  42. }
  43. int find_nearest(const float *codebook, int nb_entries, float *x, int ndim, float *min_dist)
  44. {
  45. int i, j;
  46. int nearest = 0;
  47. *min_dist = 1E15;
  48. for (i=0;i<nb_entries;i++)
  49. {
  50. float dist=0;
  51. for (j=0;j<ndim;j++)
  52. dist += (x[j]-codebook[i*ndim+j])*(x[j]-codebook[i*ndim+j]);
  53. if (dist<*min_dist)
  54. {
  55. *min_dist = dist;
  56. nearest = i;
  57. }
  58. }
  59. return nearest;
  60. }
  61. int find_nearest_weighted(const float *codebook, int nb_entries, float *x, const float *w, int ndim)
  62. {
  63. int i, j;
  64. float min_dist = 1e15;
  65. int nearest = 0;
  66. for (i=0;i<nb_entries;i++)
  67. {
  68. float dist=0;
  69. for (j=0;j<ndim;j++)
  70. dist += w[j]*(x[j]-codebook[i*ndim+j])*(x[j]-codebook[i*ndim+j]);
  71. if (dist<min_dist)
  72. {
  73. min_dist = dist;
  74. nearest = i;
  75. }
  76. }
  77. return nearest;
  78. }
  79. int quantize_lsp(const float *x, const float *codebook1, const float *codebook2,
  80. const float *codebook3, int nb_entries, float *xq, int ndim)
  81. {
  82. int i, n1, n2, n3;
  83. float err[ndim], err2[ndim], err3[ndim];
  84. float w[ndim], w2[ndim], w3[ndim], min_dist;
  85. w[0] = MIN(x[0], x[1]-x[0]);
  86. for (i=1;i<ndim-1;i++)
  87. w[i] = MIN(x[i]-x[i-1], x[i+1]-x[i]);
  88. w[ndim-1] = MIN(x[ndim-1]-x[ndim-2], M_PI-x[ndim-1]);
  89. /*
  90. for (i=0;i<ndim;i++)
  91. w[i] = 1./(.003+w[i]);
  92. w[0]*=3;
  93. w[1]*=2;*/
  94. compute_weights(x, w, ndim);
  95. for (i=0;i<ndim;i++)
  96. err[i] = x[i]-COEF*xq[i];
  97. n1 = find_nearest(codebook1, nb_entries, err, ndim, &min_dist);
  98. for (i=0;i<ndim;i++)
  99. {
  100. xq[i] = COEF*xq[i] + codebook1[ndim*n1+i];
  101. err[i] -= codebook1[ndim*n1+i];
  102. }
  103. for (i=0;i<ndim/2;i++)
  104. {
  105. err2[i] = err[2*i];
  106. err3[i] = err[2*i+1];
  107. w2[i] = w[2*i];
  108. w3[i] = w[2*i+1];
  109. }
  110. n2 = find_nearest_weighted(codebook2, nb_entries, err2, w2, ndim/2);
  111. n3 = find_nearest_weighted(codebook3, nb_entries, err3, w3, ndim/2);
  112. for (i=0;i<ndim/2;i++)
  113. {
  114. xq[2*i] += codebook2[ndim*n2/2+i];
  115. xq[2*i+1] += codebook3[ndim*n3/2+i];
  116. }
  117. return 0;
  118. }
  119. void split(float *codebook, int nb_entries, int ndim)
  120. {
  121. int i,j;
  122. for (i=0;i<nb_entries;i++)
  123. {
  124. for (j=0;j<ndim;j++)
  125. {
  126. float delta = .01*(rand()/(float)RAND_MAX-.5);
  127. codebook[i*ndim+j] += delta;
  128. codebook[(i+nb_entries)*ndim+j] = codebook[i*ndim+j] - delta;
  129. }
  130. }
  131. }
  132. void update(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim)
  133. {
  134. int i,j;
  135. int count[nb_entries];
  136. int nearest[nb_vectors];
  137. float min_dist;
  138. float total_min_dist = 0;
  139. for (i=0;i<nb_entries;i++)
  140. count[i] = 0;
  141. for (i=0;i<nb_vectors;i++)
  142. {
  143. nearest[i] = find_nearest(codebook, nb_entries, data+i*ndim, ndim, &min_dist);
  144. total_min_dist += min_dist;
  145. }
  146. for (i=0;i<nb_entries*ndim;i++)
  147. codebook[i] = 0;
  148. for (i=0;i<nb_vectors;i++)
  149. {
  150. int n = nearest[i];
  151. count[n]++;
  152. for (j=0;j<ndim;j++)
  153. codebook[n*ndim+j] += data[i*ndim+j];
  154. }
  155. float w2=0;
  156. for (i=0;i<nb_entries;i++)
  157. {
  158. for (j=0;j<ndim;j++)
  159. codebook[i*ndim+j] *= (1./count[i]);
  160. w2 += (count[i]/(float)nb_vectors)*(count[i]/(float)nb_vectors);
  161. }
  162. fprintf(stderr, "%f / %d var = %f\n", 1./w2, nb_entries, total_min_dist/nb_vectors );
  163. }
  164. void update_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
  165. {
  166. int i,j;
  167. float count[MAX_ENTRIES][ndim];
  168. int nearest[nb_vectors];
  169. for (i=0;i<nb_entries;i++)
  170. for (j=0;j<ndim;j++)
  171. count[i][j] = 0;
  172. for (i=0;i<nb_vectors;i++)
  173. {
  174. nearest[i] = find_nearest_weighted(codebook, nb_entries, data+i*ndim, weight+i*ndim, ndim);
  175. }
  176. for (i=0;i<nb_entries*ndim;i++)
  177. codebook[i] = 0;
  178. for (i=0;i<nb_vectors;i++)
  179. {
  180. int n = nearest[i];
  181. for (j=0;j<ndim;j++)
  182. {
  183. float w = sqrt(weight[i*ndim+j]);
  184. count[n][j]+=w;
  185. codebook[n*ndim+j] += w*data[i*ndim+j];
  186. }
  187. }
  188. //float w2=0;
  189. for (i=0;i<nb_entries;i++)
  190. {
  191. for (j=0;j<ndim;j++)
  192. codebook[i*ndim+j] *= (1./count[i][j]);
  193. //w2 += (count[i]/(float)nb_vectors)*(count[i]/(float)nb_vectors);
  194. }
  195. //fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
  196. }
  197. void vq_train(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim)
  198. {
  199. int i, j, e;
  200. e = 1;
  201. for (j=0;j<ndim;j++)
  202. codebook[j] = 0;
  203. for (i=0;i<nb_vectors;i++)
  204. for (j=0;j<ndim;j++)
  205. codebook[j] += data[i*ndim+j];
  206. for (j=0;j<ndim;j++)
  207. codebook[j] *= (1./nb_vectors);
  208. while (e< nb_entries)
  209. {
  210. split(codebook, e, ndim);
  211. fprintf(stderr, "%d\n", e);
  212. e<<=1;
  213. for (j=0;j<ndim;j++)
  214. update(data, nb_vectors, codebook, e, ndim);
  215. }
  216. }
  217. void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
  218. {
  219. int i, j, e;
  220. e = 1;
  221. for (j=0;j<ndim;j++)
  222. codebook[j] = 0;
  223. for (i=0;i<nb_vectors;i++)
  224. for (j=0;j<ndim;j++)
  225. codebook[j] += data[i*ndim+j];
  226. for (j=0;j<ndim;j++)
  227. codebook[j] *= (1./nb_vectors);
  228. while (e<nb_entries)
  229. {
  230. split(codebook, e, ndim);
  231. fprintf(stderr, "%d\n", e);
  232. e<<=1;
  233. for (j=0;j<ndim;j++)
  234. update_weighted(data, weight, nb_vectors, codebook, e, ndim);
  235. }
  236. }
  237. int main(int argc, char **argv)
  238. {
  239. int i,j;
  240. FILE *ftrain;
  241. int nb_vectors, nb_entries, ndim;
  242. float *data, *pred, *codebook, *codebook2, *codebook3;
  243. float *weight, *weight2, *weight3;
  244. float *delta, *delta2;
  245. float tmp, err, min_dist, total_min_dist;
  246. int ret;
  247. char filename[256];
  248. FILE *fcb;
  249. printf("Jean-Marc Valin's Split VQ training program....\n");
  250. if (argc != 5) {
  251. printf("usage: %s TrainTextFile K(dimension) M(codebook size) VQFilesPrefix\n", argv[0]);
  252. exit(1);
  253. }
  254. ndim = atoi(argv[2]);
  255. nb_vectors = atoi(argv[3]);
  256. nb_entries = atoi(argv[3]);
  257. /* determine size of training file */
  258. ftrain = fopen(argv[1],"rt"); assert(ftrain != NULL);
  259. nb_vectors = 0;
  260. while (1) {
  261. if (feof(ftrain))
  262. break;
  263. for (j=0;j<ndim;j++)
  264. {
  265. ret = fscanf(ftrain, "%f ", &tmp);
  266. }
  267. nb_vectors++;
  268. if ((nb_vectors % 1000) == 0)
  269. printf("\r%d lines",nb_vectors);
  270. }
  271. rewind(ftrain);
  272. printf("\nndim %d nb_vectors %d nb_entries %d\n", ndim, nb_vectors, nb_entries);
  273. data = malloc(nb_vectors*ndim*sizeof(*data));
  274. weight = malloc(nb_vectors*ndim*sizeof(*weight));
  275. weight2 = malloc(nb_vectors*ndim*sizeof(*weight2));
  276. weight3 = malloc(nb_vectors*ndim*sizeof(*weight3));
  277. pred = malloc(nb_vectors*ndim*sizeof(*pred));
  278. codebook = malloc(nb_entries*ndim*sizeof(*codebook));
  279. codebook2 = malloc(nb_entries*ndim*sizeof(*codebook2));
  280. codebook3 = malloc(nb_entries*ndim*sizeof(*codebook3));
  281. for (i=0;i<nb_vectors;i++)
  282. {
  283. if (feof(ftrain))
  284. break;
  285. for (j=0;j<ndim;j++)
  286. {
  287. ret = fscanf(ftrain, "%f ", &data[i*ndim+j]);
  288. }
  289. }
  290. nb_vectors = i;
  291. #ifdef VALGRIND
  292. VALGRIND_CHECK_MEM_IS_DEFINED(data, nb_entries*ndim);
  293. #endif
  294. /* determine weights for each training vector */
  295. for (i=0;i<nb_vectors;i++)
  296. {
  297. compute_weights(data+i*ndim, weight+i*ndim, ndim);
  298. for (j=0;j<ndim/2;j++)
  299. {
  300. weight2[i*ndim/2+j] = weight[i*ndim+2*j];
  301. weight3[i*ndim/2+j] = weight[i*ndim+2*j+1];
  302. }
  303. }
  304. /* 20ms (two frame gaps) initial predictor state */
  305. for (i=0;i<ndim;i++) {
  306. pred[i+ndim] = pred[i] = data[i] - M_PI*(i+1)/(ndim+1);
  307. }
  308. /* generate predicted data for training */
  309. for (i=2;i<nb_vectors;i++)
  310. {
  311. for (j=0;j<ndim;j++)
  312. pred[i*ndim+j] = data[i*ndim+j] - COEF*data[(i-2)*ndim+j];
  313. }
  314. #ifdef VALGRIND
  315. VALGRIND_CHECK_MEM_IS_DEFINED(pred, nb_entries*ndim);
  316. #endif
  317. /* train first stage */
  318. vq_train(pred, nb_vectors, codebook, nb_entries, ndim);
  319. delta = malloc(nb_vectors*ndim*sizeof(*data));
  320. err = 0;
  321. total_min_dist = 0;
  322. for (i=0;i<nb_vectors;i++)
  323. {
  324. int nearest = find_nearest(codebook, nb_entries, &pred[i*ndim], ndim, &min_dist);
  325. total_min_dist += min_dist;
  326. for (j=0;j<ndim;j++)
  327. {
  328. //delta[i*ndim+j] = data[i*ndim+j] - codebook[nearest*ndim+j];
  329. //printf("%f ", delta[i*ndim+j]);
  330. //err += (delta[i*ndim+j])*(delta[i*ndim+j]);
  331. delta[i*ndim/2+j/2+(j&1)*nb_vectors*ndim/2] = pred[i*ndim+j] - codebook[nearest*ndim+j];
  332. //printf("%f ", delta[i*ndim/2+j/2+(j&1)*nb_vectors*ndim/2]);
  333. err += (delta[i*ndim/2+j/2+(j&1)*nb_vectors*ndim/2])*(delta[i*ndim/2+j/2+(j&1)*nb_vectors*ndim/2]);
  334. }
  335. //printf("\n");
  336. }
  337. fprintf(stderr, "Stage 1 LSP RMS error: %f\n", sqrt(err/nb_vectors/ndim));
  338. fprintf(stderr, "Stage 1 LSP variance.: %f\n", total_min_dist/nb_vectors);
  339. #if 1
  340. vq_train(delta, nb_vectors, codebook2, nb_entries, ndim/2);
  341. vq_train(delta+ndim*nb_vectors/2, nb_vectors, codebook3, nb_entries, ndim/2);
  342. #else
  343. vq_train_weighted(delta, weight2, nb_vectors, codebook2, nb_entries, ndim/2);
  344. vq_train_weighted(delta+ndim*nb_vectors/2, weight3, nb_vectors, codebook3, nb_entries, ndim/2);
  345. #endif
  346. err = 0;
  347. total_min_dist = 0;
  348. delta2 = delta + nb_vectors*ndim/2;
  349. for (i=0;i<nb_vectors;i++)
  350. {
  351. int n1, n2;
  352. n1 = find_nearest(codebook2, nb_entries, &delta[i*ndim/2], ndim/2, &min_dist);
  353. for (j=0;j<ndim/2;j++)
  354. {
  355. delta[i*ndim/2+j] = delta[i*ndim/2+j] - codebook2[n1*ndim/2+j];
  356. err += (delta[i*ndim/2+j])*(delta[i*ndim/2+j]);
  357. }
  358. total_min_dist += min_dist;
  359. n2 = find_nearest(codebook3, nb_entries, &delta2[i*ndim/2], ndim/2, &min_dist);
  360. for (j=0;j<ndim/2;j++)
  361. {
  362. delta[i*ndim/2+j] = delta[i*ndim/2+j] - codebook2[n2*ndim/2+j];
  363. err += (delta2[i*ndim/2+j])*(delta2[i*ndim/2+j]);
  364. }
  365. total_min_dist += min_dist;
  366. }
  367. fprintf(stderr, "Stage 2 LSP RMS error: %f\n", sqrt(err/nb_vectors/ndim));
  368. fprintf(stderr, "Stage 2 LSP Variance.: %f\n", total_min_dist/nb_vectors);
  369. float xq[ndim];
  370. for (i=0;i<ndim;i++)
  371. xq[i] = M_PI*(i+1)/(ndim+1);
  372. for (i=0;i<nb_vectors;i++)
  373. {
  374. quantize_lsp(data+i*ndim, codebook, codebook2,
  375. codebook3, nb_entries, xq, ndim);
  376. /*for (j=0;j<ndim;j++)
  377. printf("%f ", xq[j]);
  378. printf("\n");*/
  379. }
  380. /* save output tables to text files */
  381. sprintf(filename, "%s1.txt", argv[4]);
  382. fcb = fopen(filename, "wt"); assert(fcb != NULL);
  383. fprintf(fcb, "%d %d\n", ndim, nb_entries);
  384. for (i=0;i<nb_entries;i++)
  385. {
  386. for (j=0;j<ndim;j++)
  387. fprintf(fcb, "%f ", codebook[i*ndim+j]);
  388. fprintf(fcb, "\n");
  389. }
  390. fclose(fcb);
  391. sprintf(filename, "%s2.txt", argv[4]);
  392. fcb = fopen(filename, "wt"); assert(fcb != NULL);
  393. fprintf(fcb, "%d %d\n", ndim/2, nb_entries);
  394. for (i=0;i<nb_entries;i++)
  395. {
  396. for (j=0;j<ndim/2;j++)
  397. fprintf(fcb, "%f ", codebook2[i*ndim/2+j]);
  398. fprintf(fcb, "\n");
  399. }
  400. fclose(fcb);
  401. sprintf(filename, "%s3.txt", argv[4]);
  402. fcb = fopen(filename, "wt"); assert(fcb != NULL);
  403. fprintf(fcb, "%d %d\n", ndim/2, nb_entries);
  404. for (i=0;i<nb_entries;i++)
  405. {
  406. for (j=0;j<ndim/2;j++)
  407. fprintf(fcb, "%f ", codebook3[i*ndim/2+j]);
  408. fprintf(fcb, "\n");
  409. }
  410. fclose(fcb);
  411. return 0;
  412. }