]> git.basschouten.com Git - openhab-addons.git/blob
161c742a5a797b0bfba17866183d8c79d39b95d2
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2023 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.watsonstt.internal;
14
15 import static org.openhab.voice.watsonstt.internal.WatsonSTTConstants.*;
16
17 import java.util.List;
18 import java.util.Locale;
19 import java.util.Map;
20 import java.util.Set;
21 import java.util.concurrent.ScheduledExecutorService;
22 import java.util.concurrent.atomic.AtomicBoolean;
23 import java.util.concurrent.atomic.AtomicReference;
24 import java.util.stream.Collectors;
25 import java.util.stream.Stream;
26
27 import org.eclipse.jdt.annotation.NonNullByDefault;
28 import org.eclipse.jdt.annotation.Nullable;
29 import org.openhab.core.audio.AudioFormat;
30 import org.openhab.core.audio.AudioStream;
31 import org.openhab.core.common.ThreadPoolManager;
32 import org.openhab.core.config.core.ConfigurableService;
33 import org.openhab.core.config.core.Configuration;
34 import org.openhab.core.voice.RecognitionStartEvent;
35 import org.openhab.core.voice.RecognitionStopEvent;
36 import org.openhab.core.voice.STTException;
37 import org.openhab.core.voice.STTListener;
38 import org.openhab.core.voice.STTService;
39 import org.openhab.core.voice.STTServiceHandle;
40 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
41 import org.openhab.core.voice.SpeechRecognitionEvent;
42 import org.osgi.framework.Constants;
43 import org.osgi.service.component.annotations.Activate;
44 import org.osgi.service.component.annotations.Component;
45 import org.osgi.service.component.annotations.Modified;
46 import org.slf4j.Logger;
47 import org.slf4j.LoggerFactory;
48
49 import com.google.gson.JsonObject;
50 import com.ibm.cloud.sdk.core.http.HttpMediaType;
51 import com.ibm.cloud.sdk.core.security.IamAuthenticator;
52 import com.ibm.watson.speech_to_text.v1.SpeechToText;
53 import com.ibm.watson.speech_to_text.v1.model.RecognizeWithWebsocketsOptions;
54 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionAlternative;
55 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResult;
56 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResults;
57 import com.ibm.watson.speech_to_text.v1.websocket.RecognizeCallback;
58
59 import okhttp3.WebSocket;
60
61 /**
62  * The {@link WatsonSTTService} allows to use Watson as Speech-to-Text engine
63  *
64  * @author Miguel Álvarez - Initial contribution
65  */
66 @NonNullByDefault
67 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
68 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
69         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
70 public class WatsonSTTService implements STTService {
71     private final Logger logger = LoggerFactory.getLogger(WatsonSTTService.class);
72     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-watsonstt");
73     private final List<String> telephonyModels = List.of("ar-MS_Telephony", "zh-CN_Telephony", "nl-BE_Telephony",
74             "nl-NL_Telephony", "en-AU_Telephony", "en-IN_Telephony", "en-GB_Telephony", "en-US_Telephony",
75             "fr-CA_Telephony", "fr-FR_Telephony", "hi-IN_Telephony", "pt-BR_Telephony", "es-ES_Telephony");
76     private final List<String> multimediaModels = List.of("en-AU_Multimedia", "en-GB_Multimedia", "en-US_Multimedia",
77             "fr-FR_Multimedia", "de-DE_Multimedia", "it-IT_Multimedia", "ja-JP_Multimedia", "ko-KR_Multimedia",
78             "pt-BR_Multimedia", "es-ES_Multimedia");
79     // model 'en-WW_Medical_Telephony' and 'es-LA_Telephony' will be used as fallbacks for es and en
80     private final List<Locale> fallbackLocales = List.of(Locale.forLanguageTag("es"), Locale.ENGLISH);
81     private final Set<Locale> supportedLocales = Stream
82             .concat(Stream.concat(telephonyModels.stream(), multimediaModels.stream()).map(name -> name.split("_")[0])
83                     .distinct().map(Locale::forLanguageTag), fallbackLocales.stream())
84             .collect(Collectors.toSet());
85     private WatsonSTTConfiguration config = new WatsonSTTConfiguration();
86     private @Nullable SpeechToText speechToText = null;
87
88     @Activate
89     protected void activate(Map<String, Object> config) {
90         modified(config);
91     }
92
93     @Modified
94     protected void modified(Map<String, Object> config) {
95         this.config = new Configuration(config).as(WatsonSTTConfiguration.class);
96         if (this.config.apiKey.isBlank() || this.config.instanceUrl.isBlank()) {
97             this.speechToText = null;
98         } else {
99             var speechToText = new SpeechToText(new IamAuthenticator.Builder().apikey(this.config.apiKey).build());
100             speechToText.setServiceUrl(this.config.instanceUrl);
101             if (this.config.optOutLogging) {
102                 speechToText.setDefaultHeaders(Map.of("X-Watson-Learning-Opt-Out", "1"));
103             }
104             this.speechToText = speechToText;
105         }
106     }
107
108     @Override
109     public String getId() {
110         return SERVICE_ID;
111     }
112
113     @Override
114     public String getLabel(@Nullable Locale locale) {
115         return SERVICE_NAME;
116     }
117
118     @Override
119     public Set<Locale> getSupportedLocales() {
120         return supportedLocales;
121     }
122
123     @Override
124     public Set<AudioFormat> getSupportedFormats() {
125         return Set.of(AudioFormat.WAV, AudioFormat.OGG, new AudioFormat("OGG", "OPUS", null, null, null, null),
126                 AudioFormat.MP3);
127     }
128
129     @Override
130     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale, Set<String> set)
131             throws STTException {
132         var stt = this.speechToText;
133         if (stt == null) {
134             throw new STTException("service is not correctly configured");
135         }
136         String contentType = getContentType(audioStream);
137         if (contentType == null) {
138             throw new STTException("Unsupported format, unable to resolve audio content type");
139         }
140         logger.debug("Content-Type: {}", contentType);
141         RecognizeWithWebsocketsOptions wsOptions = new RecognizeWithWebsocketsOptions.Builder().audio(audioStream)
142                 .contentType(contentType).redaction(config.redaction).smartFormatting(config.smartFormatting)
143                 .model(getModel(locale)).interimResults(true)
144                 .backgroundAudioSuppression(config.backgroundAudioSuppression)
145                 .speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.maxSilenceSeconds)
146                 .build();
147         final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
148         final AtomicBoolean aborted = new AtomicBoolean(false);
149         executor.submit(() -> {
150             socketRef.set(stt.recognizeUsingWebSocket(wsOptions,
151                     new TranscriptionListener(socketRef, sttListener, config, aborted)));
152         });
153         return new STTServiceHandle() {
154             @Override
155             public void abort() {
156                 if (!aborted.getAndSet(true)) {
157                     var socket = socketRef.get();
158                     if (socket != null) {
159                         sendStopMessage(socket);
160                     }
161                 }
162             }
163         };
164     }
165
166     private String getModel(Locale locale) throws STTException {
167         String languageTag = locale.toLanguageTag();
168         Stream<String> allModels;
169         if (config.preferMultimediaModel) {
170             allModels = Stream.concat(multimediaModels.stream(), telephonyModels.stream());
171         } else {
172             allModels = Stream.concat(telephonyModels.stream(), multimediaModels.stream());
173         }
174         var modelOption = allModels.filter(model -> model.startsWith(languageTag)).findFirst();
175         if (modelOption.isEmpty()) {
176             if ("es".equals(locale.getLanguage())) {
177                 // fallback for latin american spanish languages
178                 var model = "es-LA_Telephony";
179                 logger.debug("Falling back to model: {}", model);
180             }
181             if ("en".equals(locale.getLanguage())) {
182                 // fallback english dialects
183                 var model = "en-WW_Medical_Telephony";
184                 logger.debug("Falling back to model: {}", model);
185             }
186             throw new STTException("No compatible model for language " + languageTag);
187         }
188         var model = modelOption.get();
189         logger.debug("Using model: {}", model);
190         return model;
191     }
192
193     private @Nullable String getContentType(AudioStream audioStream) throws STTException {
194         AudioFormat format = audioStream.getFormat();
195         String container = format.getContainer();
196         String codec = format.getCodec();
197         if (container == null || codec == null) {
198             throw new STTException("Missing audio stream info");
199         }
200         Long frequency = format.getFrequency();
201         Integer bitDepth = format.getBitDepth();
202         switch (container) {
203             case AudioFormat.CONTAINER_WAVE:
204                 if (AudioFormat.CODEC_PCM_SIGNED.equals(codec)) {
205                     if (bitDepth == null || bitDepth != 16) {
206                         return "audio/wav";
207                     }
208                     // rate is a required parameter for this type
209                     if (frequency == null) {
210                         return null;
211                     }
212                     StringBuilder contentTypeL16 = new StringBuilder(HttpMediaType.AUDIO_PCM).append(";rate=")
213                             .append(frequency);
214                     // // those are optional
215                     Integer channels = format.getChannels();
216                     if (channels != null) {
217                         contentTypeL16.append(";channels=").append(channels);
218                     }
219                     Boolean bigEndian = format.isBigEndian();
220                     if (bigEndian != null) {
221                         contentTypeL16.append(";")
222                                 .append(bigEndian ? "endianness=big-endian" : "endianness=little-endian");
223                     }
224                     return contentTypeL16.toString();
225                 }
226             case AudioFormat.CONTAINER_OGG:
227                 switch (codec) {
228                     case AudioFormat.CODEC_VORBIS:
229                         return "audio/ogg;codecs=vorbis";
230                     case "OPUS":
231                         return "audio/ogg;codecs=opus";
232                 }
233                 break;
234             case AudioFormat.CONTAINER_NONE:
235                 if (AudioFormat.CODEC_MP3.equals(codec)) {
236                     return "audio/mp3";
237                 }
238                 break;
239         }
240         return null;
241     }
242
243     private static void sendStopMessage(WebSocket ws) {
244         JsonObject stopMessage = new JsonObject();
245         stopMessage.addProperty("action", "stop");
246         ws.send(stopMessage.toString());
247     }
248
249     private static class TranscriptionListener implements RecognizeCallback {
250         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
251         private final StringBuilder transcriptBuilder = new StringBuilder();
252         private final STTListener sttListener;
253         private final WatsonSTTConfiguration config;
254         private final AtomicBoolean aborted;
255         private final AtomicReference<@Nullable WebSocket> socketRef;
256         private float confidenceSum = 0f;
257         private int responseCount = 0;
258         private boolean disconnected = false;
259
260         public TranscriptionListener(AtomicReference<@Nullable WebSocket> socketRef, STTListener sttListener,
261                 WatsonSTTConfiguration config, AtomicBoolean aborted) {
262             this.socketRef = socketRef;
263             this.sttListener = sttListener;
264             this.config = config;
265             this.aborted = aborted;
266         }
267
268         @Override
269         public void onTranscription(@Nullable SpeechRecognitionResults speechRecognitionResults) {
270             logger.debug("onTranscription");
271             if (speechRecognitionResults == null) {
272                 return;
273             }
274             speechRecognitionResults.getResults().stream().filter(SpeechRecognitionResult::isXFinal).forEach(result -> {
275                 SpeechRecognitionAlternative alternative = result.getAlternatives().stream().findFirst().orElse(null);
276                 if (alternative == null) {
277                     return;
278                 }
279                 logger.debug("onTranscription Final");
280                 Double confidence = alternative.getConfidence();
281                 transcriptBuilder.append(alternative.getTranscript());
282                 confidenceSum += confidence != null ? confidence.floatValue() : 0f;
283                 responseCount++;
284                 if (config.singleUtteranceMode) {
285                     var socket = socketRef.get();
286                     if (socket != null) {
287                         sendStopMessage(socket);
288                     }
289                 }
290             });
291         }
292
293         @Override
294         public void onConnected() {
295             logger.debug("onConnected");
296         }
297
298         @Override
299         public void onError(@Nullable Exception e) {
300             var errorMessage = e != null ? e.getMessage() : null;
301             if (errorMessage != null && disconnected && errorMessage.contains("Socket closed")) {
302                 logger.debug("Error ignored: {}", errorMessage);
303                 return;
304             }
305             logger.warn("TranscriptionError: {}", errorMessage);
306             if (!aborted.getAndSet(true)) {
307                 sttListener.sttEventReceived(
308                         new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
309             }
310         }
311
312         @Override
313         public void onDisconnected() {
314             logger.debug("onDisconnected");
315             disconnected = true;
316             if (!aborted.getAndSet(true)) {
317                 sttListener.sttEventReceived(new RecognitionStopEvent());
318                 float averageConfidence = confidenceSum / (float) responseCount;
319                 String transcript = transcriptBuilder.toString().trim();
320                 if (!transcript.isBlank()) {
321                     sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
322                 } else {
323                     if (!config.noResultsMessage.isBlank()) {
324                         sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
325                     } else {
326                         sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
327                     }
328                 }
329             }
330         }
331
332         @Override
333         public void onInactivityTimeout(@Nullable RuntimeException e) {
334             if (e != null) {
335                 logger.debug("InactivityTimeout: {}", e.getMessage());
336             }
337         }
338
339         @Override
340         public void onListening() {
341             logger.debug("onListening");
342             sttListener.sttEventReceived(new RecognitionStartEvent());
343         }
344
345         @Override
346         public void onTranscriptionComplete() {
347             logger.debug("onTranscriptionComplete");
348         }
349     }
350 }