]> git.basschouten.com Git - openhab-addons.git/blob
70bff2a574206469e82e154978757aaedcb1ebd8
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2024 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.watsonstt.internal;
14
15 import static org.openhab.voice.watsonstt.internal.WatsonSTTConstants.*;
16
17 import java.io.IOException;
18 import java.util.List;
19 import java.util.Locale;
20 import java.util.Map;
21 import java.util.Set;
22 import java.util.concurrent.ScheduledExecutorService;
23 import java.util.concurrent.atomic.AtomicBoolean;
24 import java.util.concurrent.atomic.AtomicReference;
25 import java.util.stream.Collectors;
26 import java.util.stream.Stream;
27
28 import org.eclipse.jdt.annotation.NonNullByDefault;
29 import org.eclipse.jdt.annotation.Nullable;
30 import org.openhab.core.audio.AudioFormat;
31 import org.openhab.core.audio.AudioStream;
32 import org.openhab.core.audio.utils.AudioWaveUtils;
33 import org.openhab.core.common.ThreadPoolManager;
34 import org.openhab.core.config.core.ConfigurableService;
35 import org.openhab.core.config.core.Configuration;
36 import org.openhab.core.voice.RecognitionStartEvent;
37 import org.openhab.core.voice.RecognitionStopEvent;
38 import org.openhab.core.voice.STTException;
39 import org.openhab.core.voice.STTListener;
40 import org.openhab.core.voice.STTService;
41 import org.openhab.core.voice.STTServiceHandle;
42 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
43 import org.openhab.core.voice.SpeechRecognitionEvent;
44 import org.osgi.framework.Constants;
45 import org.osgi.service.component.annotations.Activate;
46 import org.osgi.service.component.annotations.Component;
47 import org.osgi.service.component.annotations.Modified;
48 import org.slf4j.Logger;
49 import org.slf4j.LoggerFactory;
50
51 import com.google.gson.JsonObject;
52 import com.ibm.cloud.sdk.core.http.HttpMediaType;
53 import com.ibm.cloud.sdk.core.security.IamAuthenticator;
54 import com.ibm.watson.speech_to_text.v1.SpeechToText;
55 import com.ibm.watson.speech_to_text.v1.model.RecognizeWithWebsocketsOptions;
56 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionAlternative;
57 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResult;
58 import com.ibm.watson.speech_to_text.v1.model.SpeechRecognitionResults;
59 import com.ibm.watson.speech_to_text.v1.websocket.RecognizeCallback;
60
61 import okhttp3.WebSocket;
62
63 /**
64  * The {@link WatsonSTTService} allows to use Watson as Speech-to-Text engine
65  *
66  * @author Miguel Álvarez - Initial contribution
67  */
68 @NonNullByDefault
69 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
70 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
71         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
72 public class WatsonSTTService implements STTService {
73     private final Logger logger = LoggerFactory.getLogger(WatsonSTTService.class);
74     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-watsonstt");
75     private final List<String> telephonyModels = List.of("ar-MS_Telephony", "zh-CN_Telephony", "nl-BE_Telephony",
76             "nl-NL_Telephony", "en-AU_Telephony", "en-IN_Telephony", "en-GB_Telephony", "en-US_Telephony",
77             "fr-CA_Telephony", "fr-FR_Telephony", "hi-IN_Telephony", "pt-BR_Telephony", "es-ES_Telephony");
78     private final List<String> multimediaModels = List.of("en-AU_Multimedia", "en-GB_Multimedia", "en-US_Multimedia",
79             "fr-FR_Multimedia", "de-DE_Multimedia", "it-IT_Multimedia", "ja-JP_Multimedia", "ko-KR_Multimedia",
80             "pt-BR_Multimedia", "es-ES_Multimedia");
81     // model 'en-WW_Medical_Telephony' and 'es-LA_Telephony' will be used as fallbacks for es and en
82     private final List<Locale> fallbackLocales = List.of(Locale.forLanguageTag("es"), Locale.ENGLISH);
83     private final Set<Locale> supportedLocales = Stream
84             .concat(Stream.concat(telephonyModels.stream(), multimediaModels.stream()).map(name -> name.split("_")[0])
85                     .distinct().map(Locale::forLanguageTag), fallbackLocales.stream())
86             .collect(Collectors.toSet());
87     private WatsonSTTConfiguration config = new WatsonSTTConfiguration();
88     private @Nullable SpeechToText speechToText = null;
89
90     @Activate
91     protected void activate(Map<String, Object> config) {
92         modified(config);
93     }
94
95     @Modified
96     protected void modified(Map<String, Object> config) {
97         this.config = new Configuration(config).as(WatsonSTTConfiguration.class);
98         if (this.config.apiKey.isBlank() || this.config.instanceUrl.isBlank()) {
99             this.speechToText = null;
100         } else {
101             var speechToText = new SpeechToText(new IamAuthenticator.Builder().apikey(this.config.apiKey).build());
102             speechToText.setServiceUrl(this.config.instanceUrl);
103             if (this.config.optOutLogging) {
104                 speechToText.setDefaultHeaders(Map.of("X-Watson-Learning-Opt-Out", "1"));
105             }
106             this.speechToText = speechToText;
107         }
108     }
109
110     @Override
111     public String getId() {
112         return SERVICE_ID;
113     }
114
115     @Override
116     public String getLabel(@Nullable Locale locale) {
117         return SERVICE_NAME;
118     }
119
120     @Override
121     public Set<Locale> getSupportedLocales() {
122         return supportedLocales;
123     }
124
125     @Override
126     public Set<AudioFormat> getSupportedFormats() {
127         return Set.of(AudioFormat.PCM_SIGNED, AudioFormat.WAV);
128     }
129
130     @Override
131     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale, Set<String> set)
132             throws STTException {
133         var stt = this.speechToText;
134         if (stt == null) {
135             throw new STTException("service is not correctly configured");
136         }
137         String contentType = getContentType(audioStream);
138         if (contentType == null) {
139             throw new STTException("Unsupported format, unable to resolve audio content type");
140         }
141         logger.debug("Content-Type: {}", contentType);
142         RecognizeWithWebsocketsOptions wsOptions = new RecognizeWithWebsocketsOptions.Builder().audio(audioStream)
143                 .contentType(contentType).redaction(config.redaction).smartFormatting(config.smartFormatting)
144                 .model(getModel(locale)).interimResults(true)
145                 .backgroundAudioSuppression(config.backgroundAudioSuppression)
146                 .speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.maxSilenceSeconds)
147                 .build();
148         final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
149         final AtomicBoolean aborted = new AtomicBoolean(false);
150         executor.submit(() -> {
151             if (AudioFormat.CONTAINER_WAVE.equals(audioStream.getFormat().getContainer())) {
152                 try {
153                     AudioWaveUtils.removeFMT(audioStream);
154                 } catch (IOException e) {
155                     logger.warn("Error removing format header: {}", e.getMessage());
156                 }
157             }
158             socketRef.set(stt.recognizeUsingWebSocket(wsOptions,
159                     new TranscriptionListener(socketRef, sttListener, config, aborted)));
160         });
161         return new STTServiceHandle() {
162             @Override
163             public void abort() {
164                 if (!aborted.getAndSet(true)) {
165                     var socket = socketRef.get();
166                     if (socket != null) {
167                         sendStopMessage(socket);
168                     }
169                 }
170             }
171         };
172     }
173
174     private String getModel(Locale locale) throws STTException {
175         String languageTag = locale.toLanguageTag();
176         Stream<String> allModels;
177         if (config.preferMultimediaModel) {
178             allModels = Stream.concat(multimediaModels.stream(), telephonyModels.stream());
179         } else {
180             allModels = Stream.concat(telephonyModels.stream(), multimediaModels.stream());
181         }
182         var modelOption = allModels.filter(model -> model.startsWith(languageTag)).findFirst();
183         if (modelOption.isEmpty()) {
184             if ("es".equals(locale.getLanguage())) {
185                 // fallback for latin american spanish languages
186                 var model = "es-LA_Telephony";
187                 logger.debug("Falling back to model: {}", model);
188             }
189             if ("en".equals(locale.getLanguage())) {
190                 // fallback english dialects
191                 var model = "en-WW_Medical_Telephony";
192                 logger.debug("Falling back to model: {}", model);
193             }
194             throw new STTException("No compatible model for language " + languageTag);
195         }
196         var model = modelOption.get();
197         logger.debug("Using model: {}", model);
198         return model;
199     }
200
201     private @Nullable String getContentType(AudioStream audioStream) throws STTException {
202         AudioFormat format = audioStream.getFormat();
203         String container = format.getContainer();
204         String codec = format.getCodec();
205         if (container == null || codec == null) {
206             throw new STTException("Missing audio stream info");
207         }
208         Long frequency = format.getFrequency();
209         Integer bitDepth = format.getBitDepth();
210         switch (container) {
211             case AudioFormat.CONTAINER_NONE:
212                 if (AudioFormat.CODEC_MP3.equals(codec)) {
213                     return "audio/mp3";
214                 }
215             case AudioFormat.CONTAINER_WAVE:
216                 if (AudioFormat.CODEC_PCM_SIGNED.equals(codec)) {
217                     if (bitDepth == null || bitDepth != 16) {
218                         return "audio/wav";
219                     }
220                     // rate is a required parameter for this type
221                     if (frequency == null) {
222                         return null;
223                     }
224                     StringBuilder contentTypeL16 = new StringBuilder(HttpMediaType.AUDIO_PCM).append(";rate=")
225                             .append(frequency);
226                     // // those are optional
227                     Integer channels = format.getChannels();
228                     if (channels != null) {
229                         contentTypeL16.append(";channels=").append(channels);
230                     }
231                     Boolean bigEndian = format.isBigEndian();
232                     if (bigEndian != null) {
233                         contentTypeL16.append(";")
234                                 .append(bigEndian ? "endianness=big-endian" : "endianness=little-endian");
235                     }
236                     return contentTypeL16.toString();
237                 }
238             case AudioFormat.CONTAINER_OGG:
239                 switch (codec) {
240                     case AudioFormat.CODEC_VORBIS:
241                         return "audio/ogg;codecs=vorbis";
242                     case "OPUS":
243                         return "audio/ogg;codecs=opus";
244                 }
245                 break;
246         }
247         return null;
248     }
249
250     private static void sendStopMessage(WebSocket ws) {
251         JsonObject stopMessage = new JsonObject();
252         stopMessage.addProperty("action", "stop");
253         ws.send(stopMessage.toString());
254     }
255
256     private static class TranscriptionListener implements RecognizeCallback {
257         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
258         private final StringBuilder transcriptBuilder = new StringBuilder();
259         private final STTListener sttListener;
260         private final WatsonSTTConfiguration config;
261         private final AtomicBoolean aborted;
262         private final AtomicReference<@Nullable WebSocket> socketRef;
263         private float confidenceSum = 0f;
264         private int responseCount = 0;
265         private boolean disconnected = false;
266
267         public TranscriptionListener(AtomicReference<@Nullable WebSocket> socketRef, STTListener sttListener,
268                 WatsonSTTConfiguration config, AtomicBoolean aborted) {
269             this.socketRef = socketRef;
270             this.sttListener = sttListener;
271             this.config = config;
272             this.aborted = aborted;
273         }
274
275         @Override
276         public void onTranscription(@Nullable SpeechRecognitionResults speechRecognitionResults) {
277             logger.debug("onTranscription");
278             if (speechRecognitionResults == null) {
279                 return;
280             }
281             speechRecognitionResults.getResults().stream().filter(SpeechRecognitionResult::isXFinal).forEach(result -> {
282                 SpeechRecognitionAlternative alternative = result.getAlternatives().stream().findFirst().orElse(null);
283                 if (alternative == null) {
284                     return;
285                 }
286                 logger.debug("onTranscription Final");
287                 Double confidence = alternative.getConfidence();
288                 transcriptBuilder.append(alternative.getTranscript());
289                 confidenceSum += confidence != null ? confidence.floatValue() : 0f;
290                 responseCount++;
291                 if (config.singleUtteranceMode) {
292                     var socket = socketRef.get();
293                     if (socket != null) {
294                         sendStopMessage(socket);
295                     }
296                 }
297             });
298         }
299
300         @Override
301         public void onConnected() {
302             logger.debug("onConnected");
303         }
304
305         @Override
306         public void onError(@Nullable Exception e) {
307             var errorMessage = e != null ? e.getMessage() : null;
308             if (errorMessage != null && disconnected && errorMessage.contains("Socket closed")) {
309                 logger.debug("Error ignored: {}", errorMessage);
310                 return;
311             }
312             logger.warn("TranscriptionError: {}", errorMessage);
313             if (!aborted.getAndSet(true)) {
314                 sttListener.sttEventReceived(
315                         new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
316             }
317         }
318
319         @Override
320         public void onDisconnected() {
321             logger.debug("onDisconnected");
322             disconnected = true;
323             if (!aborted.getAndSet(true)) {
324                 sttListener.sttEventReceived(new RecognitionStopEvent());
325                 float averageConfidence = confidenceSum / (float) responseCount;
326                 String transcript = transcriptBuilder.toString().trim();
327                 if (!transcript.isBlank()) {
328                     sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
329                 } else {
330                     if (!config.noResultsMessage.isBlank()) {
331                         sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
332                     } else {
333                         sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
334                     }
335                 }
336             }
337         }
338
339         @Override
340         public void onInactivityTimeout(@Nullable RuntimeException e) {
341             if (e != null) {
342                 logger.debug("InactivityTimeout: {}", e.getMessage());
343             }
344         }
345
346         @Override
347         public void onListening() {
348             logger.debug("onListening");
349             sttListener.sttEventReceived(new RecognitionStartEvent());
350         }
351
352         @Override
353         public void onTranscriptionComplete() {
354             logger.debug("onTranscriptionComplete");
355         }
356     }
357 }