]> git.basschouten.com Git - openhab-addons.git/blob
ebd5c0759a069eb3c0e7461185e90d80adaefdf8
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2022 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.watsonstt.internal;
14
15 import static org.openhab.voice.watsonstt.internal.WatsonSTTConstants.*;
16
17 import java.util.List;
18 import java.util.Locale;
19 import java.util.Map;
20 import java.util.Set;
21 import java.util.concurrent.ScheduledExecutorService;
22 import java.util.concurrent.atomic.AtomicBoolean;
23 import java.util.concurrent.atomic.AtomicReference;
24 import java.util.stream.Collectors;
25
26 import javax.net.ssl.SSLPeerUnverifiedException;
27
28 import org.eclipse.jdt.annotation.NonNullByDefault;
29 import org.eclipse.jdt.annotation.Nullable;
30 import org.openhab.core.audio.AudioFormat;
31 import org.openhab.core.audio.AudioStream;
32 import org.openhab.core.common.ThreadPoolManager;
33 import org.openhab.core.config.core.ConfigurableService;
34 import org.openhab.core.config.core.Configuration;
35 import org.openhab.core.voice.RecognitionStartEvent;
36 import org.openhab.core.voice.RecognitionStopEvent;
37 import org.openhab.core.voice.STTException;
38 import org.openhab.core.voice.STTListener;
39 import org.openhab.core.voice.STTService;
40 import org.openhab.core.voice.STTServiceHandle;
41 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
42 import org.openhab.core.voice.SpeechRecognitionEvent;
43 import org.osgi.framework.Constants;
44 import org.osgi.service.component.annotations.Activate;
45 import org.osgi.service.component.annotations.Component;
46 import org.osgi.service.component.annotations.Modified;
47 import org.slf4j.Logger;
48 import org.slf4j.LoggerFactory;
49
50 import com.ibm.cloud.sdk.core.http.HttpMediaType;
51 import com.ibm.cloud.sdk.core.security.IamAuthenticator;
52 import com.ibm.watson.speech_to_text.v1.SpeechToText;
53 import com.ibm.watson.speech_to_text.v1.model.RecognizeWithWebsocketsOptions;
54 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionAlternative;
55 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResult;
56 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResults;
57 import com.ibm.watson.speech_to_text.v1.websocket.RecognizeCallback;
58
59 import okhttp3.WebSocket;
60
61 /**
62  * The {@link WatsonSTTService} allows to use Watson as Speech-to-Text engine
63  *
64  * @author Miguel Álvarez - Initial contribution
65  */
66 @NonNullByDefault
67 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
68 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
69         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
70 public class WatsonSTTService implements STTService {
71     private final Logger logger = LoggerFactory.getLogger(WatsonSTTService.class);
72     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-watsonstt");
73     private final List<String> models = List.of("ar-AR_BroadbandModel", "de-DE_BroadbandModel", "en-AU_BroadbandModel",
74             "en-GB_BroadbandModel", "en-US_BroadbandModel", "es-AR_BroadbandModel", "es-CL_BroadbandModel",
75             "es-CO_BroadbandModel", "es-ES_BroadbandModel", "es-MX_BroadbandModel", "es-PE_BroadbandModel",
76             "fr-CA_BroadbandModel", "fr-FR_BroadbandModel", "it-IT_BroadbandModel", "ja-JP_BroadbandModel",
77             "ko-KR_BroadbandModel", "nl-NL_BroadbandModel", "pt-BR_BroadbandModel", "zh-CN_BroadbandModel");
78     private final Set<Locale> supportedLocales = models.stream().map(name -> name.split("_")[0])
79             .map(Locale::forLanguageTag).collect(Collectors.toSet());
80     private WatsonSTTConfiguration config = new WatsonSTTConfiguration();
81
82     @Activate
83     protected void activate(Map<String, Object> config) {
84         this.config = new Configuration(config).as(WatsonSTTConfiguration.class);
85     }
86
87     @Modified
88     protected void modified(Map<String, Object> config) {
89         this.config = new Configuration(config).as(WatsonSTTConfiguration.class);
90     }
91
92     @Override
93     public String getId() {
94         return SERVICE_ID;
95     }
96
97     @Override
98     public String getLabel(@Nullable Locale locale) {
99         return SERVICE_NAME;
100     }
101
102     @Override
103     public Set<Locale> getSupportedLocales() {
104         return supportedLocales;
105     }
106
107     @Override
108     public Set<AudioFormat> getSupportedFormats() {
109         return Set.of(AudioFormat.WAV, AudioFormat.OGG, new AudioFormat("OGG", "OPUS", null, null, null, null),
110                 AudioFormat.MP3);
111     }
112
113     @Override
114     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale, Set<String> set)
115             throws STTException {
116         if (config.apiKey.isBlank() || config.instanceUrl.isBlank()) {
117             throw new STTException("service is not correctly configured");
118         }
119         String contentType = getContentType(audioStream);
120         if (contentType == null) {
121             throw new STTException("Unsupported format, unable to resolve audio content type");
122         }
123         logger.debug("Content-Type: {}", contentType);
124         var speechToText = new SpeechToText(new IamAuthenticator.Builder().apikey(config.apiKey).build());
125         speechToText.setServiceUrl(config.instanceUrl);
126         if (config.optOutLogging) {
127             speechToText.setDefaultHeaders(Map.of("X-Watson-Learning-Opt-Out", "1"));
128         }
129         RecognizeWithWebsocketsOptions wsOptions = new RecognizeWithWebsocketsOptions.Builder().audio(audioStream)
130                 .contentType(contentType).redaction(config.redaction).smartFormatting(config.smartFormatting)
131                 .model(locale.toLanguageTag() + "_BroadbandModel").interimResults(true)
132                 .backgroundAudioSuppression(config.backgroundAudioSuppression)
133                 .speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.inactivityTimeout)
134                 .build();
135         final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
136         final AtomicBoolean aborted = new AtomicBoolean(false);
137         executor.submit(() -> {
138             int retries = 2;
139             while (retries > 0) {
140                 try {
141                     socketRef.set(speechToText.recognizeUsingWebSocket(wsOptions,
142                             new TranscriptionListener(sttListener, config, aborted)));
143                     break;
144                 } catch (RuntimeException e) {
145                     var cause = e.getCause();
146                     if (cause instanceof SSLPeerUnverifiedException) {
147                         logger.debug("Retrying on error: {}", cause.getMessage());
148                         retries--;
149                     } else {
150                         var errorMessage = e.getMessage();
151                         logger.warn("Aborting on error: {}", errorMessage);
152                         sttListener.sttEventReceived(
153                                 new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
154                         break;
155                     }
156                 }
157             }
158         });
159         return new STTServiceHandle() {
160             @Override
161             public void abort() {
162                 if (!aborted.getAndSet(true)) {
163                     var socket = socketRef.get();
164                     if (socket != null) {
165                         socket.close(1000, null);
166                         socket.cancel();
167                         try {
168                             Thread.sleep(100);
169                         } catch (InterruptedException ignored) {
170                         }
171                     }
172                 }
173             }
174         };
175     }
176
177     private @Nullable String getContentType(AudioStream audioStream) throws STTException {
178         AudioFormat format = audioStream.getFormat();
179         String container = format.getContainer();
180         String codec = format.getCodec();
181         if (container == null || codec == null) {
182             throw new STTException("Missing audio stream info");
183         }
184         Long frequency = format.getFrequency();
185         Integer bitDepth = format.getBitDepth();
186         switch (container) {
187             case AudioFormat.CONTAINER_WAVE:
188                 if (AudioFormat.CODEC_PCM_SIGNED.equals(codec)) {
189                     if (bitDepth == null || bitDepth != 16) {
190                         return "audio/wav";
191                     }
192                     // rate is a required parameter for this type
193                     if (frequency == null) {
194                         return null;
195                     }
196                     StringBuilder contentTypeL16 = new StringBuilder(HttpMediaType.AUDIO_PCM).append(";rate=")
197                             .append(frequency);
198                     // // those are optional
199                     Integer channels = format.getChannels();
200                     if (channels != null) {
201                         contentTypeL16.append(";channels=").append(channels);
202                     }
203                     Boolean bigEndian = format.isBigEndian();
204                     if (bigEndian != null) {
205                         contentTypeL16.append(";")
206                                 .append(bigEndian ? "endianness=big-endian" : "endianness=little-endian");
207                     }
208                     return contentTypeL16.toString();
209                 }
210             case AudioFormat.CONTAINER_OGG:
211                 switch (codec) {
212                     case AudioFormat.CODEC_VORBIS:
213                         return "audio/ogg;codecs=vorbis";
214                     case "OPUS":
215                         return "audio/ogg;codecs=opus";
216                 }
217                 break;
218             case AudioFormat.CONTAINER_NONE:
219                 if (AudioFormat.CODEC_MP3.equals(codec)) {
220                     return "audio/mp3";
221                 }
222                 break;
223         }
224         return null;
225     }
226
227     private static class TranscriptionListener implements RecognizeCallback {
228         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
229         private final StringBuilder transcriptBuilder = new StringBuilder();
230         private final STTListener sttListener;
231         private final WatsonSTTConfiguration config;
232         private final AtomicBoolean aborted;
233         private float confidenceSum = 0f;
234         private int responseCount = 0;
235         private boolean disconnected = false;
236
237         public TranscriptionListener(STTListener sttListener, WatsonSTTConfiguration config, AtomicBoolean aborted) {
238             this.sttListener = sttListener;
239             this.config = config;
240             this.aborted = aborted;
241         }
242
243         @Override
244         public void onTranscription(@Nullable SpeechRecognitionResults speechRecognitionResults) {
245             logger.debug("onTranscription");
246             if (speechRecognitionResults == null) {
247                 return;
248             }
249             speechRecognitionResults.getResults().stream().filter(SpeechRecognitionResult::isXFinal).forEach(result -> {
250                 SpeechRecognitionAlternative alternative = result.getAlternatives().stream().findFirst().orElse(null);
251                 if (alternative == null) {
252                     return;
253                 }
254                 logger.debug("onTranscription Final");
255                 Double confidence = alternative.getConfidence();
256                 transcriptBuilder.append(alternative.getTranscript());
257                 confidenceSum += confidence != null ? confidence.floatValue() : 0f;
258                 responseCount++;
259             });
260         }
261
262         @Override
263         public void onConnected() {
264             logger.debug("onConnected");
265         }
266
267         @Override
268         public void onError(@Nullable Exception e) {
269             var errorMessage = e != null ? e.getMessage() : null;
270             if (errorMessage != null && disconnected && errorMessage.contains("Socket closed")) {
271                 logger.debug("Error ignored: {}", errorMessage);
272                 return;
273             }
274             logger.warn("TranscriptionError: {}", errorMessage);
275             if (!aborted.get()) {
276                 sttListener.sttEventReceived(
277                         new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
278             }
279         }
280
281         @Override
282         public void onDisconnected() {
283             logger.debug("onDisconnected");
284             disconnected = true;
285             if (!aborted.getAndSet(true)) {
286                 sttListener.sttEventReceived(new RecognitionStopEvent());
287                 float averageConfidence = confidenceSum / (float) responseCount;
288                 String transcript = transcriptBuilder.toString();
289                 if (!transcript.isBlank()) {
290                     sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
291                 } else {
292                     if (!config.noResultsMessage.isBlank()) {
293                         sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
294                     } else {
295                         sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
296                     }
297                 }
298             }
299         }
300
301         @Override
302         public void onInactivityTimeout(@Nullable RuntimeException e) {
303             if (e != null) {
304                 logger.debug("InactivityTimeout: {}", e.getMessage());
305             }
306         }
307
308         @Override
309         public void onListening() {
310             logger.debug("onListening");
311             sttListener.sttEventReceived(new RecognitionStartEvent());
312         }
313
314         @Override
315         public void onTranscriptionComplete() {
316             logger.debug("onTranscriptionComplete");
317         }
318     }
319 }