]> git.basschouten.com Git - openhab-addons.git/blob
5f82829cf14775dd92412d00756f9cd703d88b27
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2022 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.googlestt.internal;
14
15 import static org.openhab.voice.googlestt.internal.GoogleSTTConstants.*;
16
17 import java.io.IOException;
18 import java.util.Comparator;
19 import java.util.Dictionary;
20 import java.util.List;
21 import java.util.Locale;
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.Future;
25 import java.util.concurrent.ScheduledExecutorService;
26 import java.util.concurrent.atomic.AtomicBoolean;
27 import java.util.function.Consumer;
28
29 import org.eclipse.jdt.annotation.NonNullByDefault;
30 import org.eclipse.jdt.annotation.Nullable;
31 import org.openhab.core.audio.AudioFormat;
32 import org.openhab.core.audio.AudioStream;
33 import org.openhab.core.auth.client.oauth2.AccessTokenResponse;
34 import org.openhab.core.auth.client.oauth2.OAuthClientService;
35 import org.openhab.core.auth.client.oauth2.OAuthException;
36 import org.openhab.core.auth.client.oauth2.OAuthFactory;
37 import org.openhab.core.auth.client.oauth2.OAuthResponseException;
38 import org.openhab.core.common.ThreadPoolManager;
39 import org.openhab.core.config.core.ConfigurableService;
40 import org.openhab.core.config.core.Configuration;
41 import org.openhab.core.voice.RecognitionStartEvent;
42 import org.openhab.core.voice.RecognitionStopEvent;
43 import org.openhab.core.voice.STTListener;
44 import org.openhab.core.voice.STTService;
45 import org.openhab.core.voice.STTServiceHandle;
46 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
47 import org.openhab.core.voice.SpeechRecognitionEvent;
48 import org.osgi.framework.Constants;
49 import org.osgi.service.cm.ConfigurationAdmin;
50 import org.osgi.service.component.annotations.Activate;
51 import org.osgi.service.component.annotations.Component;
52 import org.osgi.service.component.annotations.Modified;
53 import org.osgi.service.component.annotations.Reference;
54 import org.slf4j.Logger;
55 import org.slf4j.LoggerFactory;
56
57 import com.google.api.gax.rpc.ClientStream;
58 import com.google.api.gax.rpc.ResponseObserver;
59 import com.google.api.gax.rpc.StreamController;
60 import com.google.auth.Credentials;
61 import com.google.auth.oauth2.AccessToken;
62 import com.google.auth.oauth2.OAuth2Credentials;
63 import com.google.cloud.speech.v1.RecognitionConfig;
64 import com.google.cloud.speech.v1.SpeechClient;
65 import com.google.cloud.speech.v1.SpeechRecognitionAlternative;
66 import com.google.cloud.speech.v1.SpeechSettings;
67 import com.google.cloud.speech.v1.StreamingRecognitionConfig;
68 import com.google.cloud.speech.v1.StreamingRecognitionResult;
69 import com.google.cloud.speech.v1.StreamingRecognizeRequest;
70 import com.google.cloud.speech.v1.StreamingRecognizeResponse;
71 import com.google.protobuf.ByteString;
72
73 import io.grpc.LoadBalancerRegistry;
74 import io.grpc.internal.PickFirstLoadBalancerProvider;
75
76 /**
77  * The {@link GoogleSTTService} class is a service implementation to use Google Cloud Speech-to-Text features.
78  *
79  * @author Miguel Álvarez - Initial contribution
80  */
81 @NonNullByDefault
82 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
83 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
84         + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
85 public class GoogleSTTService implements STTService {
86
87     private static final String GCP_AUTH_URI = "https://accounts.google.com/o/oauth2/auth";
88     private static final String GCP_TOKEN_URI = "https://accounts.google.com/o/oauth2/token";
89     private static final String GCP_REDIRECT_URI = "urn:ietf:wg:oauth:2.0:oob";
90     private static final String GCP_SCOPE = "https://www.googleapis.com/auth/cloud-platform";
91
92     private final Logger logger = LoggerFactory.getLogger(GoogleSTTService.class);
93     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-googlestt");
94     private final OAuthFactory oAuthFactory;
95     private final ConfigurationAdmin configAdmin;
96
97     private GoogleSTTConfiguration config = new GoogleSTTConfiguration();
98     private @Nullable OAuthClientService oAuthService;
99
100     @Activate
101     public GoogleSTTService(final @Reference OAuthFactory oAuthFactory,
102             final @Reference ConfigurationAdmin configAdmin) {
103         LoadBalancerRegistry.getDefaultRegistry().register(new PickFirstLoadBalancerProvider());
104         this.oAuthFactory = oAuthFactory;
105         this.configAdmin = configAdmin;
106     }
107
108     @Activate
109     protected void activate(Map<String, Object> config) {
110         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
111         executor.submit(() -> GoogleSTTLocale.loadLocales(this.config.refreshSupportedLocales));
112         updateConfig();
113     }
114
115     @Modified
116     protected void modified(Map<String, Object> config) {
117         this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
118         updateConfig();
119     }
120
121     @Override
122     public String getId() {
123         return SERVICE_ID;
124     }
125
126     @Override
127     public String getLabel(@Nullable Locale locale) {
128         return SERVICE_NAME;
129     }
130
131     @Override
132     public Set<Locale> getSupportedLocales() {
133         return GoogleSTTLocale.getSupportedLocales();
134     }
135
136     @Override
137     public Set<AudioFormat> getSupportedFormats() {
138         return Set.of(
139                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
140                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 8000L),
141                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 12000L),
142                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 16000L),
143                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 24000L),
144                 new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 48000L));
145     }
146
147     @Override
148     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
149             Set<String> set) {
150         AtomicBoolean keepStreaming = new AtomicBoolean(true);
151         Future scheduledTask = backgroundRecognize(sttListener, audioStream, keepStreaming, locale, set);
152         return new STTServiceHandle() {
153             @Override
154             public void abort() {
155                 keepStreaming.set(false);
156                 try {
157                     Thread.sleep(100);
158                 } catch (InterruptedException e) {
159                 }
160                 scheduledTask.cancel(true);
161             }
162         };
163     }
164
165     private void updateConfig() {
166         String clientId = this.config.clientId;
167         String clientSecret = this.config.clientSecret;
168         if (!clientId.isBlank() && !clientSecret.isBlank()) {
169             var oAuthService = oAuthFactory.createOAuthClientService(SERVICE_PID, GCP_TOKEN_URI, GCP_AUTH_URI, clientId,
170                     clientSecret, GCP_SCOPE, false);
171             this.oAuthService = oAuthService;
172             if (!this.config.oauthCode.isEmpty()) {
173                 getAccessToken(oAuthService, this.config.oauthCode);
174                 deleteAuthCode();
175             }
176         } else {
177             logger.warn("Missing authentication configuration to access Google Cloud STT API.");
178         }
179     }
180
181     private void getAccessToken(OAuthClientService oAuthService, String oauthCode) {
182         logger.debug("Trying to get access and refresh tokens.");
183         try {
184             oAuthService.getAccessTokenResponseByAuthorizationCode(oauthCode, GCP_REDIRECT_URI);
185         } catch (OAuthException | OAuthResponseException e) {
186             if (logger.isDebugEnabled()) {
187                 logger.debug("Error fetching access token: {}", e.getMessage(), e);
188             } else {
189                 logger.warn("Error fetching access token. Invalid oauth code? Please generate a new one.");
190             }
191         } catch (IOException e) {
192             logger.warn("An unexpected IOException occurred when fetching access token: {}", e.getMessage());
193         }
194     }
195
196     private void deleteAuthCode() {
197         try {
198             org.osgi.service.cm.Configuration serviceConfig = configAdmin.getConfiguration(SERVICE_PID);
199             Dictionary<String, Object> configProperties = serviceConfig.getProperties();
200             if (configProperties != null) {
201                 configProperties.put("oauthCode", "");
202                 serviceConfig.update(configProperties);
203             }
204         } catch (IOException e) {
205             logger.warn("Failed to delete current oauth code, please delete it manually.");
206         }
207     }
208
209     private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean keepStreaming,
210             Locale locale, Set<String> set) {
211         Credentials credentials = getCredentials();
212         return executor.submit(() -> {
213             logger.debug("Background recognize starting");
214             ClientStream<StreamingRecognizeRequest> clientStream = null;
215             try (SpeechClient client = SpeechClient
216                     .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
217                 TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config,
218                         (t) -> keepStreaming.set(false));
219                 clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
220                 streamAudio(clientStream, audioStream, responseObserver, keepStreaming, locale);
221                 clientStream.closeSend();
222                 logger.debug("Background recognize done");
223             } catch (IOException e) {
224                 if (clientStream != null && clientStream.isSendReady()) {
225                     clientStream.closeSendWithError(e);
226                 } else if (!config.errorMessage.isBlank()) {
227                     logger.warn("Error running speech to text: {}", e.getMessage());
228                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
229                 }
230             }
231         });
232     }
233
234     private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
235             TranscriptionListener responseObserver, AtomicBoolean keepStreaming, Locale locale) throws IOException {
236         // Gather stream info and send config
237         AudioFormat streamFormat = audioStream.getFormat();
238         RecognitionConfig.AudioEncoding streamEncoding;
239         if (AudioFormat.WAV.isCompatible(streamFormat)) {
240             streamEncoding = RecognitionConfig.AudioEncoding.LINEAR16;
241         } else if (AudioFormat.OGG.isCompatible(streamFormat)) {
242             streamEncoding = RecognitionConfig.AudioEncoding.OGG_OPUS;
243         } else {
244             logger.debug("Unsupported format {}", streamFormat);
245             return;
246         }
247         Integer channelsObject = streamFormat.getChannels();
248         int channels = channelsObject != null ? channelsObject : 1;
249         Long longFrequency = streamFormat.getFrequency();
250         if (longFrequency == null) {
251             logger.debug("Missing frequency info");
252             return;
253         }
254         int frequency = Math.toIntExact(longFrequency);
255         // First thing we need to send the stream config
256         sendStreamConfig(clientStream, streamEncoding, frequency, channels, locale);
257         // Loop sending audio data
258         long startTime = System.currentTimeMillis();
259         long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
260         long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
261         int readBytes = 6400;
262         while (keepStreaming.get()) {
263             byte[] data = new byte[readBytes];
264             int dataN = audioStream.read(data);
265             if (!keepStreaming.get() || isExpiredInterval(maxTranscriptionMillis, startTime)) {
266                 logger.debug("Stops listening, max transcription time reached");
267                 break;
268             }
269             if (!config.singleUtteranceMode
270                     && isExpiredInterval(maxSilenceMillis, responseObserver.getLastInputTime())) {
271                 logger.debug("Stops listening, max silence time reached");
272                 break;
273             }
274             if (dataN != readBytes) {
275                 try {
276                     Thread.sleep(100);
277                 } catch (InterruptedException e) {
278                 }
279                 continue;
280             }
281             StreamingRecognizeRequest dataRequest = StreamingRecognizeRequest.newBuilder()
282                     .setAudioContent(ByteString.copyFrom(data)).build();
283             logger.debug("Sending audio data {}", dataN);
284             clientStream.send(dataRequest);
285         }
286     }
287
288     private void sendStreamConfig(ClientStream<StreamingRecognizeRequest> clientStream,
289             RecognitionConfig.AudioEncoding encoding, int sampleRate, int channels, Locale locale) {
290         RecognitionConfig recognitionConfig = RecognitionConfig.newBuilder().setEncoding(encoding)
291                 .setAudioChannelCount(channels).setLanguageCode(locale.toLanguageTag()).setSampleRateHertz(sampleRate)
292                 .build();
293
294         StreamingRecognitionConfig streamingRecognitionConfig = StreamingRecognitionConfig.newBuilder()
295                 .setConfig(recognitionConfig).setInterimResults(false).setSingleUtterance(config.singleUtteranceMode)
296                 .build();
297
298         clientStream
299                 .send(StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingRecognitionConfig).build());
300     }
301
302     private @Nullable Credentials getCredentials() {
303         String accessToken = null;
304         try {
305             OAuthClientService oAuthService = this.oAuthService;
306             if (oAuthService != null) {
307                 AccessTokenResponse response = oAuthService.getAccessTokenResponse();
308                 if (response != null) {
309                     accessToken = response.getAccessToken();
310                 }
311             }
312         } catch (OAuthException | IOException | OAuthResponseException e) {
313             logger.warn("Access token error: {}", e.getMessage());
314         }
315         if (accessToken == null) {
316             logger.warn("Missed google cloud access token");
317             return null;
318         }
319         return OAuth2Credentials.create(new AccessToken(accessToken, null));
320     }
321
322     private boolean isExpiredInterval(long interval, long referenceTime) {
323         return System.currentTimeMillis() - referenceTime > interval;
324     }
325
326     private static class TranscriptionListener implements ResponseObserver<StreamingRecognizeResponse> {
327         private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
328         private final StringBuilder transcriptBuilder = new StringBuilder();
329         private final STTListener sttListener;
330         GoogleSTTConfiguration config;
331         private final Consumer<@Nullable Throwable> completeListener;
332         private float confidenceSum = 0;
333         private int responseCount = 0;
334         private long lastInputTime = 0;
335
336         public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config,
337                 Consumer<@Nullable Throwable> completeListener) {
338             this.sttListener = sttListener;
339             this.config = config;
340             this.completeListener = completeListener;
341         }
342
343         @Override
344         public void onStart(@Nullable StreamController controller) {
345             sttListener.sttEventReceived(new RecognitionStartEvent());
346             lastInputTime = System.currentTimeMillis();
347         }
348
349         @Override
350         public void onResponse(StreamingRecognizeResponse response) {
351             lastInputTime = System.currentTimeMillis();
352             List<StreamingRecognitionResult> results = response.getResultsList();
353             logger.debug("Got {} results", response.getResultsList().size());
354             if (results.isEmpty()) {
355                 logger.debug("No results");
356                 return;
357             }
358             results.forEach(result -> {
359                 List<SpeechRecognitionAlternative> alternatives = result.getAlternativesList();
360                 logger.debug("Got {} alternatives", alternatives.size());
361                 SpeechRecognitionAlternative alternative = alternatives.stream()
362                         .max(Comparator.comparing(SpeechRecognitionAlternative::getConfidence)).orElse(null);
363                 if (alternative == null) {
364                     return;
365                 }
366                 String transcript = alternative.getTranscript();
367                 logger.debug("Alternative transcript: {}", transcript);
368                 logger.debug("Alternative confidence: {}", alternative.getConfidence());
369                 if (result.getIsFinal()) {
370                     transcriptBuilder.append(transcript);
371                     confidenceSum += alternative.getConfidence();
372                     responseCount++;
373                     // when in single utterance mode we can just get one final result so complete
374                     if (config.singleUtteranceMode) {
375                         completeListener.accept(null);
376                     }
377                 }
378             });
379         }
380
381         @Override
382         public void onComplete() {
383             sttListener.sttEventReceived(new RecognitionStopEvent());
384             float averageConfidence = confidenceSum / responseCount;
385             String transcript = transcriptBuilder.toString();
386             if (!transcript.isBlank()) {
387                 sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
388             } else {
389                 if (!config.noResultsMessage.isBlank()) {
390                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
391                 } else {
392                     sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
393                 }
394             }
395         }
396
397         @Override
398         public void onError(@Nullable Throwable t) {
399             logger.warn("Recognition error: ", t);
400             completeListener.accept(t);
401             sttListener.sttEventReceived(new RecognitionStopEvent());
402             if (!config.errorMessage.isBlank()) {
403                 sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
404             } else {
405                 String errorMessage = t.getMessage();
406                 sttListener.sttEventReceived(
407                         new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
408             }
409         }
410
411         public long getLastInputTime() {
412             return lastInputTime;
413         }
414     }
415 }