]> git.basschouten.com Git - openhab-addons.git/blob
9454a778a4e8fb88cb252cf62fe8f31e3154bd54
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2022 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.watsonstt.internal;
14
15 import static org.openhab.voice.watsonstt.internal.WatsonSTTConstants.*;
16
17 import java.util.List;
18 import java.util.Locale;
19 import java.util.Map;
20 import java.util.Set;
21 import java.util.concurrent.ScheduledExecutorService;
22 import java.util.concurrent.atomic.AtomicBoolean;
23 import java.util.concurrent.atomic.AtomicReference;
24 import java.util.stream.Collectors;
25
26 import org.eclipse.jdt.annotation.NonNullByDefault;
27 import org.eclipse.jdt.annotation.Nullable;
28 import org.openhab.core.audio.AudioFormat;
29 import org.openhab.core.audio.AudioStream;
30 import org.openhab.core.common.ThreadPoolManager;
31 import org.openhab.core.config.core.ConfigurableService;
32 import org.openhab.core.config.core.Configuration;
33 import org.openhab.core.voice.RecognitionStartEvent;
34 import org.openhab.core.voice.RecognitionStopEvent;
35 import org.openhab.core.voice.STTException;
36 import org.openhab.core.voice.STTListener;
37 import org.openhab.core.voice.STTService;
38 import org.openhab.core.voice.STTServiceHandle;
39 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
40 import org.openhab.core.voice.SpeechRecognitionEvent;
41 import org.osgi.framework.Constants;
42 import org.osgi.service.component.annotations.Activate;
43 import org.osgi.service.component.annotations.Component;
44 import org.osgi.service.component.annotations.Modified;
45 import org.slf4j.Logger;
46 import org.slf4j.LoggerFactory;
47
48 import com.google.gson.JsonObject;
49 import com.ibm.cloud.sdk.core.http.HttpMediaType;
50 import com.ibm.cloud.sdk.core.security.IamAuthenticator;
51 import com.ibm.watson.speech_to_text.v1.SpeechToText;
52 import com.ibm.watson.speech_to_text.v1.model.RecognizeWithWebsocketsOptions;
53 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionAlternative;
54 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResult;
55 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResults;
56 import com.ibm.watson.speech_to_text.v1.websocket.RecognizeCallback;
57
58 import okhttp3.WebSocket;
59
60 /**
61  * The {@link WatsonSTTService} allows to use Watson as Speech-to-Text engine
62  *
63  * @author Miguel Álvarez - Initial contribution
64  */
65 @NonNullByDefault
66 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
67 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
68         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
69 public class WatsonSTTService implements STTService {
70     private final Logger logger = LoggerFactory.getLogger(WatsonSTTService.class);
71     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-watsonstt");
72     private final List<String> models = List.of("ar-AR_BroadbandModel", "de-DE_BroadbandModel", "en-AU_BroadbandModel",
73             "en-GB_BroadbandModel", "en-US_BroadbandModel", "es-AR_BroadbandModel", "es-CL_BroadbandModel",
74             "es-CO_BroadbandModel", "es-ES_BroadbandModel", "es-MX_BroadbandModel", "es-PE_BroadbandModel",
75             "fr-CA_BroadbandModel", "fr-FR_BroadbandModel", "it-IT_BroadbandModel", "ja-JP_BroadbandModel",
76             "ko-KR_BroadbandModel", "nl-NL_BroadbandModel", "pt-BR_BroadbandModel", "zh-CN_BroadbandModel");
77     private final Set<Locale> supportedLocales = models.stream().map(name -> name.split("_")[0])
78             .map(Locale::forLanguageTag).collect(Collectors.toSet());
79     private WatsonSTTConfiguration config = new WatsonSTTConfiguration();
80     private @Nullable SpeechToText speechToText = null;
81
82     @Activate
83     protected void activate(Map<String, Object> config) {
84         modified(config);
85     }
86
87     @Modified
88     protected void modified(Map<String, Object> config) {
89         this.config = new Configuration(config).as(WatsonSTTConfiguration.class);
90         if (this.config.apiKey.isBlank() || this.config.instanceUrl.isBlank()) {
91             this.speechToText = null;
92         } else {
93             var speechToText = new SpeechToText(new IamAuthenticator.Builder().apikey(this.config.apiKey).build());
94             speechToText.setServiceUrl(this.config.instanceUrl);
95             if (this.config.optOutLogging) {
96                 speechToText.setDefaultHeaders(Map.of("X-Watson-Learning-Opt-Out", "1"));
97             }
98             this.speechToText = speechToText;
99         }
100     }
101
102     @Override
103     public String getId() {
104         return SERVICE_ID;
105     }
106
107     @Override
108     public String getLabel(@Nullable Locale locale) {
109         return SERVICE_NAME;
110     }
111
112     @Override
113     public Set<Locale> getSupportedLocales() {
114         return supportedLocales;
115     }
116
117     @Override
118     public Set<AudioFormat> getSupportedFormats() {
119         return Set.of(AudioFormat.WAV, AudioFormat.OGG, new AudioFormat("OGG", "OPUS", null, null, null, null),
120                 AudioFormat.MP3);
121     }
122
123     @Override
124     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale, Set<String> set)
125             throws STTException {
126         var stt = this.speechToText;
127         if (stt == null) {
128             throw new STTException("service is not correctly configured");
129         }
130         String contentType = getContentType(audioStream);
131         if (contentType == null) {
132             throw new STTException("Unsupported format, unable to resolve audio content type");
133         }
134         logger.debug("Content-Type: {}", contentType);
135         RecognizeWithWebsocketsOptions wsOptions = new RecognizeWithWebsocketsOptions.Builder().audio(audioStream)
136                 .contentType(contentType).redaction(config.redaction).smartFormatting(config.smartFormatting)
137                 .model(locale.toLanguageTag() + "_BroadbandModel").interimResults(true)
138                 .backgroundAudioSuppression(config.backgroundAudioSuppression)
139                 .speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.maxSilenceSeconds)
140                 .build();
141         final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
142         final AtomicBoolean aborted = new AtomicBoolean(false);
143         executor.submit(() -> {
144             socketRef.set(stt.recognizeUsingWebSocket(wsOptions,
145                     new TranscriptionListener(socketRef, sttListener, config, aborted)));
146         });
147         return new STTServiceHandle() {
148             @Override
149             public void abort() {
150                 if (!aborted.getAndSet(true)) {
151                     var socket = socketRef.get();
152                     if (socket != null) {
153                         sendStopMessage(socket);
154                     }
155                 }
156             }
157         };
158     }
159
160     private @Nullable String getContentType(AudioStream audioStream) throws STTException {
161         AudioFormat format = audioStream.getFormat();
162         String container = format.getContainer();
163         String codec = format.getCodec();
164         if (container == null || codec == null) {
165             throw new STTException("Missing audio stream info");
166         }
167         Long frequency = format.getFrequency();
168         Integer bitDepth = format.getBitDepth();
169         switch (container) {
170             case AudioFormat.CONTAINER_WAVE:
171                 if (AudioFormat.CODEC_PCM_SIGNED.equals(codec)) {
172                     if (bitDepth == null || bitDepth != 16) {
173                         return "audio/wav";
174                     }
175                     // rate is a required parameter for this type
176                     if (frequency == null) {
177                         return null;
178                     }
179                     StringBuilder contentTypeL16 = new StringBuilder(HttpMediaType.AUDIO_PCM).append(";rate=")
180                             .append(frequency);
181                     // // those are optional
182                     Integer channels = format.getChannels();
183                     if (channels != null) {
184                         contentTypeL16.append(";channels=").append(channels);
185                     }
186                     Boolean bigEndian = format.isBigEndian();
187                     if (bigEndian != null) {
188                         contentTypeL16.append(";")
189                                 .append(bigEndian ? "endianness=big-endian" : "endianness=little-endian");
190                     }
191                     return contentTypeL16.toString();
192                 }
193             case AudioFormat.CONTAINER_OGG:
194                 switch (codec) {
195                     case AudioFormat.CODEC_VORBIS:
196                         return "audio/ogg;codecs=vorbis";
197                     case "OPUS":
198                         return "audio/ogg;codecs=opus";
199                 }
200                 break;
201             case AudioFormat.CONTAINER_NONE:
202                 if (AudioFormat.CODEC_MP3.equals(codec)) {
203                     return "audio/mp3";
204                 }
205                 break;
206         }
207         return null;
208     }
209
210     private static void sendStopMessage(WebSocket ws) {
211         JsonObject stopMessage = new JsonObject();
212         stopMessage.addProperty("action", "stop");
213         ws.send(stopMessage.toString());
214     }
215
216     private static class TranscriptionListener implements RecognizeCallback {
217         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
218         private final StringBuilder transcriptBuilder = new StringBuilder();
219         private final STTListener sttListener;
220         private final WatsonSTTConfiguration config;
221         private final AtomicBoolean aborted;
222         private final AtomicReference<@Nullable WebSocket> socketRef;
223         private float confidenceSum = 0f;
224         private int responseCount = 0;
225         private boolean disconnected = false;
226
227         public TranscriptionListener(AtomicReference<@Nullable WebSocket> socketRef, STTListener sttListener,
228                 WatsonSTTConfiguration config, AtomicBoolean aborted) {
229             this.socketRef = socketRef;
230             this.sttListener = sttListener;
231             this.config = config;
232             this.aborted = aborted;
233         }
234
235         @Override
236         public void onTranscription(@Nullable SpeechRecognitionResults speechRecognitionResults) {
237             logger.debug("onTranscription");
238             if (speechRecognitionResults == null) {
239                 return;
240             }
241             speechRecognitionResults.getResults().stream().filter(SpeechRecognitionResult::isXFinal).forEach(result -> {
242                 SpeechRecognitionAlternative alternative = result.getAlternatives().stream().findFirst().orElse(null);
243                 if (alternative == null) {
244                     return;
245                 }
246                 logger.debug("onTranscription Final");
247                 Double confidence = alternative.getConfidence();
248                 transcriptBuilder.append(alternative.getTranscript());
249                 confidenceSum += confidence != null ? confidence.floatValue() : 0f;
250                 responseCount++;
251                 if (config.singleUtteranceMode) {
252                     var socket = socketRef.get();
253                     if (socket != null) {
254                         sendStopMessage(socket);
255                     }
256                 }
257             });
258         }
259
260         @Override
261         public void onConnected() {
262             logger.debug("onConnected");
263         }
264
265         @Override
266         public void onError(@Nullable Exception e) {
267             var errorMessage = e != null ? e.getMessage() : null;
268             if (errorMessage != null && disconnected && errorMessage.contains("Socket closed")) {
269                 logger.debug("Error ignored: {}", errorMessage);
270                 return;
271             }
272             logger.warn("TranscriptionError: {}", errorMessage);
273             if (!aborted.getAndSet(true)) {
274                 sttListener.sttEventReceived(
275                         new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
276             }
277         }
278
279         @Override
280         public void onDisconnected() {
281             logger.debug("onDisconnected");
282             disconnected = true;
283             if (!aborted.getAndSet(true)) {
284                 sttListener.sttEventReceived(new RecognitionStopEvent());
285                 float averageConfidence = confidenceSum / (float) responseCount;
286                 String transcript = transcriptBuilder.toString().trim();
287                 if (!transcript.isBlank()) {
288                     sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
289                 } else {
290                     if (!config.noResultsMessage.isBlank()) {
291                         sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
292                     } else {
293                         sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
294                     }
295                 }
296             }
297         }
298
299         @Override
300         public void onInactivityTimeout(@Nullable RuntimeException e) {
301             if (e != null) {
302                 logger.debug("InactivityTimeout: {}", e.getMessage());
303             }
304         }
305
306         @Override
307         public void onListening() {
308             logger.debug("onListening");
309             sttListener.sttEventReceived(new RecognitionStartEvent());
310         }
311
312         @Override
313         public void onTranscriptionComplete() {
314             logger.debug("onTranscriptionComplete");
315         }
316     }
317 }