af_asr.c 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. /*
  2. * Copyright (c) 2019 Paul B Mahol
  3. *
  4. * This file is part of FFmpeg.
  5. *
  6. * FFmpeg is free software; you can redistribute it and/or
  7. * modify it under the terms of the GNU Lesser General Public
  8. * License as published by the Free Software Foundation; either
  9. * version 2.1 of the License, or (at your option) any later version.
  10. *
  11. * FFmpeg is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. * Lesser General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU Lesser General Public
  17. * License along with FFmpeg; if not, write to the Free Software
  18. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  19. */
  20. #include <pocketsphinx/pocketsphinx.h>
  21. #include "libavutil/avassert.h"
  22. #include "libavutil/avstring.h"
  23. #include "libavutil/channel_layout.h"
  24. #include "libavutil/opt.h"
  25. #include "audio.h"
  26. #include "avfilter.h"
  27. #include "internal.h"
  28. typedef struct ASRContext {
  29. const AVClass *class;
  30. int rate;
  31. char *hmm;
  32. char *dict;
  33. char *lm;
  34. char *lmctl;
  35. char *lmname;
  36. char *logfn;
  37. ps_decoder_t *ps;
  38. cmd_ln_t *config;
  39. int utt_started;
  40. } ASRContext;
  41. #define OFFSET(x) offsetof(ASRContext, x)
  42. #define FLAGS AV_OPT_FLAG_AUDIO_PARAM | AV_OPT_FLAG_FILTERING_PARAM
  43. static const AVOption asr_options[] = {
  44. { "rate", "set sampling rate", OFFSET(rate), AV_OPT_TYPE_INT, {.i64=16000}, 0, INT_MAX, .flags = FLAGS },
  45. { "hmm", "set directory containing acoustic model files", OFFSET(hmm), AV_OPT_TYPE_STRING, {.str=NULL}, .flags = FLAGS },
  46. { "dict", "set pronunciation dictionary", OFFSET(dict), AV_OPT_TYPE_STRING, {.str=NULL}, .flags = FLAGS },
  47. { "lm", "set language model file", OFFSET(lm), AV_OPT_TYPE_STRING, {.str=NULL}, .flags = FLAGS },
  48. { "lmctl", "set language model set", OFFSET(lmctl), AV_OPT_TYPE_STRING, {.str=NULL}, .flags = FLAGS },
  49. { "lmname","set which language model to use", OFFSET(lmname), AV_OPT_TYPE_STRING, {.str=NULL}, .flags = FLAGS },
  50. { "logfn", "set output for log messages", OFFSET(logfn), AV_OPT_TYPE_STRING, {.str="/dev/null"}, .flags = FLAGS },
  51. { NULL }
  52. };
  53. AVFILTER_DEFINE_CLASS(asr);
  54. static int filter_frame(AVFilterLink *inlink, AVFrame *in)
  55. {
  56. AVFilterContext *ctx = inlink->dst;
  57. AVDictionary **metadata = &in->metadata;
  58. ASRContext *s = ctx->priv;
  59. int have_speech;
  60. const char *speech;
  61. ps_process_raw(s->ps, (const int16_t *)in->data[0], in->nb_samples, 0, 0);
  62. have_speech = ps_get_in_speech(s->ps);
  63. if (have_speech && !s->utt_started)
  64. s->utt_started = 1;
  65. if (!have_speech && s->utt_started) {
  66. ps_end_utt(s->ps);
  67. speech = ps_get_hyp(s->ps, NULL);
  68. if (speech != NULL)
  69. av_dict_set(metadata, "lavfi.asr.text", speech, 0);
  70. ps_start_utt(s->ps);
  71. s->utt_started = 0;
  72. }
  73. return ff_filter_frame(ctx->outputs[0], in);
  74. }
  75. static int config_input(AVFilterLink *inlink)
  76. {
  77. AVFilterContext *ctx = inlink->dst;
  78. ASRContext *s = ctx->priv;
  79. ps_start_utt(s->ps);
  80. return 0;
  81. }
  82. static av_cold int asr_init(AVFilterContext *ctx)
  83. {
  84. ASRContext *s = ctx->priv;
  85. const float frate = s->rate;
  86. char *rate = av_asprintf("%f", frate);
  87. const char *argv[] = { "-logfn", s->logfn,
  88. "-hmm", s->hmm,
  89. "-lm", s->lm,
  90. "-lmctl", s->lmctl,
  91. "-lmname", s->lmname,
  92. "-dict", s->dict,
  93. "-samprate", rate,
  94. NULL };
  95. s->config = cmd_ln_parse_r(NULL, ps_args(), 14, (char **)argv, 0);
  96. av_free(rate);
  97. if (!s->config)
  98. return AVERROR(ENOMEM);
  99. ps_default_search_args(s->config);
  100. s->ps = ps_init(s->config);
  101. if (!s->ps)
  102. return AVERROR(ENOMEM);
  103. return 0;
  104. }
  105. static int query_formats(AVFilterContext *ctx)
  106. {
  107. ASRContext *s = ctx->priv;
  108. int sample_rates[] = { s->rate, -1 };
  109. int ret;
  110. AVFilterFormats *formats = NULL;
  111. AVFilterChannelLayouts *layout = NULL;
  112. if ((ret = ff_add_format (&formats, AV_SAMPLE_FMT_S16 )) < 0 ||
  113. (ret = ff_set_common_formats (ctx , formats )) < 0 ||
  114. (ret = ff_add_channel_layout (&layout , AV_CH_LAYOUT_MONO )) < 0 ||
  115. (ret = ff_set_common_channel_layouts (ctx , layout )) < 0 ||
  116. (ret = ff_set_common_samplerates (ctx , ff_make_format_list(sample_rates) )) < 0)
  117. return ret;
  118. return 0;
  119. }
  120. static av_cold void asr_uninit(AVFilterContext *ctx)
  121. {
  122. ASRContext *s = ctx->priv;
  123. ps_free(s->ps);
  124. s->ps = NULL;
  125. cmd_ln_free_r(s->config);
  126. s->config = NULL;
  127. }
  128. static const AVFilterPad asr_inputs[] = {
  129. {
  130. .name = "default",
  131. .type = AVMEDIA_TYPE_AUDIO,
  132. .filter_frame = filter_frame,
  133. .config_props = config_input,
  134. },
  135. { NULL }
  136. };
  137. static const AVFilterPad asr_outputs[] = {
  138. {
  139. .name = "default",
  140. .type = AVMEDIA_TYPE_AUDIO,
  141. },
  142. { NULL }
  143. };
  144. AVFilter ff_af_asr = {
  145. .name = "asr",
  146. .description = NULL_IF_CONFIG_SMALL("Automatic Speech Recognition."),
  147. .priv_size = sizeof(ASRContext),
  148. .priv_class = &asr_class,
  149. .init = asr_init,
  150. .uninit = asr_uninit,
  151. .query_formats = query_formats,
  152. .inputs = asr_inputs,
  153. .outputs = asr_outputs,
  154. };