]> git.basschouten.com Git - openhab-addons.git/blob
a217c6a4170683e32a5804fbdb15d39db13e9b7e
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2024 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.googlestt.internal;
14
15 import static org.openhab.voice.googlestt.internal.GoogleSTTConstants.*;
16
17 import java.io.IOException;
18 import java.util.Comparator;
19 import java.util.Dictionary;
20 import java.util.List;
21 import java.util.Locale;
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.Future;
25 import java.util.concurrent.ScheduledExecutorService;
26 import java.util.concurrent.atomic.AtomicBoolean;
27
28 import org.eclipse.jdt.annotation.NonNullByDefault;
29 import org.eclipse.jdt.annotation.Nullable;
30 import org.openhab.core.audio.AudioFormat;
31 import org.openhab.core.audio.AudioStream;
32 import org.openhab.core.audio.utils.AudioWaveUtils;
33 import org.openhab.core.auth.client.oauth2.AccessTokenResponse;
34 import org.openhab.core.auth.client.oauth2.OAuthClientService;
35 import org.openhab.core.auth.client.oauth2.OAuthException;
36 import org.openhab.core.auth.client.oauth2.OAuthFactory;
37 import org.openhab.core.auth.client.oauth2.OAuthResponseException;
38 import org.openhab.core.common.ThreadPoolManager;
39 import org.openhab.core.config.core.ConfigurableService;
40 import org.openhab.core.config.core.Configuration;
41 import org.openhab.core.voice.RecognitionStartEvent;
42 import org.openhab.core.voice.RecognitionStopEvent;
43 import org.openhab.core.voice.STTListener;
44 import org.openhab.core.voice.STTService;
45 import org.openhab.core.voice.STTServiceHandle;
46 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
47 import org.openhab.core.voice.SpeechRecognitionEvent;
48 import org.osgi.framework.Constants;
49 import org.osgi.service.cm.ConfigurationAdmin;
50 import org.osgi.service.component.annotations.Activate;
51 import org.osgi.service.component.annotations.Component;
52 import org.osgi.service.component.annotations.Deactivate;
53 import org.osgi.service.component.annotations.Modified;
54 import org.osgi.service.component.annotations.Reference;
55 import org.slf4j.Logger;
56 import org.slf4j.LoggerFactory;
57
58 import com.google.api.gax.rpc.ClientStream;
59 import com.google.api.gax.rpc.ResponseObserver;
60 import com.google.api.gax.rpc.StreamController;
61 import com.google.auth.Credentials;
62 import com.google.auth.oauth2.AccessToken;
63 import com.google.auth.oauth2.OAuth2Credentials;
64 import com.google.cloud.speech.v1.RecognitionConfig;
65 import com.google.cloud.speech.v1.SpeechClient;
66 import com.google.cloud.speech.v1.SpeechRecognitionAlternative;
67 import com.google.cloud.speech.v1.SpeechSettings;
68 import com.google.cloud.speech.v1.StreamingRecognitionConfig;
69 import com.google.cloud.speech.v1.StreamingRecognitionResult;
70 import com.google.cloud.speech.v1.StreamingRecognizeRequest;
71 import com.google.cloud.speech.v1.StreamingRecognizeResponse;
72 import com.google.protobuf.ByteString;
73
74 import io.grpc.LoadBalancerRegistry;
75 import io.grpc.internal.PickFirstLoadBalancerProvider;
76
77 /**
78  * The {@link GoogleSTTService} class is a service implementation to use Google Cloud Speech-to-Text features.
79  *
80  * @author Miguel Álvarez - Initial contribution
81  */
82 @NonNullByDefault
83 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
84 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
85         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
86 public class GoogleSTTService implements STTService {
87
88     private static final String GCP_AUTH_URI = "https://accounts.google.com/o/oauth2/auth";
89     private static final String GCP_TOKEN_URI = "https://accounts.google.com/o/oauth2/token";
90     private static final String GCP_REDIRECT_URI = "https://www.google.com";
91     private static final String GCP_SCOPE = "https://www.googleapis.com/auth/cloud-platform";
92
93     private final Logger logger = LoggerFactory.getLogger(GoogleSTTService.class);
94     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-googlestt");
95     private final OAuthFactory oAuthFactory;
96     private final ConfigurationAdmin configAdmin;
97
98     private GoogleSTTConfiguration config = new GoogleSTTConfiguration();
99     private @Nullable OAuthClientService oAuthService;
100
101     @Activate
102     public GoogleSTTService(final @Reference OAuthFactory oAuthFactory,
103             final @Reference ConfigurationAdmin configAdmin) {
104         LoadBalancerRegistry.getDefaultRegistry().register(new PickFirstLoadBalancerProvider());
105         this.oAuthFactory = oAuthFactory;
106         this.configAdmin = configAdmin;
107     }
108
109     @Activate
110     protected void activate(Map<String, Object> config) {
111         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
112         executor.submit(() -> GoogleSTTLocale.loadLocales(this.config.refreshSupportedLocales));
113         updateConfig();
114     }
115
116     @Modified
117     protected void modified(Map<String, Object> config) {
118         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
119         updateConfig();
120     }
121
122     @Deactivate
123     protected void dispose() {
124         if (oAuthService != null) {
125             oAuthFactory.ungetOAuthService(SERVICE_PID);
126             oAuthService = null;
127         }
128     }
129
130     @Override
131     public String getId() {
132         return SERVICE_ID;
133     }
134
135     @Override
136     public String getLabel(@Nullable Locale locale) {
137         return SERVICE_NAME;
138     }
139
140     @Override
141     public Set<Locale> getSupportedLocales() {
142         return GoogleSTTLocale.getSupportedLocales();
143     }
144
145     @Override
146     public Set<AudioFormat> getSupportedFormats() {
147         return Set.of(
148                 new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
149                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L));
150     }
151
152     @Override
153     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
154             Set<String> set) {
155         AtomicBoolean aborted = new AtomicBoolean(false);
156         backgroundRecognize(sttListener, audioStream, aborted, locale, set);
157         return new STTServiceHandle() {
158             @Override
159             public void abort() {
160                 aborted.set(true);
161             }
162         };
163     }
164
165     private void updateConfig() {
166         if (oAuthService != null) {
167             oAuthFactory.ungetOAuthService(SERVICE_PID);
168             oAuthService = null;
169         }
170         String clientId = this.config.clientId;
171         String clientSecret = this.config.clientSecret;
172         if (!clientId.isBlank() && !clientSecret.isBlank()) {
173             var oAuthService = oAuthFactory.createOAuthClientService(SERVICE_PID, GCP_TOKEN_URI, GCP_AUTH_URI, clientId,
174                     clientSecret, GCP_SCOPE, false);
175             this.oAuthService = oAuthService;
176             if (!this.config.oauthCode.isEmpty()) {
177                 getAccessToken(oAuthService, this.config.oauthCode);
178                 deleteAuthCode();
179             }
180         } else {
181             logger.warn("Missing authentication configuration to access Google Cloud STT API.");
182         }
183     }
184
185     private void getAccessToken(OAuthClientService oAuthService, String oauthCode) {
186         logger.debug("Trying to get access and refresh tokens.");
187         try {
188             AccessTokenResponse response = oAuthService.getAccessTokenResponseByAuthorizationCode(oauthCode,
189                     GCP_REDIRECT_URI);
190             if (response.getRefreshToken() == null || response.getRefreshToken().isEmpty()) {
191                 logger.warn("Error fetching refresh token. Please try to reauthorize.");
192             }
193         } catch (OAuthException | OAuthResponseException e) {
194             if (logger.isDebugEnabled()) {
195                 logger.debug("Error fetching access token: {}", e.getMessage(), e);
196             } else {
197                 logger.warn("Error fetching access token. Invalid oauth code? Please generate a new one.");
198             }
199         } catch (IOException e) {
200             logger.warn("An unexpected IOException occurred when fetching access token: {}", e.getMessage());
201         }
202     }
203
204     private void deleteAuthCode() {
205         try {
206             org.osgi.service.cm.Configuration serviceConfig = configAdmin.getConfiguration(SERVICE_PID);
207             Dictionary<String, Object> configProperties = serviceConfig.getProperties();
208             if (configProperties != null) {
209                 configProperties.put("oauthCode", "");
210                 serviceConfig.update(configProperties);
211             }
212         } catch (IOException e) {
213             logger.warn("Failed to delete current oauth code, please delete it manually.");
214         }
215     }
216
217     private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
218             Locale locale, Set<String> set) {
219         Credentials credentials = getCredentials();
220         return executor.submit(() -> {
221             logger.debug("Background recognize starting");
222             ClientStream<StreamingRecognizeRequest> clientStream = null;
223             try (SpeechClient client = SpeechClient
224                     .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
225                 TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
226                 clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
227                 streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
228                 clientStream.closeSend();
229                 logger.debug("Background recognize done");
230             } catch (IOException e) {
231                 if (clientStream != null && clientStream.isSendReady()) {
232                     clientStream.closeSendWithError(e);
233                 } else if (!config.errorMessage.isBlank()) {
234                     logger.warn("Error running speech to text: {}", e.getMessage());
235                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
236                 }
237             }
238         });
239     }
240
241     private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
242             TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
243         // Gather stream info and send config
244         AudioFormat streamFormat = audioStream.getFormat();
245         RecognitionConfig.AudioEncoding streamEncoding;
246         if (AudioFormat.PCM_SIGNED.isCompatible(streamFormat) || AudioFormat.WAV.isCompatible(streamFormat)) {
247             streamEncoding = RecognitionConfig.AudioEncoding.LINEAR16;
248         } else {
249             logger.debug("Unsupported format {}", streamFormat);
250             return;
251         }
252         Integer channelsObject = streamFormat.getChannels();
253         int channels = channelsObject != null ? channelsObject : 1;
254         Long longFrequency = streamFormat.getFrequency();
255         if (longFrequency == null) {
256             logger.debug("Missing frequency info");
257             return;
258         }
259         int frequency = Math.toIntExact(longFrequency);
260         // First thing we need to send the stream config
261         sendStreamConfig(clientStream, streamEncoding, frequency, channels, locale);
262         // Loop sending audio data
263         long startTime = System.currentTimeMillis();
264         long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
265         long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
266         final int bufferSize = 6400;
267         int numBytesRead;
268         int remaining = bufferSize;
269         if (AudioFormat.CONTAINER_WAVE.equals(streamFormat.getContainer())) {
270             AudioWaveUtils.removeFMT(audioStream);
271         }
272         byte[] audioBuffer = new byte[bufferSize];
273         while (!aborted.get() && !responseObserver.isDone()) {
274             numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
275             if (aborted.get()) {
276                 logger.debug("Stops listening, aborted");
277                 break;
278             }
279             if (numBytesRead == -1) {
280                 logger.debug("End of stream");
281                 break;
282             }
283             if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
284                 logger.debug("Stops listening, max transcription time reached");
285                 break;
286             }
287             if (!config.singleUtteranceMode
288                     && isExpiredInterval(maxSilenceMillis, responseObserver.getLastInputTime())) {
289                 logger.debug("Stops listening, max silence time reached");
290                 break;
291             }
292             if (numBytesRead != remaining) {
293                 remaining = remaining - numBytesRead;
294                 continue;
295             }
296             remaining = bufferSize;
297             StreamingRecognizeRequest dataRequest = StreamingRecognizeRequest.newBuilder()
298                     .setAudioContent(ByteString.copyFrom(audioBuffer)).build();
299             logger.debug("Sending audio data {}", bufferSize);
300             clientStream.send(dataRequest);
301         }
302         audioStream.close();
303     }
304
305     private void sendStreamConfig(ClientStream<StreamingRecognizeRequest> clientStream,
306             RecognitionConfig.AudioEncoding encoding, int sampleRate, int channels, Locale locale) {
307         RecognitionConfig recognitionConfig = RecognitionConfig.newBuilder().setEncoding(encoding)
308                 .setAudioChannelCount(channels).setLanguageCode(locale.toLanguageTag()).setSampleRateHertz(sampleRate)
309                 .build();
310
311         StreamingRecognitionConfig streamingRecognitionConfig = StreamingRecognitionConfig.newBuilder()
312                 .setConfig(recognitionConfig).setInterimResults(false).setSingleUtterance(config.singleUtteranceMode)
313                 .build();
314
315         clientStream
316                 .send(StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingRecognitionConfig).build());
317     }
318
319     private @Nullable Credentials getCredentials() {
320         String accessToken = null;
321         String refreshToken = null;
322         try {
323             OAuthClientService oAuthService = this.oAuthService;
324             if (oAuthService != null) {
325                 AccessTokenResponse response = oAuthService.getAccessTokenResponse();
326                 if (response != null) {
327                     accessToken = response.getAccessToken();
328                     refreshToken = response.getRefreshToken();
329                 }
330             }
331         } catch (OAuthException | IOException | OAuthResponseException e) {
332             logger.warn("Access token error: {}", e.getMessage());
333         }
334         if (accessToken == null || refreshToken == null) {
335             logger.warn("Missed google cloud access and/or refresh token");
336             return null;
337         }
338         return OAuth2Credentials.create(new AccessToken(accessToken, null));
339     }
340
341     private boolean isExpiredInterval(long interval, long referenceTime) {
342         return System.currentTimeMillis() - referenceTime > interval;
343     }
344
345     private static class TranscriptionListener implements ResponseObserver<StreamingRecognizeResponse> {
346         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
347         private final StringBuilder transcriptBuilder = new StringBuilder();
348         private final STTListener sttListener;
349         GoogleSTTConfiguration config;
350         private final AtomicBoolean aborted;
351         private float confidenceSum = 0;
352         private int responseCount = 0;
353         private long lastInputTime = 0;
354         private boolean done = false;
355
356         public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
357             this.sttListener = sttListener;
358             this.config = config;
359             this.aborted = aborted;
360         }
361
362         @Override
363         public void onStart(@Nullable StreamController controller) {
364             sttListener.sttEventReceived(new RecognitionStartEvent());
365             lastInputTime = System.currentTimeMillis();
366         }
367
368         @Override
369         public void onResponse(StreamingRecognizeResponse response) {
370             lastInputTime = System.currentTimeMillis();
371             List<StreamingRecognitionResult> results = response.getResultsList();
372             logger.debug("Got {} results", response.getResultsList().size());
373             if (results.isEmpty()) {
374                 logger.debug("No results");
375                 return;
376             }
377             results.forEach(result -> {
378                 List<SpeechRecognitionAlternative> alternatives = result.getAlternativesList();
379                 logger.debug("Got {} alternatives", alternatives.size());
380                 SpeechRecognitionAlternative alternative = alternatives.stream()
381                         .max(Comparator.comparing(SpeechRecognitionAlternative::getConfidence)).orElse(null);
382                 if (alternative == null) {
383                     return;
384                 }
385                 String transcript = alternative.getTranscript();
386                 logger.debug("Alternative transcript: {}", transcript);
387                 logger.debug("Alternative confidence: {}", alternative.getConfidence());
388                 if (result.getIsFinal()) {
389                     transcriptBuilder.append(transcript);
390                     confidenceSum += alternative.getConfidence();
391                     responseCount++;
392                     // when in single utterance mode we can just get one final result so complete
393                     if (config.singleUtteranceMode) {
394                         done = true;
395                     }
396                 }
397             });
398         }
399
400         @Override
401         public void onComplete() {
402             if (!aborted.getAndSet(true)) {
403                 sttListener.sttEventReceived(new RecognitionStopEvent());
404                 float averageConfidence = confidenceSum / responseCount;
405                 String transcript = transcriptBuilder.toString();
406                 if (!transcript.isBlank()) {
407                     sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
408                 } else if (!config.noResultsMessage.isBlank()) {
409                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
410                 } else {
411                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
412                 }
413             }
414         }
415
416         @Override
417         public void onError(@Nullable Throwable t) {
418             logger.warn("Recognition error: ", t);
419             if (!aborted.getAndSet(true)) {
420                 sttListener.sttEventReceived(new RecognitionStopEvent());
421                 if (!config.errorMessage.isBlank()) {
422                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
423                 } else {
424                     String errorMessage = t.getMessage();
425                     sttListener.sttEventReceived(
426                             new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
427                 }
428             }
429         }
430
431         public boolean isDone() {
432             return done;
433         }
434
435         public long getLastInputTime() {
436             return lastInputTime;
437         }
438     }
439 }