]> git.basschouten.com Git - openhab-addons.git/blob
120460921c9efa83465e1b0327ef34d055b328ed
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2023 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.googlestt.internal;
14
15 import static org.openhab.voice.googlestt.internal.GoogleSTTConstants.*;
16
17 import java.io.IOException;
18 import java.util.Comparator;
19 import java.util.Dictionary;
20 import java.util.List;
21 import java.util.Locale;
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.Future;
25 import java.util.concurrent.ScheduledExecutorService;
26 import java.util.concurrent.atomic.AtomicBoolean;
27
28 import org.eclipse.jdt.annotation.NonNullByDefault;
29 import org.eclipse.jdt.annotation.Nullable;
30 import org.openhab.core.audio.AudioFormat;
31 import org.openhab.core.audio.AudioStream;
32 import org.openhab.core.auth.client.oauth2.AccessTokenResponse;
33 import org.openhab.core.auth.client.oauth2.OAuthClientService;
34 import org.openhab.core.auth.client.oauth2.OAuthException;
35 import org.openhab.core.auth.client.oauth2.OAuthFactory;
36 import org.openhab.core.auth.client.oauth2.OAuthResponseException;
37 import org.openhab.core.common.ThreadPoolManager;
38 import org.openhab.core.config.core.ConfigurableService;
39 import org.openhab.core.config.core.Configuration;
40 import org.openhab.core.voice.RecognitionStartEvent;
41 import org.openhab.core.voice.RecognitionStopEvent;
42 import org.openhab.core.voice.STTListener;
43 import org.openhab.core.voice.STTService;
44 import org.openhab.core.voice.STTServiceHandle;
45 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
46 import org.openhab.core.voice.SpeechRecognitionEvent;
47 import org.osgi.framework.Constants;
48 import org.osgi.service.cm.ConfigurationAdmin;
49 import org.osgi.service.component.annotations.Activate;
50 import org.osgi.service.component.annotations.Component;
51 import org.osgi.service.component.annotations.Deactivate;
52 import org.osgi.service.component.annotations.Modified;
53 import org.osgi.service.component.annotations.Reference;
54 import org.slf4j.Logger;
55 import org.slf4j.LoggerFactory;
56
57 import com.google.api.gax.rpc.ClientStream;
58 import com.google.api.gax.rpc.ResponseObserver;
59 import com.google.api.gax.rpc.StreamController;
60 import com.google.auth.Credentials;
61 import com.google.auth.oauth2.AccessToken;
62 import com.google.auth.oauth2.OAuth2Credentials;
63 import com.google.cloud.speech.v1.RecognitionConfig;
64 import com.google.cloud.speech.v1.SpeechClient;
65 import com.google.cloud.speech.v1.SpeechRecognitionAlternative;
66 import com.google.cloud.speech.v1.SpeechSettings;
67 import com.google.cloud.speech.v1.StreamingRecognitionConfig;
68 import com.google.cloud.speech.v1.StreamingRecognitionResult;
69 import com.google.cloud.speech.v1.StreamingRecognizeRequest;
70 import com.google.cloud.speech.v1.StreamingRecognizeResponse;
71 import com.google.protobuf.ByteString;
72
73 import io.grpc.LoadBalancerRegistry;
74 import io.grpc.internal.PickFirstLoadBalancerProvider;
75
76 /**
77  * The {@link GoogleSTTService} class is a service implementation to use Google Cloud Speech-to-Text features.
78  *
79  * @author Miguel Álvarez - Initial contribution
80  */
81 @NonNullByDefault
82 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
83 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
84         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
85 public class GoogleSTTService implements STTService {
86
87     private static final String GCP_AUTH_URI = "https://accounts.google.com/o/oauth2/auth";
88     private static final String GCP_TOKEN_URI = "https://accounts.google.com/o/oauth2/token";
89     private static final String GCP_REDIRECT_URI = "https://www.google.com";
90     private static final String GCP_SCOPE = "https://www.googleapis.com/auth/cloud-platform";
91
92     private final Logger logger = LoggerFactory.getLogger(GoogleSTTService.class);
93     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-googlestt");
94     private final OAuthFactory oAuthFactory;
95     private final ConfigurationAdmin configAdmin;
96
97     private GoogleSTTConfiguration config = new GoogleSTTConfiguration();
98     private @Nullable OAuthClientService oAuthService;
99
100     @Activate
101     public GoogleSTTService(final @Reference OAuthFactory oAuthFactory,
102             final @Reference ConfigurationAdmin configAdmin) {
103         LoadBalancerRegistry.getDefaultRegistry().register(new PickFirstLoadBalancerProvider());
104         this.oAuthFactory = oAuthFactory;
105         this.configAdmin = configAdmin;
106     }
107
108     @Activate
109     protected void activate(Map<String, Object> config) {
110         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
111         executor.submit(() -> GoogleSTTLocale.loadLocales(this.config.refreshSupportedLocales));
112         updateConfig();
113     }
114
115     @Modified
116     protected void modified(Map<String, Object> config) {
117         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
118         updateConfig();
119     }
120
121     @Deactivate
122     protected void dispose() {
123         if (oAuthService != null) {
124             oAuthFactory.ungetOAuthService(SERVICE_PID);
125             oAuthService = null;
126         }
127     }
128
129     @Override
130     public String getId() {
131         return SERVICE_ID;
132     }
133
134     @Override
135     public String getLabel(@Nullable Locale locale) {
136         return SERVICE_NAME;
137     }
138
139     @Override
140     public Set<Locale> getSupportedLocales() {
141         return GoogleSTTLocale.getSupportedLocales();
142     }
143
144     @Override
145     public Set<AudioFormat> getSupportedFormats() {
146         return Set.of(
147                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
148                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 8000L),
149                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 12000L),
150                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 16000L),
151                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 24000L),
152                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 48000L));
153     }
154
155     @Override
156     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
157             Set<String> set) {
158         AtomicBoolean aborted = new AtomicBoolean(false);
159         backgroundRecognize(sttListener, audioStream, aborted, locale, set);
160         return new STTServiceHandle() {
161             @Override
162             public void abort() {
163                 aborted.set(true);
164             }
165         };
166     }
167
168     private void updateConfig() {
169         if (oAuthService != null) {
170             oAuthFactory.ungetOAuthService(SERVICE_PID);
171             oAuthService = null;
172         }
173         String clientId = this.config.clientId;
174         String clientSecret = this.config.clientSecret;
175         if (!clientId.isBlank() && !clientSecret.isBlank()) {
176             var oAuthService = oAuthFactory.createOAuthClientService(SERVICE_PID, GCP_TOKEN_URI, GCP_AUTH_URI, clientId,
177                     clientSecret, GCP_SCOPE, false);
178             this.oAuthService = oAuthService;
179             if (!this.config.oauthCode.isEmpty()) {
180                 getAccessToken(oAuthService, this.config.oauthCode);
181                 deleteAuthCode();
182             }
183         } else {
184             logger.warn("Missing authentication configuration to access Google Cloud STT API.");
185         }
186     }
187
188     private void getAccessToken(OAuthClientService oAuthService, String oauthCode) {
189         logger.debug("Trying to get access and refresh tokens.");
190         try {
191             AccessTokenResponse response = oAuthService.getAccessTokenResponseByAuthorizationCode(oauthCode,
192                     GCP_REDIRECT_URI);
193             if (response.getRefreshToken() == null || response.getRefreshToken().isEmpty()) {
194                 logger.warn("Error fetching refresh token. Please try to reauthorize.");
195             }
196         } catch (OAuthException | OAuthResponseException e) {
197             if (logger.isDebugEnabled()) {
198                 logger.debug("Error fetching access token: {}", e.getMessage(), e);
199             } else {
200                 logger.warn("Error fetching access token. Invalid oauth code? Please generate a new one.");
201             }
202         } catch (IOException e) {
203             logger.warn("An unexpected IOException occurred when fetching access token: {}", e.getMessage());
204         }
205     }
206
207     private void deleteAuthCode() {
208         try {
209             org.osgi.service.cm.Configuration serviceConfig = configAdmin.getConfiguration(SERVICE_PID);
210             Dictionary<String, Object> configProperties = serviceConfig.getProperties();
211             if (configProperties != null) {
212                 configProperties.put("oauthCode", "");
213                 serviceConfig.update(configProperties);
214             }
215         } catch (IOException e) {
216             logger.warn("Failed to delete current oauth code, please delete it manually.");
217         }
218     }
219
220     private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
221             Locale locale, Set<String> set) {
222         Credentials credentials = getCredentials();
223         return executor.submit(() -> {
224             logger.debug("Background recognize starting");
225             ClientStream<StreamingRecognizeRequest> clientStream = null;
226             try (SpeechClient client = SpeechClient
227                     .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
228                 TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
229                 clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
230                 streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
231                 clientStream.closeSend();
232                 logger.debug("Background recognize done");
233             } catch (IOException e) {
234                 if (clientStream != null && clientStream.isSendReady()) {
235                     clientStream.closeSendWithError(e);
236                 } else if (!config.errorMessage.isBlank()) {
237                     logger.warn("Error running speech to text: {}", e.getMessage());
238                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
239                 }
240             }
241         });
242     }
243
244     private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
245             TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
246         // Gather stream info and send config
247         AudioFormat streamFormat = audioStream.getFormat();
248         RecognitionConfig.AudioEncoding streamEncoding;
249         if (AudioFormat.WAV.isCompatible(streamFormat)) {
250             streamEncoding = RecognitionConfig.AudioEncoding.LINEAR16;
251         } else if (AudioFormat.OGG.isCompatible(streamFormat)) {
252             streamEncoding = RecognitionConfig.AudioEncoding.OGG_OPUS;
253         } else {
254             logger.debug("Unsupported format {}", streamFormat);
255             return;
256         }
257         Integer channelsObject = streamFormat.getChannels();
258         int channels = channelsObject != null ? channelsObject : 1;
259         Long longFrequency = streamFormat.getFrequency();
260         if (longFrequency == null) {
261             logger.debug("Missing frequency info");
262             return;
263         }
264         int frequency = Math.toIntExact(longFrequency);
265         // First thing we need to send the stream config
266         sendStreamConfig(clientStream, streamEncoding, frequency, channels, locale);
267         // Loop sending audio data
268         long startTime = System.currentTimeMillis();
269         long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
270         long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
271         final int bufferSize = 6400;
272         int numBytesRead;
273         int remaining = bufferSize;
274         byte[] audioBuffer = new byte[bufferSize];
275         while (!aborted.get() && !responseObserver.isDone()) {
276             numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
277             if (aborted.get()) {
278                 logger.debug("Stops listening, aborted");
279                 break;
280             }
281             if (numBytesRead == -1) {
282                 logger.debug("End of stream");
283                 break;
284             }
285             if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
286                 logger.debug("Stops listening, max transcription time reached");
287                 break;
288             }
289             if (!config.singleUtteranceMode
290                     && isExpiredInterval(maxSilenceMillis, responseObserver.getLastInputTime())) {
291                 logger.debug("Stops listening, max silence time reached");
292                 break;
293             }
294             if (numBytesRead != remaining) {
295                 remaining = remaining - numBytesRead;
296                 continue;
297             }
298             remaining = bufferSize;
299             StreamingRecognizeRequest dataRequest = StreamingRecognizeRequest.newBuilder()
300                     .setAudioContent(ByteString.copyFrom(audioBuffer)).build();
301             logger.debug("Sending audio data {}", bufferSize);
302             clientStream.send(dataRequest);
303         }
304         audioStream.close();
305     }
306
307     private void sendStreamConfig(ClientStream<StreamingRecognizeRequest> clientStream,
308             RecognitionConfig.AudioEncoding encoding, int sampleRate, int channels, Locale locale) {
309         RecognitionConfig recognitionConfig = RecognitionConfig.newBuilder().setEncoding(encoding)
310                 .setAudioChannelCount(channels).setLanguageCode(locale.toLanguageTag()).setSampleRateHertz(sampleRate)
311                 .build();
312
313         StreamingRecognitionConfig streamingRecognitionConfig = StreamingRecognitionConfig.newBuilder()
314                 .setConfig(recognitionConfig).setInterimResults(false).setSingleUtterance(config.singleUtteranceMode)
315                 .build();
316
317         clientStream
318                 .send(StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingRecognitionConfig).build());
319     }
320
321     private @Nullable Credentials getCredentials() {
322         String accessToken = null;
323         String refreshToken = null;
324         try {
325             OAuthClientService oAuthService = this.oAuthService;
326             if (oAuthService != null) {
327                 AccessTokenResponse response = oAuthService.getAccessTokenResponse();
328                 if (response != null) {
329                     accessToken = response.getAccessToken();
330                     refreshToken = response.getRefreshToken();
331                 }
332             }
333         } catch (OAuthException | IOException | OAuthResponseException e) {
334             logger.warn("Access token error: {}", e.getMessage());
335         }
336         if (accessToken == null || refreshToken == null) {
337             logger.warn("Missed google cloud access and/or refresh token");
338             return null;
339         }
340         return OAuth2Credentials.create(new AccessToken(accessToken, null));
341     }
342
343     private boolean isExpiredInterval(long interval, long referenceTime) {
344         return System.currentTimeMillis() - referenceTime > interval;
345     }
346
347     private static class TranscriptionListener implements ResponseObserver<StreamingRecognizeResponse> {
348         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
349         private final StringBuilder transcriptBuilder = new StringBuilder();
350         private final STTListener sttListener;
351         GoogleSTTConfiguration config;
352         private final AtomicBoolean aborted;
353         private float confidenceSum = 0;
354         private int responseCount = 0;
355         private long lastInputTime = 0;
356         private boolean done = false;
357
358         public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
359             this.sttListener = sttListener;
360             this.config = config;
361             this.aborted = aborted;
362         }
363
364         @Override
365         public void onStart(@Nullable StreamController controller) {
366             sttListener.sttEventReceived(new RecognitionStartEvent());
367             lastInputTime = System.currentTimeMillis();
368         }
369
370         @Override
371         public void onResponse(StreamingRecognizeResponse response) {
372             lastInputTime = System.currentTimeMillis();
373             List<StreamingRecognitionResult> results = response.getResultsList();
374             logger.debug("Got {} results", response.getResultsList().size());
375             if (results.isEmpty()) {
376                 logger.debug("No results");
377                 return;
378             }
379             results.forEach(result -> {
380                 List<SpeechRecognitionAlternative> alternatives = result.getAlternativesList();
381                 logger.debug("Got {} alternatives", alternatives.size());
382                 SpeechRecognitionAlternative alternative = alternatives.stream()
383                         .max(Comparator.comparing(SpeechRecognitionAlternative::getConfidence)).orElse(null);
384                 if (alternative == null) {
385                     return;
386                 }
387                 String transcript = alternative.getTranscript();
388                 logger.debug("Alternative transcript: {}", transcript);
389                 logger.debug("Alternative confidence: {}", alternative.getConfidence());
390                 if (result.getIsFinal()) {
391                     transcriptBuilder.append(transcript);
392                     confidenceSum += alternative.getConfidence();
393                     responseCount++;
394                     // when in single utterance mode we can just get one final result so complete
395                     if (config.singleUtteranceMode) {
396                         done = true;
397                     }
398                 }
399             });
400         }
401
402         @Override
403         public void onComplete() {
404             if (!aborted.getAndSet(true)) {
405                 sttListener.sttEventReceived(new RecognitionStopEvent());
406                 float averageConfidence = confidenceSum / responseCount;
407                 String transcript = transcriptBuilder.toString();
408                 if (!transcript.isBlank()) {
409                     sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
410                 } else if (!config.noResultsMessage.isBlank()) {
411                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
412                 } else {
413                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
414                 }
415             }
416         }
417
418         @Override
419         public void onError(@Nullable Throwable t) {
420             logger.warn("Recognition error: ", t);
421             if (!aborted.getAndSet(true)) {
422                 sttListener.sttEventReceived(new RecognitionStopEvent());
423                 if (!config.errorMessage.isBlank()) {
424                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
425                 } else {
426                     String errorMessage = t.getMessage();
427                     sttListener.sttEventReceived(
428                             new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
429                 }
430             }
431         }
432
433         public boolean isDone() {
434             return done;
435         }
436
437         public long getLastInputTime() {
438             return lastInputTime;
439         }
440     }
441 }