2
0

vqtrainsp.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. /*--------------------------------------------------------------------------*\
  2. FILE........: vqtrainsp.c
  3. AUTHOR......: David Rowe
  4. DATE CREATED: 7 August 2012
  5. This program trains sparse amplitude vector quantisers.
  6. Modified from vqtrainph.c
  7. \*--------------------------------------------------------------------------*/
  8. /*
  9. Copyright (C) 2012 David Rowe
  10. All rights reserved.
  11. This program is free software; you can redistribute it and/or modify
  12. it under the terms of the GNU Lesser General Public License version 2, as
  13. published by the Free Software Foundation. This program is
  14. distributed in the hope that it will be useful, but WITHOUT ANY
  15. WARRANTY; without even the implied warranty of MERCHANTABILITY or
  16. FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
  17. License for more details.
  18. You should have received a copy of the GNU Lesser General Public License
  19. along with this program; if not, see <http://www.gnu.org/licenses/>.
  20. */
  21. /*-----------------------------------------------------------------------*\
  22. INCLUDES
  23. \*-----------------------------------------------------------------------*/
  24. #include <stdio.h>
  25. #include <stdlib.h>
  26. #include <string.h>
  27. #include <math.h>
  28. #include <ctype.h>
  29. #include <assert.h>
  30. typedef struct {
  31. float real;
  32. float imag;
  33. } COMP;
  34. /*-----------------------------------------------------------------------* \
  35. DEFINES
  36. \*-----------------------------------------------------------------------*/
  37. #define DELTAQ 0.01 /* quiting distortion */
  38. #define MAX_STR 80 /* maximum string length */
  39. /*-----------------------------------------------------------------------*\
  40. FUNCTION PROTOTYPES
  41. \*-----------------------------------------------------------------------*/
  42. void zero(float v[], int d);
  43. void acc(float v1[], float v2[], int d);
  44. void norm(float v[], int k, int n[]);
  45. int quantise(float cb[], float vec[], int d, int e, float *se);
  46. void print_vec(float cb[], int d, int e);
  47. void split(float cb[], int d, int b);
  48. int gain_shape_quantise(float cb[], float vec[], int d, int e, float *se, float *best_gain);
  49. /*-----------------------------------------------------------------------* \
  50. MAIN
  51. \*-----------------------------------------------------------------------*/
  52. int main(int argc, char *argv[]) {
  53. int d,e; /* dimension and codebook size */
  54. float *vec; /* current vector */
  55. float *cb; /* vector codebook */
  56. float *cent; /* centroids for each codebook entry */
  57. int *n; /* number of vectors in this interval */
  58. int J; /* number of vectors in training set */
  59. int ind; /* index of current vector */
  60. float se; /* total squared error for this iteration */
  61. float var; /* variance */
  62. float var_1; /* previous variance */
  63. float delta; /* improvement in distortion */
  64. FILE *ftrain; /* file containing training set */
  65. FILE *fvq; /* file containing vector quantiser */
  66. int ret;
  67. int i,j, finished, iterations;
  68. float sd;
  69. int var_n, bits, b, levels;
  70. /* Interpret command line arguments */
  71. if (argc < 5) {
  72. printf("usage: %s TrainFile D(dimension) B(number of bits) VQFile [error.txt file]\n", argv[0]);
  73. exit(1);
  74. }
  75. /* Open training file */
  76. ftrain = fopen(argv[1],"rb");
  77. if (ftrain == NULL) {
  78. printf("Error opening training database file: %s\n",argv[1]);
  79. exit(1);
  80. }
  81. /* determine k and m, and allocate arrays */
  82. d = atoi(argv[2]);
  83. bits = atoi(argv[3]);
  84. e = 1<<bits;
  85. printf("\n");
  86. printf("dimension D=%d number of bits B=%d entries E=%d\n", d, bits, e);
  87. vec = (float*)malloc(sizeof(float)*d);
  88. cb = (float*)malloc(sizeof(float)*d*e);
  89. cent = (float*)malloc(sizeof(float)*d*e);
  90. n = (int*)malloc(sizeof(int)*d*e);
  91. if (cb == NULL || cb == NULL || cent == NULL || vec == NULL) {
  92. printf("Error in malloc.\n");
  93. exit(1);
  94. }
  95. /* determine size of training set */
  96. J = 0;
  97. var_n = 0;
  98. while(fread(vec, sizeof(float), d, ftrain) == (size_t)d) {
  99. for(j=0; j<d; j++)
  100. if (vec[j] != 0.0)
  101. var_n++;
  102. J++;
  103. }
  104. printf("J=%d sparse vectors in training set, %d non-zero values\n", J, var_n);
  105. /* set up initial codebook from centroid of training set */
  106. //#define DBG
  107. zero(cent, d);
  108. for(j=0; j<d; j++)
  109. n[j] = 0;
  110. rewind(ftrain);
  111. #ifdef DBG
  112. printf("initial codebook...\n");
  113. #endif
  114. for(i=0; i<J; i++) {
  115. ret = fread(vec, sizeof(float), d, ftrain);
  116. #ifdef DBG
  117. print_vec(vec, d, 1);
  118. #endif
  119. acc(cent, vec, d);
  120. for(j=0; j<d; j++)
  121. if (vec[j] != 0.0)
  122. n[j]++;
  123. }
  124. norm(cent, d, n);
  125. memcpy(cb, cent, d*sizeof(float));
  126. #ifdef DBG
  127. printf("\n");
  128. print_vec(cb, d, 1);
  129. #endif
  130. /* main loop */
  131. printf("\n");
  132. printf("bits Iteration delta var std dev\n");
  133. printf("---------------------------------------\n");
  134. for(b=1; b<=bits; b++) {
  135. levels = 1<<b;
  136. iterations = 0;
  137. finished = 0;
  138. delta = 0;
  139. var_1 = 0.0;
  140. split(cb, d, levels/2);
  141. //print_vec(cb, d, levels);
  142. do {
  143. /* zero centroids */
  144. for(i=0; i<levels; i++) {
  145. zero(&cent[i*d], d);
  146. for(j=0; j<d; j++)
  147. n[i*d+j] = 0;
  148. }
  149. //#define DBG
  150. #ifdef DBG
  151. printf("cb...\n");
  152. print_vec(cb, d, levels);
  153. printf("\n\nquantise...\n");
  154. #endif
  155. /* quantise training set */
  156. se = 0.0;
  157. rewind(ftrain);
  158. for(i=0; i<J; i++) {
  159. ret = fread(vec, sizeof(float), d, ftrain);
  160. ind = quantise(cb, vec, d, levels, &se);
  161. //ind = gain_shape_quantise(cb, vec, d, levels, &se, &best_gain);
  162. //for(j=0; j<d; j++)
  163. // if (vec[j] != 0.0)
  164. // vec[j] += best_gain;
  165. #ifdef DBG
  166. print_vec(vec, d, 1);
  167. printf(" ind %d se: %f\n", ind, se);
  168. #endif
  169. acc(&cent[ind*d], vec, d);
  170. for(j=0; j<d; j++)
  171. if (vec[j] != 0.0)
  172. n[ind*d+j]++;
  173. }
  174. #ifdef DBG
  175. printf("cent...\n");
  176. print_vec(cent, d, e);
  177. printf("\n");
  178. #endif
  179. /* work out stats */
  180. var = se/var_n;
  181. sd = sqrt(var);
  182. iterations++;
  183. if (iterations > 1) {
  184. if (var > 0.0) {
  185. delta = (var_1 - var)/var;
  186. }
  187. else
  188. delta = 0;
  189. if (delta < DELTAQ)
  190. finished = 1;
  191. }
  192. if (!finished) {
  193. /* determine new codebook from centroids */
  194. for(i=0; i<levels; i++) {
  195. norm(&cent[i*d], d, &n[i*d]);
  196. memcpy(&cb[i*d], &cent[i*d], d*sizeof(float));
  197. }
  198. }
  199. #ifdef DBG
  200. printf("new cb ...\n");
  201. print_vec(cent, d, e);
  202. printf("\n");
  203. #endif
  204. printf("%2d %2d %4.3f %6.3f %4.3f\r",b,iterations, delta, var, sd);
  205. fflush(stdout);
  206. var_1 = var;
  207. } while (!finished);
  208. printf("\n");
  209. }
  210. //print_vec(cb, d, 1);
  211. /* save codebook to disk */
  212. fvq = fopen(argv[4],"wt");
  213. if (fvq == NULL) {
  214. printf("Error opening VQ file: %s\n",argv[4]);
  215. exit(1);
  216. }
  217. fprintf(fvq,"%d %d\n",d,e);
  218. for(j=0; j<e; j++) {
  219. for(i=0; i<d; i++)
  220. fprintf(fvq,"% 7.3f ", cb[j*d+i]);
  221. fprintf(fvq,"\n");
  222. }
  223. fclose(fvq);
  224. /* optionally dump error file for multi-stage work */
  225. if (argc == 6) {
  226. FILE *ferr = fopen(argv[5],"wt");
  227. assert(ferr != NULL);
  228. rewind(ftrain);
  229. for(i=0; i<J; i++) {
  230. ret = fread(vec, sizeof(float), d, ftrain);
  231. ind = quantise(cb, vec, d, levels, &se);
  232. for(j=0; j<d; j++) {
  233. if (vec[j] != 0.0)
  234. vec[j] -= cb[ind*d+j];
  235. fprintf(ferr, "%f ", vec[j]);
  236. }
  237. fprintf(ferr, "\n");
  238. }
  239. }
  240. return 0;
  241. }
  242. /*-----------------------------------------------------------------------*\
  243. FUNCTIONS
  244. \*-----------------------------------------------------------------------*/
  245. void print_vec(float cb[], int d, int e)
  246. {
  247. int i,j;
  248. for(j=0; j<e; j++) {
  249. printf(" ");
  250. for(i=0; i<d; i++)
  251. printf("% 7.3f ", cb[j*d+i]);
  252. printf("\n");
  253. }
  254. }
  255. /*---------------------------------------------------------------------------*\
  256. FUNCTION....: zero()
  257. AUTHOR......: David Rowe
  258. DATE CREATED: 23/2/95
  259. Zeros a vector of length d.
  260. \*---------------------------------------------------------------------------*/
  261. void zero(float v[], int d)
  262. {
  263. int i;
  264. for(i=0; i<d; i++) {
  265. v[i] = 0.0;
  266. }
  267. }
  268. /*---------------------------------------------------------------------------*\
  269. FUNCTION....: acc()
  270. AUTHOR......: David Rowe
  271. DATE CREATED: 23/2/95
  272. Adds d dimensional vectors v1 to v2 and stores the result back
  273. in v1.
  274. An unused entry in a sparse vector is set to zero so won't
  275. affect the accumulation process.
  276. \*---------------------------------------------------------------------------*/
  277. void acc(float v1[], float v2[], int d)
  278. {
  279. int i;
  280. for(i=0; i<d; i++)
  281. v1[i] += v2[i];
  282. }
  283. /*---------------------------------------------------------------------------*\
  284. FUNCTION....: norm()
  285. AUTHOR......: David Rowe
  286. DATE CREATED: 23/2/95
  287. Normalises each element in d dimensional vector.
  288. \*---------------------------------------------------------------------------*/
  289. void norm(float v[], int d, int n[])
  290. {
  291. int i;
  292. for(i=0; i<d; i++) {
  293. if (n[i] != 0)
  294. v[i] /= n[i];
  295. }
  296. }
  297. /*---------------------------------------------------------------------------*\
  298. FUNCTION....: quantise()
  299. AUTHOR......: David Rowe
  300. DATE CREATED: 23/2/95
  301. Quantises vec by choosing the nearest vector in codebook cb, and
  302. returns the vector index. The squared error of the quantised vector
  303. is added to se.
  304. Unused entries in sparse vectors are ignored.
  305. \*---------------------------------------------------------------------------*/
  306. int quantise(float cb[], float vec[], int d, int e, float *se)
  307. {
  308. float error; /* current error */
  309. int besti; /* best index so far */
  310. float best_error; /* best error so far */
  311. int i,j;
  312. float diff;
  313. besti = 0;
  314. best_error = 1E32;
  315. for(j=0; j<e; j++) {
  316. error = 0.0;
  317. for(i=0; i<d; i++) {
  318. if (vec[i] != 0.0) {
  319. diff = cb[j*d+i] - vec[i];
  320. error += diff*diff;
  321. }
  322. }
  323. if (error < best_error) {
  324. best_error = error;
  325. besti = j;
  326. }
  327. }
  328. *se += best_error;
  329. return(besti);
  330. }
  331. int gain_shape_quantise(float cb[], float vec[], int d, int e, float *se, float *best_gain)
  332. {
  333. float error; /* current error */
  334. int besti; /* best index so far */
  335. float best_error; /* best error so far */
  336. int i,j,m;
  337. float diff, metric, best_metric, gain, sumAm, sumCb;
  338. besti = 0;
  339. best_metric = best_error = 1E32;
  340. for(j=0; j<e; j++) {
  341. /* compute optimum gain */
  342. sumAm = sumCb = 0.0;
  343. m = 0;
  344. for(i=0; i<d; i++) {
  345. if (vec[i] != 0.0) {
  346. m++;
  347. sumAm += vec[i];
  348. sumCb += cb[j*d+i];
  349. }
  350. }
  351. gain = (sumAm - sumCb)/m;
  352. /* compute error */
  353. metric = error = 0.0;
  354. for(i=0; i<d; i++) {
  355. if (vec[i] != 0.0) {
  356. diff = vec[i] - cb[j*d+i] - gain;
  357. error += diff*diff;
  358. metric += diff*diff;
  359. }
  360. }
  361. if (metric < best_metric) {
  362. best_error = error;
  363. best_metric = metric;
  364. *best_gain = gain;
  365. besti = j;
  366. }
  367. }
  368. *se += best_error;
  369. return(besti);
  370. }
  371. void split(float cb[], int d, int levels)
  372. {
  373. int i,j;
  374. for (i=0;i<levels;i++) {
  375. for (j=0;j<d;j++) {
  376. float delta = .01*(rand()/(float)RAND_MAX-.5);
  377. cb[i*d+j] += delta;
  378. cb[(i+levels)*d+j] = cb[i*d+j] - delta;
  379. }
  380. }
  381. }