]> git.basschouten.com Git - openhab-addons.git/blob
52df7359dfa2d3591c02c9d6e3076e9271178cb9
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2023 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.googlestt.internal;
14
15 import static org.openhab.voice.googlestt.internal.GoogleSTTConstants.*;
16
17 import java.io.IOException;
18 import java.util.Comparator;
19 import java.util.Dictionary;
20 import java.util.List;
21 import java.util.Locale;
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.Future;
25 import java.util.concurrent.ScheduledExecutorService;
26 import java.util.concurrent.atomic.AtomicBoolean;
27
28 import org.eclipse.jdt.annotation.NonNullByDefault;
29 import org.eclipse.jdt.annotation.Nullable;
30 import org.openhab.core.audio.AudioFormat;
31 import org.openhab.core.audio.AudioStream;
32 import org.openhab.core.auth.client.oauth2.AccessTokenResponse;
33 import org.openhab.core.auth.client.oauth2.OAuthClientService;
34 import org.openhab.core.auth.client.oauth2.OAuthException;
35 import org.openhab.core.auth.client.oauth2.OAuthFactory;
36 import org.openhab.core.auth.client.oauth2.OAuthResponseException;
37 import org.openhab.core.common.ThreadPoolManager;
38 import org.openhab.core.config.core.ConfigurableService;
39 import org.openhab.core.config.core.Configuration;
40 import org.openhab.core.voice.RecognitionStartEvent;
41 import org.openhab.core.voice.RecognitionStopEvent;
42 import org.openhab.core.voice.STTListener;
43 import org.openhab.core.voice.STTService;
44 import org.openhab.core.voice.STTServiceHandle;
45 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
46 import org.openhab.core.voice.SpeechRecognitionEvent;
47 import org.osgi.framework.Constants;
48 import org.osgi.service.cm.ConfigurationAdmin;
49 import org.osgi.service.component.annotations.Activate;
50 import org.osgi.service.component.annotations.Component;
51 import org.osgi.service.component.annotations.Modified;
52 import org.osgi.service.component.annotations.Reference;
53 import org.slf4j.Logger;
54 import org.slf4j.LoggerFactory;
55
56 import com.google.api.gax.rpc.ClientStream;
57 import com.google.api.gax.rpc.ResponseObserver;
58 import com.google.api.gax.rpc.StreamController;
59 import com.google.auth.Credentials;
60 import com.google.auth.oauth2.AccessToken;
61 import com.google.auth.oauth2.OAuth2Credentials;
62 import com.google.cloud.speech.v1.RecognitionConfig;
63 import com.google.cloud.speech.v1.SpeechClient;
64 import com.google.cloud.speech.v1.SpeechRecognitionAlternative;
65 import com.google.cloud.speech.v1.SpeechSettings;
66 import com.google.cloud.speech.v1.StreamingRecognitionConfig;
67 import com.google.cloud.speech.v1.StreamingRecognitionResult;
68 import com.google.cloud.speech.v1.StreamingRecognizeRequest;
69 import com.google.cloud.speech.v1.StreamingRecognizeResponse;
70 import com.google.protobuf.ByteString;
71
72 import io.grpc.LoadBalancerRegistry;
73 import io.grpc.internal.PickFirstLoadBalancerProvider;
74
75 /**
76  * The {@link GoogleSTTService} class is a service implementation to use Google Cloud Speech-to-Text features.
77  *
78  * @author Miguel Álvarez - Initial contribution
79  */
80 @NonNullByDefault
81 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
82 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
83         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
84 public class GoogleSTTService implements STTService {
85
86     private static final String GCP_AUTH_URI = "https://accounts.google.com/o/oauth2/auth";
87     private static final String GCP_TOKEN_URI = "https://accounts.google.com/o/oauth2/token";
88     private static final String GCP_REDIRECT_URI = "https://www.google.com";
89     private static final String GCP_SCOPE = "https://www.googleapis.com/auth/cloud-platform";
90
91     private final Logger logger = LoggerFactory.getLogger(GoogleSTTService.class);
92     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-googlestt");
93     private final OAuthFactory oAuthFactory;
94     private final ConfigurationAdmin configAdmin;
95
96     private GoogleSTTConfiguration config = new GoogleSTTConfiguration();
97     private @Nullable OAuthClientService oAuthService;
98
99     @Activate
100     public GoogleSTTService(final @Reference OAuthFactory oAuthFactory,
101             final @Reference ConfigurationAdmin configAdmin) {
102         LoadBalancerRegistry.getDefaultRegistry().register(new PickFirstLoadBalancerProvider());
103         this.oAuthFactory = oAuthFactory;
104         this.configAdmin = configAdmin;
105     }
106
107     @Activate
108     protected void activate(Map<String, Object> config) {
109         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
110         executor.submit(() -> GoogleSTTLocale.loadLocales(this.config.refreshSupportedLocales));
111         updateConfig();
112     }
113
114     @Modified
115     protected void modified(Map<String, Object> config) {
116         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
117         updateConfig();
118     }
119
120     @Override
121     public String getId() {
122         return SERVICE_ID;
123     }
124
125     @Override
126     public String getLabel(@Nullable Locale locale) {
127         return SERVICE_NAME;
128     }
129
130     @Override
131     public Set<Locale> getSupportedLocales() {
132         return GoogleSTTLocale.getSupportedLocales();
133     }
134
135     @Override
136     public Set<AudioFormat> getSupportedFormats() {
137         return Set.of(
138                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
139                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 8000L),
140                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 12000L),
141                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 16000L),
142                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 24000L),
143                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 48000L));
144     }
145
146     @Override
147     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
148             Set<String> set) {
149         AtomicBoolean aborted = new AtomicBoolean(false);
150         backgroundRecognize(sttListener, audioStream, aborted, locale, set);
151         return new STTServiceHandle() {
152             @Override
153             public void abort() {
154                 aborted.set(true);
155             }
156         };
157     }
158
159     private void updateConfig() {
160         String clientId = this.config.clientId;
161         String clientSecret = this.config.clientSecret;
162         if (!clientId.isBlank() && !clientSecret.isBlank()) {
163             var oAuthService = oAuthFactory.createOAuthClientService(SERVICE_PID, GCP_TOKEN_URI, GCP_AUTH_URI, clientId,
164                     clientSecret, GCP_SCOPE, false);
165             this.oAuthService = oAuthService;
166             if (!this.config.oauthCode.isEmpty()) {
167                 getAccessToken(oAuthService, this.config.oauthCode);
168                 deleteAuthCode();
169             }
170         } else {
171             logger.warn("Missing authentication configuration to access Google Cloud STT API.");
172         }
173     }
174
175     private void getAccessToken(OAuthClientService oAuthService, String oauthCode) {
176         logger.debug("Trying to get access and refresh tokens.");
177         try {
178             AccessTokenResponse response = oAuthService.getAccessTokenResponseByAuthorizationCode(oauthCode,
179                     GCP_REDIRECT_URI);
180             if (response.getRefreshToken() == null || response.getRefreshToken().isEmpty()) {
181                 logger.warn("Error fetching refresh token. Please try to reauthorize.");
182             }
183         } catch (OAuthException | OAuthResponseException e) {
184             if (logger.isDebugEnabled()) {
185                 logger.debug("Error fetching access token: {}", e.getMessage(), e);
186             } else {
187                 logger.warn("Error fetching access token. Invalid oauth code? Please generate a new one.");
188             }
189         } catch (IOException e) {
190             logger.warn("An unexpected IOException occurred when fetching access token: {}", e.getMessage());
191         }
192     }
193
194     private void deleteAuthCode() {
195         try {
196             org.osgi.service.cm.Configuration serviceConfig = configAdmin.getConfiguration(SERVICE_PID);
197             Dictionary<String, Object> configProperties = serviceConfig.getProperties();
198             if (configProperties != null) {
199                 configProperties.put("oauthCode", "");
200                 serviceConfig.update(configProperties);
201             }
202         } catch (IOException e) {
203             logger.warn("Failed to delete current oauth code, please delete it manually.");
204         }
205     }
206
207     private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
208             Locale locale, Set<String> set) {
209         Credentials credentials = getCredentials();
210         return executor.submit(() -> {
211             logger.debug("Background recognize starting");
212             ClientStream<StreamingRecognizeRequest> clientStream = null;
213             try (SpeechClient client = SpeechClient
214                     .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
215                 TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
216                 clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
217                 streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
218                 clientStream.closeSend();
219                 logger.debug("Background recognize done");
220             } catch (IOException e) {
221                 if (clientStream != null && clientStream.isSendReady()) {
222                     clientStream.closeSendWithError(e);
223                 } else if (!config.errorMessage.isBlank()) {
224                     logger.warn("Error running speech to text: {}", e.getMessage());
225                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
226                 }
227             }
228         });
229     }
230
231     private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
232             TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
233         // Gather stream info and send config
234         AudioFormat streamFormat = audioStream.getFormat();
235         RecognitionConfig.AudioEncoding streamEncoding;
236         if (AudioFormat.WAV.isCompatible(streamFormat)) {
237             streamEncoding = RecognitionConfig.AudioEncoding.LINEAR16;
238         } else if (AudioFormat.OGG.isCompatible(streamFormat)) {
239             streamEncoding = RecognitionConfig.AudioEncoding.OGG_OPUS;
240         } else {
241             logger.debug("Unsupported format {}", streamFormat);
242             return;
243         }
244         Integer channelsObject = streamFormat.getChannels();
245         int channels = channelsObject != null ? channelsObject : 1;
246         Long longFrequency = streamFormat.getFrequency();
247         if (longFrequency == null) {
248             logger.debug("Missing frequency info");
249             return;
250         }
251         int frequency = Math.toIntExact(longFrequency);
252         // First thing we need to send the stream config
253         sendStreamConfig(clientStream, streamEncoding, frequency, channels, locale);
254         // Loop sending audio data
255         long startTime = System.currentTimeMillis();
256         long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
257         long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
258         int readBytes = 6400;
259         while (!aborted.get()) {
260             byte[] data = new byte[readBytes];
261             int dataN = audioStream.read(data);
262             if (aborted.get()) {
263                 logger.debug("Stops listening, aborted");
264                 break;
265             }
266             if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
267                 logger.debug("Stops listening, max transcription time reached");
268                 break;
269             }
270             if (!config.singleUtteranceMode
271                     && isExpiredInterval(maxSilenceMillis, responseObserver.getLastInputTime())) {
272                 logger.debug("Stops listening, max silence time reached");
273                 break;
274             }
275             if (dataN != readBytes) {
276                 try {
277                     Thread.sleep(100);
278                 } catch (InterruptedException e) {
279                 }
280                 continue;
281             }
282             StreamingRecognizeRequest dataRequest = StreamingRecognizeRequest.newBuilder()
283                     .setAudioContent(ByteString.copyFrom(data)).build();
284             logger.debug("Sending audio data {}", dataN);
285             clientStream.send(dataRequest);
286         }
287     }
288
289     private void sendStreamConfig(ClientStream<StreamingRecognizeRequest> clientStream,
290             RecognitionConfig.AudioEncoding encoding, int sampleRate, int channels, Locale locale) {
291         RecognitionConfig recognitionConfig = RecognitionConfig.newBuilder().setEncoding(encoding)
292                 .setAudioChannelCount(channels).setLanguageCode(locale.toLanguageTag()).setSampleRateHertz(sampleRate)
293                 .build();
294
295         StreamingRecognitionConfig streamingRecognitionConfig = StreamingRecognitionConfig.newBuilder()
296                 .setConfig(recognitionConfig).setInterimResults(false).setSingleUtterance(config.singleUtteranceMode)
297                 .build();
298
299         clientStream
300                 .send(StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingRecognitionConfig).build());
301     }
302
303     private @Nullable Credentials getCredentials() {
304         String accessToken = null;
305         String refreshToken = null;
306         try {
307             OAuthClientService oAuthService = this.oAuthService;
308             if (oAuthService != null) {
309                 AccessTokenResponse response = oAuthService.getAccessTokenResponse();
310                 if (response != null) {
311                     accessToken = response.getAccessToken();
312                     refreshToken = response.getRefreshToken();
313                 }
314             }
315         } catch (OAuthException | IOException | OAuthResponseException e) {
316             logger.warn("Access token error: {}", e.getMessage());
317         }
318         if (accessToken == null || refreshToken == null) {
319             logger.warn("Missed google cloud access and/or refresh token");
320             return null;
321         }
322         return OAuth2Credentials.create(new AccessToken(accessToken, null));
323     }
324
325     private boolean isExpiredInterval(long interval, long referenceTime) {
326         return System.currentTimeMillis() - referenceTime > interval;
327     }
328
329     private static class TranscriptionListener implements ResponseObserver<StreamingRecognizeResponse> {
330         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
331         private final StringBuilder transcriptBuilder = new StringBuilder();
332         private final STTListener sttListener;
333         GoogleSTTConfiguration config;
334         private final AtomicBoolean aborted;
335         private float confidenceSum = 0;
336         private int responseCount = 0;
337         private long lastInputTime = 0;
338
339         public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
340             this.sttListener = sttListener;
341             this.config = config;
342             this.aborted = aborted;
343         }
344
345         @Override
346         public void onStart(@Nullable StreamController controller) {
347             sttListener.sttEventReceived(new RecognitionStartEvent());
348             lastInputTime = System.currentTimeMillis();
349         }
350
351         @Override
352         public void onResponse(StreamingRecognizeResponse response) {
353             lastInputTime = System.currentTimeMillis();
354             List<StreamingRecognitionResult> results = response.getResultsList();
355             logger.debug("Got {} results", response.getResultsList().size());
356             if (results.isEmpty()) {
357                 logger.debug("No results");
358                 return;
359             }
360             results.forEach(result -> {
361                 List<SpeechRecognitionAlternative> alternatives = result.getAlternativesList();
362                 logger.debug("Got {} alternatives", alternatives.size());
363                 SpeechRecognitionAlternative alternative = alternatives.stream()
364                         .max(Comparator.comparing(SpeechRecognitionAlternative::getConfidence)).orElse(null);
365                 if (alternative == null) {
366                     return;
367                 }
368                 String transcript = alternative.getTranscript();
369                 logger.debug("Alternative transcript: {}", transcript);
370                 logger.debug("Alternative confidence: {}", alternative.getConfidence());
371                 if (result.getIsFinal()) {
372                     transcriptBuilder.append(transcript);
373                     confidenceSum += alternative.getConfidence();
374                     responseCount++;
375                     // when in single utterance mode we can just get one final result so complete
376                     if (config.singleUtteranceMode) {
377                         onComplete();
378                     }
379                 }
380             });
381         }
382
383         @Override
384         public void onComplete() {
385             if (!aborted.getAndSet(true)) {
386                 sttListener.sttEventReceived(new RecognitionStopEvent());
387                 float averageConfidence = confidenceSum / responseCount;
388                 String transcript = transcriptBuilder.toString();
389                 if (!transcript.isBlank()) {
390                     sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
391                 } else if (!config.noResultsMessage.isBlank()) {
392                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
393                 } else {
394                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
395                 }
396             }
397         }
398
399         @Override
400         public void onError(@Nullable Throwable t) {
401             logger.warn("Recognition error: ", t);
402             if (!aborted.getAndSet(true)) {
403                 sttListener.sttEventReceived(new RecognitionStopEvent());
404                 if (!config.errorMessage.isBlank()) {
405                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
406                 } else {
407                     String errorMessage = t.getMessage();
408                     sttListener.sttEventReceived(
409                             new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
410                 }
411             }
412         }
413
414         public long getLastInputTime() {
415             return lastInputTime;
416         }
417     }
418 }