]> git.basschouten.com Git - openhab-addons.git/blob
a2f7bbe74378bf8360a126ab1b0cf1299b5e4752
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2022 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.googlestt.internal;
14
15 import static org.openhab.voice.googlestt.internal.GoogleSTTConstants.*;
16
17 import java.io.IOException;
18 import java.util.*;
19 import java.util.concurrent.Future;
20 import java.util.concurrent.ScheduledExecutorService;
21 import java.util.concurrent.atomic.AtomicBoolean;
22 import java.util.function.Consumer;
23
24 import org.eclipse.jdt.annotation.NonNullByDefault;
25 import org.eclipse.jdt.annotation.Nullable;
26 import org.openhab.core.audio.AudioFormat;
27 import org.openhab.core.audio.AudioStream;
28 import org.openhab.core.auth.client.oauth2.*;
29 import org.openhab.core.common.ThreadPoolManager;
30 import org.openhab.core.config.core.ConfigurableService;
31 import org.openhab.core.config.core.Configuration;
32 import org.openhab.core.voice.*;
33 import org.osgi.framework.Constants;
34 import org.osgi.service.cm.ConfigurationAdmin;
35 import org.osgi.service.component.annotations.Activate;
36 import org.osgi.service.component.annotations.Component;
37 import org.osgi.service.component.annotations.Modified;
38 import org.osgi.service.component.annotations.Reference;
39 import org.slf4j.Logger;
40 import org.slf4j.LoggerFactory;
41
42 import com.google.api.gax.rpc.ClientStream;
43 import com.google.api.gax.rpc.ResponseObserver;
44 import com.google.api.gax.rpc.StreamController;
45 import com.google.auth.Credentials;
46 import com.google.auth.oauth2.AccessToken;
47 import com.google.auth.oauth2.OAuth2Credentials;
48 import com.google.cloud.speech.v1.*;
49 import com.google.protobuf.ByteString;
50
51 import io.grpc.LoadBalancerRegistry;
52 import io.grpc.internal.PickFirstLoadBalancerProvider;
53
54 /**
55  * The {@link GoogleSTTService} class is a service implementation to use Google Cloud Speech-to-Text features.
56  *
57  * @author Miguel Álvarez - Initial contribution
58  */
59 @NonNullByDefault
60 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
61 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME, description_uri = SERVICE_CATEGORY + ":"
62         + SERVICE_ID)
63 public class GoogleSTTService implements STTService {
64
65     private static final String GCP_AUTH_URI = "https://accounts.google.com/o/oauth2/auth";
66     private static final String GCP_TOKEN_URI = "https://accounts.google.com/o/oauth2/token";
67     private static final String GCP_REDIRECT_URI = "urn:ietf:wg:oauth:2.0:oob";
68     private static final String GCP_SCOPE = "https://www.googleapis.com/auth/cloud-platform";
69
70     private final Logger logger = LoggerFactory.getLogger(GoogleSTTService.class);
71     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-googlestt");
72     private final OAuthFactory oAuthFactory;
73     private final ConfigurationAdmin configAdmin;
74
75     private GoogleSTTConfiguration config = new GoogleSTTConfiguration();
76     private @Nullable OAuthClientService oAuthService;
77
78     @Activate
79     public GoogleSTTService(final @Reference OAuthFactory oAuthFactory,
80             final @Reference ConfigurationAdmin configAdmin) {
81         LoadBalancerRegistry.getDefaultRegistry().register(new PickFirstLoadBalancerProvider());
82         this.oAuthFactory = oAuthFactory;
83         this.configAdmin = configAdmin;
84     }
85
86     @Activate
87     protected void activate(Map<String, Object> config) {
88         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
89         executor.submit(() -> GoogleSTTLocale.loadLocales(this.config.refreshSupportedLocales));
90         updateConfig();
91     }
92
93     @Modified
94     protected void modified(Map<String, Object> config) {
95         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
96         updateConfig();
97     }
98
99     @Override
100     public String getId() {
101         return SERVICE_ID;
102     }
103
104     @Override
105     public String getLabel(@Nullable Locale locale) {
106         return SERVICE_NAME;
107     }
108
109     @Override
110     public Set<Locale> getSupportedLocales() {
111         return GoogleSTTLocale.getSupportedLocales();
112     }
113
114     @Override
115     public Set<AudioFormat> getSupportedFormats() {
116         return Set.of(
117                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
118                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 8000L),
119                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 12000L),
120                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 16000L),
121                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 24000L),
122                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 48000L));
123     }
124
125     @Override
126     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
127             Set<String> set) {
128         AtomicBoolean keepStreaming = new AtomicBoolean(true);
129         Future scheduledTask = backgroundRecognize(sttListener, audioStream, keepStreaming, locale, set);
130         return new STTServiceHandle() {
131             @Override
132             public void abort() {
133                 keepStreaming.set(false);
134                 try {
135                     Thread.sleep(100);
136                 } catch (InterruptedException e) {
137                 }
138                 scheduledTask.cancel(true);
139             }
140         };
141     }
142
143     private void updateConfig() {
144         String clientId = this.config.clientId;
145         String clientSecret = this.config.clientSecret;
146         if (!clientId.isBlank() && !clientSecret.isBlank()) {
147             var oAuthService = oAuthFactory.createOAuthClientService(SERVICE_PID, GCP_TOKEN_URI, GCP_AUTH_URI, clientId,
148                     clientSecret, GCP_SCOPE, false);
149             this.oAuthService = oAuthService;
150             if (!this.config.oauthCode.isEmpty()) {
151                 getAccessToken(oAuthService, this.config.oauthCode);
152                 deleteAuthCode();
153             }
154         } else {
155             logger.warn("Missing authentication configuration to access Google Cloud STT API.");
156         }
157     }
158
159     private void getAccessToken(OAuthClientService oAuthService, String oauthCode) {
160         logger.debug("Trying to get access and refresh tokens.");
161         try {
162             oAuthService.getAccessTokenResponseByAuthorizationCode(oauthCode, GCP_REDIRECT_URI);
163         } catch (OAuthException | OAuthResponseException e) {
164             if (logger.isDebugEnabled()) {
165                 logger.debug("Error fetching access token: {}", e.getMessage(), e);
166             } else {
167                 logger.warn("Error fetching access token. Invalid oauth code? Please generate a new one.");
168             }
169         } catch (IOException e) {
170             logger.warn("An unexpected IOException occurred when fetching access token: {}", e.getMessage());
171         }
172     }
173
174     private void deleteAuthCode() {
175         try {
176             org.osgi.service.cm.Configuration serviceConfig = configAdmin.getConfiguration(SERVICE_PID);
177             Dictionary<String, Object> configProperties = serviceConfig.getProperties();
178             if (configProperties != null) {
179                 configProperties.put("oauthCode", "");
180                 serviceConfig.update(configProperties);
181             }
182         } catch (IOException e) {
183             logger.warn("Failed to delete current oauth code, please delete it manually.");
184         }
185     }
186
187     private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean keepStreaming,
188             Locale locale, Set<String> set) {
189         Credentials credentials = getCredentials();
190         return executor.submit(() -> {
191             logger.debug("Background recognize starting");
192             ClientStream<StreamingRecognizeRequest> clientStream = null;
193             try (SpeechClient client = SpeechClient
194                     .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
195                 TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config,
196                         (t) -> keepStreaming.set(false));
197                 clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
198                 streamAudio(clientStream, audioStream, responseObserver, keepStreaming, locale);
199                 clientStream.closeSend();
200                 logger.debug("Background recognize done");
201             } catch (IOException e) {
202                 if (clientStream != null && clientStream.isSendReady()) {
203                     clientStream.closeSendWithError(e);
204                 } else if (!config.errorMessage.isBlank()) {
205                     logger.warn("Error running speech to text: {}", e.getMessage());
206                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
207                 }
208             }
209         });
210     }
211
212     private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
213             TranscriptionListener responseObserver, AtomicBoolean keepStreaming, Locale locale) throws IOException {
214         // Gather stream info and send config
215         AudioFormat streamFormat = audioStream.getFormat();
216         RecognitionConfig.AudioEncoding streamEncoding;
217         if (AudioFormat.WAV.isCompatible(streamFormat)) {
218             streamEncoding = RecognitionConfig.AudioEncoding.LINEAR16;
219         } else if (AudioFormat.OGG.isCompatible(streamFormat)) {
220             streamEncoding = RecognitionConfig.AudioEncoding.OGG_OPUS;
221         } else {
222             logger.debug("Unsupported format {}", streamFormat);
223             return;
224         }
225         Integer channelsObject = streamFormat.getChannels();
226         int channels = channelsObject != null ? channelsObject : 1;
227         Long longFrequency = streamFormat.getFrequency();
228         if (longFrequency == null) {
229             logger.debug("Missing frequency info");
230             return;
231         }
232         int frequency = Math.toIntExact(longFrequency);
233         // First thing we need to send the stream config
234         sendStreamConfig(clientStream, streamEncoding, frequency, channels, locale);
235         // Loop sending audio data
236         long startTime = System.currentTimeMillis();
237         long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
238         long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
239         int readBytes = 6400;
240         while (keepStreaming.get()) {
241             byte[] data = new byte[readBytes];
242             int dataN = audioStream.read(data);
243             if (!keepStreaming.get() || isExpiredInterval(maxTranscriptionMillis, startTime)) {
244                 logger.debug("Stops listening, max transcription time reached");
245                 break;
246             }
247             if (!config.singleUtteranceMode
248                     && isExpiredInterval(maxSilenceMillis, responseObserver.getLastInputTime())) {
249                 logger.debug("Stops listening, max silence time reached");
250                 break;
251             }
252             if (dataN != readBytes) {
253                 try {
254                     Thread.sleep(100);
255                 } catch (InterruptedException e) {
256                 }
257                 continue;
258             }
259             StreamingRecognizeRequest dataRequest = StreamingRecognizeRequest.newBuilder()
260                     .setAudioContent(ByteString.copyFrom(data)).build();
261             logger.debug("Sending audio data {}", dataN);
262             clientStream.send(dataRequest);
263         }
264     }
265
266     private void sendStreamConfig(ClientStream<StreamingRecognizeRequest> clientStream,
267             RecognitionConfig.AudioEncoding encoding, int sampleRate, int channels, Locale locale) {
268         RecognitionConfig recognitionConfig = RecognitionConfig.newBuilder().setEncoding(encoding)
269                 .setAudioChannelCount(channels).setLanguageCode(locale.toLanguageTag()).setSampleRateHertz(sampleRate)
270                 .build();
271
272         StreamingRecognitionConfig streamingRecognitionConfig = StreamingRecognitionConfig.newBuilder()
273                 .setConfig(recognitionConfig).setInterimResults(false).setSingleUtterance(config.singleUtteranceMode)
274                 .build();
275
276         clientStream
277                 .send(StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingRecognitionConfig).build());
278     }
279
280     private @Nullable Credentials getCredentials() {
281         String accessToken = null;
282         try {
283             OAuthClientService oAuthService = this.oAuthService;
284             if (oAuthService != null) {
285                 AccessTokenResponse response = oAuthService.getAccessTokenResponse();
286                 if (response != null) {
287                     accessToken = response.getAccessToken();
288                 }
289             }
290         } catch (OAuthException | IOException | OAuthResponseException e) {
291             logger.warn("Access token error: {}", e.getMessage());
292         }
293         if (accessToken == null) {
294             logger.warn("Missed google cloud access token");
295             return null;
296         }
297         return OAuth2Credentials.create(new AccessToken(accessToken, null));
298     }
299
300     private boolean isExpiredInterval(long interval, long referenceTime) {
301         return System.currentTimeMillis() - referenceTime > interval;
302     }
303
304     private static class TranscriptionListener implements ResponseObserver<StreamingRecognizeResponse> {
305         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
306         private final StringBuilder transcriptBuilder = new StringBuilder();
307         private final STTListener sttListener;
308         GoogleSTTConfiguration config;
309         private final Consumer<@Nullable Throwable> completeListener;
310         private float confidenceSum = 0;
311         private int responseCount = 0;
312         private long lastInputTime = 0;
313
314         public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config,
315                 Consumer<@Nullable Throwable> completeListener) {
316             this.sttListener = sttListener;
317             this.config = config;
318             this.completeListener = completeListener;
319         }
320
321         public void onStart(@Nullable StreamController controller) {
322             sttListener.sttEventReceived(new SpeechStartEvent());
323             lastInputTime = System.currentTimeMillis();
324         }
325
326         public void onResponse(StreamingRecognizeResponse response) {
327             lastInputTime = System.currentTimeMillis();
328             List<StreamingRecognitionResult> results = response.getResultsList();
329             logger.debug("Got {} results", response.getResultsList().size());
330             if (results.isEmpty()) {
331                 logger.debug("No results");
332                 return;
333             }
334             results.forEach(result -> {
335                 List<SpeechRecognitionAlternative> alternatives = result.getAlternativesList();
336                 logger.debug("Got {} alternatives", alternatives.size());
337                 SpeechRecognitionAlternative alternative = alternatives.stream()
338                         .max(Comparator.comparing(SpeechRecognitionAlternative::getConfidence)).orElse(null);
339                 if (alternative == null) {
340                     return;
341                 }
342                 String transcript = alternative.getTranscript();
343                 logger.debug("Alternative transcript: {}", transcript);
344                 logger.debug("Alternative confidence: {}", alternative.getConfidence());
345                 if (result.getIsFinal()) {
346                     transcriptBuilder.append(transcript);
347                     confidenceSum += alternative.getConfidence();
348                     responseCount++;
349                     // when in single utterance mode we can just get one final result so complete
350                     if (config.singleUtteranceMode) {
351                         completeListener.accept(null);
352                     }
353                 }
354             });
355         }
356
357         public void onComplete() {
358             sttListener.sttEventReceived(new SpeechStopEvent());
359             float averageConfidence = confidenceSum / (float) responseCount;
360             String transcript = transcriptBuilder.toString();
361             if (!transcript.isBlank()) {
362                 sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
363             } else {
364                 if (!config.noResultsMessage.isBlank()) {
365                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
366                 } else {
367                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
368                 }
369             }
370         }
371
372         public void onError(@Nullable Throwable t) {
373             logger.warn("Recognition error: ", t);
374             completeListener.accept(t);
375             sttListener.sttEventReceived(new SpeechStopEvent());
376             if (!config.errorMessage.isBlank()) {
377                 sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
378             } else {
379                 String errorMessage = t.getMessage();
380                 sttListener.sttEventReceived(
381                         new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
382             }
383         }
384
385         public long getLastInputTime() {
386             return lastInputTime;
387         }
388     }
389 }