]> git.basschouten.com Git - openhab-addons.git/blob
532ffbb22c3ced05fd6dd14232c258609f4b9656
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2022 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.voskstt.internal;
14
15 import static org.openhab.voice.voskstt.internal.VoskSTTConstants.*;
16
17 import java.io.File;
18 import java.io.IOException;
19 import java.io.InputStream;
20 import java.nio.file.Path;
21 import java.util.Locale;
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.Future;
25 import java.util.concurrent.ScheduledExecutorService;
26 import java.util.concurrent.atomic.AtomicBoolean;
27
28 import org.eclipse.jdt.annotation.NonNullByDefault;
29 import org.eclipse.jdt.annotation.Nullable;
30 import org.openhab.core.OpenHAB;
31 import org.openhab.core.audio.AudioFormat;
32 import org.openhab.core.audio.AudioStream;
33 import org.openhab.core.common.ThreadPoolManager;
34 import org.openhab.core.config.core.ConfigurableService;
35 import org.openhab.core.config.core.Configuration;
36 import org.openhab.core.io.rest.LocaleService;
37 import org.openhab.core.voice.RecognitionStartEvent;
38 import org.openhab.core.voice.RecognitionStopEvent;
39 import org.openhab.core.voice.STTException;
40 import org.openhab.core.voice.STTListener;
41 import org.openhab.core.voice.STTService;
42 import org.openhab.core.voice.STTServiceHandle;
43 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
44 import org.openhab.core.voice.SpeechRecognitionEvent;
45 import org.osgi.framework.Constants;
46 import org.osgi.service.component.annotations.Activate;
47 import org.osgi.service.component.annotations.Component;
48 import org.osgi.service.component.annotations.Deactivate;
49 import org.osgi.service.component.annotations.Modified;
50 import org.osgi.service.component.annotations.Reference;
51 import org.slf4j.Logger;
52 import org.slf4j.LoggerFactory;
53 import org.vosk.Model;
54 import org.vosk.Recognizer;
55
56 import com.fasterxml.jackson.databind.ObjectMapper;
57
58 /**
59  * The {@link VoskSTTService} class is a service implementation to use Vosk-API for Speech-to-Text.
60  *
61  * @author Miguel Álvarez - Initial contribution
62  */
63 @NonNullByDefault
64 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
65 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
66         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
67 public class VoskSTTService implements STTService {
68     private static final String VOSK_FOLDER = Path.of(OpenHAB.getUserDataFolder(), "vosk").toString();
69     private static final String MODEL_PATH = Path.of(VOSK_FOLDER, "model").toString();
70     static {
71         Logger logger = LoggerFactory.getLogger(VoskSTTService.class);
72         File directory = new File(VOSK_FOLDER);
73         if (!directory.exists()) {
74             if (directory.mkdir()) {
75                 logger.info("vosk dir created {}", VOSK_FOLDER);
76             }
77         }
78     }
79     private final Logger logger = LoggerFactory.getLogger(VoskSTTService.class);
80     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-voskstt");
81     private final LocaleService localeService;
82     private VoskSTTConfiguration config = new VoskSTTConfiguration();
83     private @Nullable Model model;
84
85     @Activate
86     public VoskSTTService(@Reference LocaleService localeService) {
87         this.localeService = localeService;
88     }
89
90     @Activate
91     protected void activate(Map<String, Object> config) {
92         configChange(config);
93     }
94
95     @Modified
96     protected void modified(Map<String, Object> config) {
97         configChange(config);
98     }
99
100     @Deactivate
101     protected void deactivate(Map<String, Object> config) {
102         try {
103             unloadModel();
104         } catch (IOException e) {
105             logger.warn("IOException unloading model: {}", e.getMessage());
106         }
107     }
108
109     private void configChange(Map<String, Object> config) {
110         this.config = new Configuration(config).as(VoskSTTConfiguration.class);
111         if (this.config.preloadModel) {
112             try {
113                 loadModel();
114             } catch (IOException e) {
115                 logger.warn("IOException loading model: {}", e.getMessage());
116             }
117         } else {
118             try {
119                 unloadModel();
120             } catch (IOException e) {
121                 logger.warn("IOException unloading model: {}", e.getMessage());
122             }
123         }
124     }
125
126     @Override
127     public String getId() {
128         return SERVICE_ID;
129     }
130
131     @Override
132     public String getLabel(@Nullable Locale locale) {
133         return SERVICE_NAME;
134     }
135
136     @Override
137     public Set<Locale> getSupportedLocales() {
138         // as it is not possible to determine the language of the model that was downloaded and setup by the user, it is
139         // assumed the language of the model is matching the locale of the openHAB server
140         return Set.of(localeService.getLocale(null));
141     }
142
143     @Override
144     public Set<AudioFormat> getSupportedFormats() {
145         return Set.of(
146                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, null, null, 16000L));
147     }
148
149     @Override
150     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale, Set<String> set)
151             throws STTException {
152         AtomicBoolean aborted = new AtomicBoolean(false);
153         try {
154             var frequency = audioStream.getFormat().getFrequency();
155             if (frequency == null) {
156                 throw new IOException("missing audio stream frequency");
157             }
158             backgroundRecognize(sttListener, audioStream, frequency, aborted);
159         } catch (IOException e) {
160             throw new STTException(e);
161         }
162         return () -> {
163             aborted.set(true);
164         };
165     }
166
167     private Model getModel() throws IOException {
168         var model = this.model;
169         if (model != null) {
170             return model;
171         }
172         return loadModel();
173     }
174
175     private Model loadModel() throws IOException {
176         unloadModel();
177         var modelFile = new File(MODEL_PATH);
178         if (!modelFile.exists() || !modelFile.isDirectory()) {
179             throw new IOException("missing model dir: " + MODEL_PATH);
180         }
181         logger.debug("loading model");
182         var model = new Model(MODEL_PATH);
183         if (config.preloadModel) {
184             this.model = model;
185         }
186         return model;
187     }
188
189     private void unloadModel() throws IOException {
190         var model = this.model;
191         if (model != null) {
192             logger.debug("unloading model");
193             model.close();
194             this.model = null;
195         }
196     }
197
198     private Future<?> backgroundRecognize(STTListener sttListener, InputStream audioStream, long frequency,
199             AtomicBoolean aborted) {
200         StringBuilder transcriptBuilder = new StringBuilder();
201         long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
202         long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
203         long startTime = System.currentTimeMillis();
204         return executor.submit(() -> {
205             Recognizer recognizer = null;
206             Model model = null;
207             try {
208                 model = getModel();
209                 recognizer = new Recognizer(model, frequency);
210                 long lastInputTime = System.currentTimeMillis();
211                 int nbytes;
212                 byte[] b = new byte[4096];
213                 sttListener.sttEventReceived(new RecognitionStartEvent());
214                 while (!aborted.get()) {
215                     nbytes = audioStream.read(b);
216                     if (aborted.get()) {
217                         break;
218                     }
219                     if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
220                         logger.debug("Stops listening, max transcription time reached");
221                         break;
222                     }
223                     if (!config.singleUtteranceMode && isExpiredInterval(maxSilenceMillis, lastInputTime)) {
224                         logger.debug("Stops listening, max silence time reached");
225                         break;
226                     }
227                     if (nbytes == 0) {
228                         trySleep(100);
229                         continue;
230                     }
231                     if (recognizer.acceptWaveForm(b, nbytes)) {
232                         lastInputTime = System.currentTimeMillis();
233                         var result = recognizer.getResult();
234                         logger.debug("Result: {}", result);
235                         ObjectMapper mapper = new ObjectMapper();
236                         var json = mapper.readTree(result);
237                         transcriptBuilder.append(json.get("text").asText()).append(" ");
238                         if (config.singleUtteranceMode) {
239                             break;
240                         }
241                     } else {
242                         logger.debug("Partial: {}", recognizer.getPartialResult());
243                     }
244                 }
245                 if (!aborted.get()) {
246                     sttListener.sttEventReceived(new RecognitionStopEvent());
247                     var transcript = transcriptBuilder.toString().trim();
248                     logger.debug("Final: {}", transcript);
249                     if (!transcript.isBlank()) {
250                         sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, 1F));
251                     } else {
252                         if (!config.noResultsMessage.isBlank()) {
253                             sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
254                         } else {
255                             sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
256                         }
257                     }
258                 }
259             } catch (IOException e) {
260                 logger.warn("Error running speech to text: {}", e.getMessage());
261                 if (config.errorMessage.isBlank()) {
262                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("Error"));
263                 } else {
264                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
265                 }
266             } finally {
267                 if (recognizer != null) {
268                     recognizer.close();
269                 }
270                 if (!config.preloadModel && model != null) {
271                     model.close();
272                 }
273             }
274             try {
275                 audioStream.close();
276             } catch (IOException e) {
277                 logger.warn("IOException on close: {}", e.getMessage());
278             }
279         });
280     }
281
282     private void trySleep(long ms) {
283         try {
284             Thread.sleep(ms);
285         } catch (InterruptedException ignored) {
286         }
287     }
288
289     private boolean isExpiredInterval(long interval, long referenceTime) {
290         return System.currentTimeMillis() - referenceTime > interval;
291     }
292 }