]> git.basschouten.com Git - openhab-addons.git/blob
a3dc80d0d7eb5a20a8baed00b71156ff28246613
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2022 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.watsonstt.internal;
14
15 import static org.openhab.voice.watsonstt.internal.WatsonSTTConstants.*;
16
17 import java.util.List;
18 import java.util.Locale;
19 import java.util.Map;
20 import java.util.Set;
21 import java.util.concurrent.ScheduledExecutorService;
22 import java.util.concurrent.atomic.AtomicReference;
23 import java.util.stream.Collectors;
24
25 import javax.net.ssl.SSLPeerUnverifiedException;
26
27 import org.eclipse.jdt.annotation.NonNullByDefault;
28 import org.eclipse.jdt.annotation.Nullable;
29 import org.openhab.core.audio.AudioFormat;
30 import org.openhab.core.audio.AudioStream;
31 import org.openhab.core.common.ThreadPoolManager;
32 import org.openhab.core.config.core.ConfigurableService;
33 import org.openhab.core.config.core.Configuration;
34 import org.openhab.core.voice.RecognitionStartEvent;
35 import org.openhab.core.voice.RecognitionStopEvent;
36 import org.openhab.core.voice.STTException;
37 import org.openhab.core.voice.STTListener;
38 import org.openhab.core.voice.STTService;
39 import org.openhab.core.voice.STTServiceHandle;
40 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
41 import org.openhab.core.voice.SpeechRecognitionEvent;
42 import org.osgi.framework.Constants;
43 import org.osgi.service.component.annotations.Activate;
44 import org.osgi.service.component.annotations.Component;
45 import org.osgi.service.component.annotations.Modified;
46 import org.slf4j.Logger;
47 import org.slf4j.LoggerFactory;
48
49 import com.ibm.cloud.sdk.core.http.HttpMediaType;
50 import com.ibm.cloud.sdk.core.security.IamAuthenticator;
51 import com.ibm.watson.speech_to_text.v1.SpeechToText;
52 import com.ibm.watson.speech_to_text.v1.model.RecognizeWithWebsocketsOptions;
53 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionAlternative;
54 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResult;
55 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResults;
56 import com.ibm.watson.speech_to_text.v1.websocket.RecognizeCallback;
57
58 import okhttp3.WebSocket;
59
60 /**
61  * The {@link WatsonSTTService} allows to use Watson as Speech-to-Text engine
62  *
63  * @author Miguel Álvarez - Initial contribution
64  */
65 @NonNullByDefault
66 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
67 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
68         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
69 public class WatsonSTTService implements STTService {
70     private final Logger logger = LoggerFactory.getLogger(WatsonSTTService.class);
71     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-watsonstt");
72     private final List<String> models = List.of("ar-AR_BroadbandModel", "de-DE_BroadbandModel", "en-AU_BroadbandModel",
73             "en-GB_BroadbandModel", "en-US_BroadbandModel", "es-AR_BroadbandModel", "es-CL_BroadbandModel",
74             "es-CO_BroadbandModel", "es-ES_BroadbandModel", "es-MX_BroadbandModel", "es-PE_BroadbandModel",
75             "fr-CA_BroadbandModel", "fr-FR_BroadbandModel", "it-IT_BroadbandModel", "ja-JP_BroadbandModel",
76             "ko-KR_BroadbandModel", "nl-NL_BroadbandModel", "pt-BR_BroadbandModel", "zh-CN_BroadbandModel");
77     private final Set<Locale> supportedLocales = models.stream().map(name -> name.split("_")[0])
78             .map(Locale::forLanguageTag).collect(Collectors.toSet());
79     private WatsonSTTConfiguration config = new WatsonSTTConfiguration();
80
81     @Activate
82     protected void activate(Map<String, Object> config) {
83         this.config = new Configuration(config).as(WatsonSTTConfiguration.class);
84     }
85
86     @Modified
87     protected void modified(Map<String, Object> config) {
88         this.config = new Configuration(config).as(WatsonSTTConfiguration.class);
89     }
90
91     @Override
92     public String getId() {
93         return SERVICE_ID;
94     }
95
96     @Override
97     public String getLabel(@Nullable Locale locale) {
98         return SERVICE_NAME;
99     }
100
101     @Override
102     public Set<Locale> getSupportedLocales() {
103         return supportedLocales;
104     }
105
106     @Override
107     public Set<AudioFormat> getSupportedFormats() {
108         return Set.of(AudioFormat.WAV, AudioFormat.OGG, new AudioFormat("OGG", "OPUS", null, null, null, null),
109                 AudioFormat.MP3);
110     }
111
112     @Override
113     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale, Set<String> set)
114             throws STTException {
115         if (config.apiKey.isBlank() || config.instanceUrl.isBlank()) {
116             throw new STTException("service is not correctly configured");
117         }
118         String contentType = getContentType(audioStream);
119         if (contentType == null) {
120             throw new STTException("Unsupported format, unable to resolve audio content type");
121         }
122         logger.debug("Content-Type: {}", contentType);
123         var speechToText = new SpeechToText(new IamAuthenticator.Builder().apikey(config.apiKey).build());
124         speechToText.setServiceUrl(config.instanceUrl);
125         if (config.optOutLogging) {
126             speechToText.setDefaultHeaders(Map.of("X-Watson-Learning-Opt-Out", "1"));
127         }
128         RecognizeWithWebsocketsOptions wsOptions = new RecognizeWithWebsocketsOptions.Builder().audio(audioStream)
129                 .contentType(contentType).redaction(config.redaction).smartFormatting(config.smartFormatting)
130                 .model(locale.toLanguageTag() + "_BroadbandModel").interimResults(true)
131                 .backgroundAudioSuppression(config.backgroundAudioSuppression)
132                 .speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.inactivityTimeout)
133                 .build();
134         final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
135         var task = executor.submit(() -> {
136             int retries = 2;
137             while (retries > 0) {
138                 try {
139                     socketRef.set(speechToText.recognizeUsingWebSocket(wsOptions,
140                             new TranscriptionListener(sttListener, config)));
141                     break;
142                 } catch (RuntimeException e) {
143                     var cause = e.getCause();
144                     if (cause instanceof SSLPeerUnverifiedException) {
145                         logger.debug("Retrying on error: {}", cause.getMessage());
146                         retries--;
147                     } else {
148                         var errorMessage = e.getMessage();
149                         logger.warn("Aborting on error: {}", errorMessage);
150                         sttListener.sttEventReceived(
151                                 new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
152                         break;
153                     }
154                 }
155             }
156         });
157         return new STTServiceHandle() {
158             @Override
159             public void abort() {
160                 var socket = socketRef.get();
161                 if (socket != null) {
162                     socket.close(1000, null);
163                     socket.cancel();
164                     try {
165                         Thread.sleep(100);
166                     } catch (InterruptedException ignored) {
167                     }
168                 }
169                 task.cancel(true);
170             }
171         };
172     }
173
174     private @Nullable String getContentType(AudioStream audioStream) throws STTException {
175         AudioFormat format = audioStream.getFormat();
176         String container = format.getContainer();
177         String codec = format.getCodec();
178         if (container == null || codec == null) {
179             throw new STTException("Missing audio stream info");
180         }
181         Long frequency = format.getFrequency();
182         Integer bitDepth = format.getBitDepth();
183         switch (container) {
184             case AudioFormat.CONTAINER_WAVE:
185                 if (AudioFormat.CODEC_PCM_SIGNED.equals(codec)) {
186                     if (bitDepth == null || bitDepth != 16) {
187                         return "audio/wav";
188                     }
189                     // rate is a required parameter for this type
190                     if (frequency == null) {
191                         return null;
192                     }
193                     StringBuilder contentTypeL16 = new StringBuilder(HttpMediaType.AUDIO_PCM).append(";rate=")
194                             .append(frequency);
195                     // // those are optional
196                     Integer channels = format.getChannels();
197                     if (channels != null) {
198                         contentTypeL16.append(";channels=").append(channels);
199                     }
200                     Boolean bigEndian = format.isBigEndian();
201                     if (bigEndian != null) {
202                         contentTypeL16.append(";")
203                                 .append(bigEndian ? "endianness=big-endian" : "endianness=little-endian");
204                     }
205                     return contentTypeL16.toString();
206                 }
207             case AudioFormat.CONTAINER_OGG:
208                 switch (codec) {
209                     case AudioFormat.CODEC_VORBIS:
210                         return "audio/ogg;codecs=vorbis";
211                     case "OPUS":
212                         return "audio/ogg;codecs=opus";
213                 }
214                 break;
215             case AudioFormat.CONTAINER_NONE:
216                 if (AudioFormat.CODEC_MP3.equals(codec)) {
217                     return "audio/mp3";
218                 }
219                 break;
220         }
221         return null;
222     }
223
224     private static class TranscriptionListener implements RecognizeCallback {
225         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
226         private final StringBuilder transcriptBuilder = new StringBuilder();
227         private final STTListener sttListener;
228         private final WatsonSTTConfiguration config;
229         private float confidenceSum = 0f;
230         private int responseCount = 0;
231         private boolean disconnected = false;
232
233         public TranscriptionListener(STTListener sttListener, WatsonSTTConfiguration config) {
234             this.sttListener = sttListener;
235             this.config = config;
236         }
237
238         @Override
239         public void onTranscription(@Nullable SpeechRecognitionResults speechRecognitionResults) {
240             logger.debug("onTranscription");
241             if (speechRecognitionResults == null) {
242                 return;
243             }
244             speechRecognitionResults.getResults().stream().filter(SpeechRecognitionResult::isXFinal).forEach(result -> {
245                 SpeechRecognitionAlternative alternative = result.getAlternatives().stream().findFirst().orElse(null);
246                 if (alternative == null) {
247                     return;
248                 }
249                 logger.debug("onTranscription Final");
250                 Double confidence = alternative.getConfidence();
251                 transcriptBuilder.append(alternative.getTranscript());
252                 confidenceSum += confidence != null ? confidence.floatValue() : 0f;
253                 responseCount++;
254             });
255         }
256
257         @Override
258         public void onConnected() {
259             logger.debug("onConnected");
260         }
261
262         @Override
263         public void onError(@Nullable Exception e) {
264             var errorMessage = e != null ? e.getMessage() : null;
265             if (errorMessage != null && disconnected && errorMessage.contains("Socket closed")) {
266                 logger.debug("Error ignored: {}", errorMessage);
267                 return;
268             }
269             logger.warn("TranscriptionError: {}", errorMessage);
270             sttListener.sttEventReceived(
271                     new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
272         }
273
274         @Override
275         public void onDisconnected() {
276             logger.debug("onDisconnected");
277             disconnected = true;
278             sttListener.sttEventReceived(new RecognitionStopEvent());
279             float averageConfidence = confidenceSum / (float) responseCount;
280             String transcript = transcriptBuilder.toString();
281             if (!transcript.isBlank()) {
282                 sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
283             } else {
284                 if (!config.noResultsMessage.isBlank()) {
285                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
286                 } else {
287                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
288                 }
289             }
290         }
291
292         @Override
293         public void onInactivityTimeout(@Nullable RuntimeException e) {
294             if (e != null) {
295                 logger.debug("InactivityTimeout: {}", e.getMessage());
296             }
297         }
298
299         @Override
300         public void onListening() {
301             logger.debug("onListening");
302             sttListener.sttEventReceived(new RecognitionStartEvent());
303         }
304
305         @Override
306         public void onTranscriptionComplete() {
307             logger.debug("onTranscriptionComplete");
308         }
309     }
310 }