2 * Copyright (c) 2010-2024 Contributors to the openHAB project
4 * See the NOTICE file(s) distributed with this work for additional
7 * This program and the accompanying materials are made available under the
8 * terms of the Eclipse Public License 2.0 which is available at
9 * http://www.eclipse.org/legal/epl-2.0
11 * SPDX-License-Identifier: EPL-2.0
13 package org.openhab.voice.googlestt.internal;
15 import static org.openhab.voice.googlestt.internal.GoogleSTTConstants.*;
17 import java.io.IOException;
18 import java.util.Comparator;
19 import java.util.Dictionary;
20 import java.util.List;
21 import java.util.Locale;
24 import java.util.concurrent.Future;
25 import java.util.concurrent.ScheduledExecutorService;
26 import java.util.concurrent.atomic.AtomicBoolean;
28 import org.eclipse.jdt.annotation.NonNullByDefault;
29 import org.eclipse.jdt.annotation.Nullable;
30 import org.openhab.core.audio.AudioFormat;
31 import org.openhab.core.audio.AudioStream;
32 import org.openhab.core.audio.utils.AudioWaveUtils;
33 import org.openhab.core.auth.client.oauth2.AccessTokenResponse;
34 import org.openhab.core.auth.client.oauth2.OAuthClientService;
35 import org.openhab.core.auth.client.oauth2.OAuthException;
36 import org.openhab.core.auth.client.oauth2.OAuthFactory;
37 import org.openhab.core.auth.client.oauth2.OAuthResponseException;
38 import org.openhab.core.common.ThreadPoolManager;
39 import org.openhab.core.config.core.ConfigurableService;
40 import org.openhab.core.config.core.Configuration;
41 import org.openhab.core.voice.RecognitionStartEvent;
42 import org.openhab.core.voice.RecognitionStopEvent;
43 import org.openhab.core.voice.STTListener;
44 import org.openhab.core.voice.STTService;
45 import org.openhab.core.voice.STTServiceHandle;
46 import org.openhab.core.voice.SpeechRecognitionErrorEvent;
47 import org.openhab.core.voice.SpeechRecognitionEvent;
48 import org.osgi.framework.Constants;
49 import org.osgi.service.cm.ConfigurationAdmin;
50 import org.osgi.service.component.annotations.Activate;
51 import org.osgi.service.component.annotations.Component;
52 import org.osgi.service.component.annotations.Deactivate;
53 import org.osgi.service.component.annotations.Modified;
54 import org.osgi.service.component.annotations.Reference;
55 import org.slf4j.Logger;
56 import org.slf4j.LoggerFactory;
58 import com.google.api.gax.rpc.ClientStream;
59 import com.google.api.gax.rpc.ResponseObserver;
60 import com.google.api.gax.rpc.StreamController;
61 import com.google.auth.Credentials;
62 import com.google.auth.oauth2.AccessToken;
63 import com.google.auth.oauth2.OAuth2Credentials;
64 import com.google.cloud.speech.v1.RecognitionConfig;
65 import com.google.cloud.speech.v1.SpeechClient;
66 import com.google.cloud.speech.v1.SpeechRecognitionAlternative;
67 import com.google.cloud.speech.v1.SpeechSettings;
68 import com.google.cloud.speech.v1.StreamingRecognitionConfig;
69 import com.google.cloud.speech.v1.StreamingRecognitionResult;
70 import com.google.cloud.speech.v1.StreamingRecognizeRequest;
71 import com.google.cloud.speech.v1.StreamingRecognizeResponse;
72 import com.google.protobuf.ByteString;
74 import io.grpc.LoadBalancerRegistry;
75 import io.grpc.internal.PickFirstLoadBalancerProvider;
78 * The {@link GoogleSTTService} class is a service implementation to use Google Cloud Speech-to-Text features.
80 * @author Miguel Álvarez - Initial contribution
83 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
84 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
85 + " Speech-to-Text", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
86 public class GoogleSTTService implements STTService {
88 private static final String GCP_AUTH_URI = "https://accounts.google.com/o/oauth2/auth";
89 private static final String GCP_TOKEN_URI = "https://accounts.google.com/o/oauth2/token";
90 private static final String GCP_REDIRECT_URI = "https://www.google.com";
91 private static final String GCP_SCOPE = "https://www.googleapis.com/auth/cloud-platform";
93 private final Logger logger = LoggerFactory.getLogger(GoogleSTTService.class);
94 private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-googlestt");
95 private final OAuthFactory oAuthFactory;
96 private final ConfigurationAdmin configAdmin;
98 private GoogleSTTConfiguration config = new GoogleSTTConfiguration();
99 private @Nullable OAuthClientService oAuthService;
102 public GoogleSTTService(final @Reference OAuthFactory oAuthFactory,
103 final @Reference ConfigurationAdmin configAdmin) {
104 LoadBalancerRegistry.getDefaultRegistry().register(new PickFirstLoadBalancerProvider());
105 this.oAuthFactory = oAuthFactory;
106 this.configAdmin = configAdmin;
110 protected void activate(Map<String, Object> config) {
111 this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
112 executor.submit(() -> GoogleSTTLocale.loadLocales(this.config.refreshSupportedLocales));
117 protected void modified(Map<String, Object> config) {
118 this.config = new Configuration(config).as(GoogleSTTConfiguration.class);
123 protected void dispose() {
124 if (oAuthService != null) {
125 oAuthFactory.ungetOAuthService(SERVICE_PID);
131 public String getId() {
136 public String getLabel(@Nullable Locale locale) {
141 public Set<Locale> getSupportedLocales() {
142 return GoogleSTTLocale.getSupportedLocales();
146 public Set<AudioFormat> getSupportedFormats() {
148 new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
149 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L));
153 public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
155 AtomicBoolean aborted = new AtomicBoolean(false);
156 backgroundRecognize(sttListener, audioStream, aborted, locale, set);
157 return new STTServiceHandle() {
159 public void abort() {
165 private void updateConfig() {
166 if (oAuthService != null) {
167 oAuthFactory.ungetOAuthService(SERVICE_PID);
170 String clientId = this.config.clientId;
171 String clientSecret = this.config.clientSecret;
172 if (!clientId.isBlank() && !clientSecret.isBlank()) {
173 var oAuthService = oAuthFactory.createOAuthClientService(SERVICE_PID, GCP_TOKEN_URI, GCP_AUTH_URI, clientId,
174 clientSecret, GCP_SCOPE, false);
175 this.oAuthService = oAuthService;
176 if (!this.config.oauthCode.isEmpty()) {
177 getAccessToken(oAuthService, this.config.oauthCode);
181 logger.warn("Missing authentication configuration to access Google Cloud STT API.");
185 private void getAccessToken(OAuthClientService oAuthService, String oauthCode) {
186 logger.debug("Trying to get access and refresh tokens.");
188 AccessTokenResponse response = oAuthService.getAccessTokenResponseByAuthorizationCode(oauthCode,
190 if (response.getRefreshToken() == null || response.getRefreshToken().isEmpty()) {
191 logger.warn("Error fetching refresh token. Please try to reauthorize.");
193 } catch (OAuthException | OAuthResponseException e) {
194 if (logger.isDebugEnabled()) {
195 logger.debug("Error fetching access token: {}", e.getMessage(), e);
197 logger.warn("Error fetching access token. Invalid oauth code? Please generate a new one.");
199 } catch (IOException e) {
200 logger.warn("An unexpected IOException occurred when fetching access token: {}", e.getMessage());
204 private void deleteAuthCode() {
206 org.osgi.service.cm.Configuration serviceConfig = configAdmin.getConfiguration(SERVICE_PID);
207 Dictionary<String, Object> configProperties = serviceConfig.getProperties();
208 if (configProperties != null) {
209 configProperties.put("oauthCode", "");
210 serviceConfig.update(configProperties);
212 } catch (IOException e) {
213 logger.warn("Failed to delete current oauth code, please delete it manually.");
217 private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
218 Locale locale, Set<String> set) {
219 Credentials credentials = getCredentials();
220 return executor.submit(() -> {
221 logger.debug("Background recognize starting");
222 ClientStream<StreamingRecognizeRequest> clientStream = null;
223 try (SpeechClient client = SpeechClient
224 .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
225 TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
226 clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
227 streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
228 clientStream.closeSend();
229 logger.debug("Background recognize done");
230 } catch (IOException e) {
231 if (clientStream != null && clientStream.isSendReady()) {
232 clientStream.closeSendWithError(e);
233 } else if (!config.errorMessage.isBlank()) {
234 logger.warn("Error running speech to text: {}", e.getMessage());
235 sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
241 private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
242 TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
243 // Gather stream info and send config
244 AudioFormat streamFormat = audioStream.getFormat();
245 RecognitionConfig.AudioEncoding streamEncoding;
246 if (AudioFormat.PCM_SIGNED.isCompatible(streamFormat) || AudioFormat.WAV.isCompatible(streamFormat)) {
247 streamEncoding = RecognitionConfig.AudioEncoding.LINEAR16;
249 logger.debug("Unsupported format {}", streamFormat);
252 Integer channelsObject = streamFormat.getChannels();
253 int channels = channelsObject != null ? channelsObject : 1;
254 Long longFrequency = streamFormat.getFrequency();
255 if (longFrequency == null) {
256 logger.debug("Missing frequency info");
259 int frequency = Math.toIntExact(longFrequency);
260 // First thing we need to send the stream config
261 sendStreamConfig(clientStream, streamEncoding, frequency, channels, locale);
262 // Loop sending audio data
263 long startTime = System.currentTimeMillis();
264 long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
265 long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
266 final int bufferSize = 6400;
268 int remaining = bufferSize;
269 if (AudioFormat.CONTAINER_WAVE.equals(streamFormat.getContainer())) {
270 AudioWaveUtils.removeFMT(audioStream);
272 byte[] audioBuffer = new byte[bufferSize];
273 while (!aborted.get() && !responseObserver.isDone()) {
274 numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
276 logger.debug("Stops listening, aborted");
279 if (numBytesRead == -1) {
280 logger.debug("End of stream");
283 if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
284 logger.debug("Stops listening, max transcription time reached");
287 if (!config.singleUtteranceMode
288 && isExpiredInterval(maxSilenceMillis, responseObserver.getLastInputTime())) {
289 logger.debug("Stops listening, max silence time reached");
292 if (numBytesRead != remaining) {
293 remaining = remaining - numBytesRead;
296 remaining = bufferSize;
297 StreamingRecognizeRequest dataRequest = StreamingRecognizeRequest.newBuilder()
298 .setAudioContent(ByteString.copyFrom(audioBuffer)).build();
299 logger.debug("Sending audio data {}", bufferSize);
300 clientStream.send(dataRequest);
305 private void sendStreamConfig(ClientStream<StreamingRecognizeRequest> clientStream,
306 RecognitionConfig.AudioEncoding encoding, int sampleRate, int channels, Locale locale) {
307 RecognitionConfig recognitionConfig = RecognitionConfig.newBuilder().setEncoding(encoding)
308 .setAudioChannelCount(channels).setLanguageCode(locale.toLanguageTag()).setSampleRateHertz(sampleRate)
311 StreamingRecognitionConfig streamingRecognitionConfig = StreamingRecognitionConfig.newBuilder()
312 .setConfig(recognitionConfig).setInterimResults(false).setSingleUtterance(config.singleUtteranceMode)
316 .send(StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingRecognitionConfig).build());
319 private @Nullable Credentials getCredentials() {
320 String accessToken = null;
321 String refreshToken = null;
323 OAuthClientService oAuthService = this.oAuthService;
324 if (oAuthService != null) {
325 AccessTokenResponse response = oAuthService.getAccessTokenResponse();
326 if (response != null) {
327 accessToken = response.getAccessToken();
328 refreshToken = response.getRefreshToken();
331 } catch (OAuthException | IOException | OAuthResponseException e) {
332 logger.warn("Access token error: {}", e.getMessage());
334 if (accessToken == null || refreshToken == null) {
335 logger.warn("Missed google cloud access and/or refresh token");
338 return OAuth2Credentials.create(new AccessToken(accessToken, null));
341 private boolean isExpiredInterval(long interval, long referenceTime) {
342 return System.currentTimeMillis() - referenceTime > interval;
345 private static class TranscriptionListener implements ResponseObserver<StreamingRecognizeResponse> {
346 private final Logger logger = LoggerFactory.getLogger(TranscriptionListener.class);
347 private final StringBuilder transcriptBuilder = new StringBuilder();
348 private final STTListener sttListener;
349 GoogleSTTConfiguration config;
350 private final AtomicBoolean aborted;
351 private float confidenceSum = 0;
352 private int responseCount = 0;
353 private long lastInputTime = 0;
354 private boolean done = false;
356 public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
357 this.sttListener = sttListener;
358 this.config = config;
359 this.aborted = aborted;
363 public void onStart(@Nullable StreamController controller) {
364 sttListener.sttEventReceived(new RecognitionStartEvent());
365 lastInputTime = System.currentTimeMillis();
369 public void onResponse(StreamingRecognizeResponse response) {
370 lastInputTime = System.currentTimeMillis();
371 List<StreamingRecognitionResult> results = response.getResultsList();
372 logger.debug("Got {} results", response.getResultsList().size());
373 if (results.isEmpty()) {
374 logger.debug("No results");
377 results.forEach(result -> {
378 List<SpeechRecognitionAlternative> alternatives = result.getAlternativesList();
379 logger.debug("Got {} alternatives", alternatives.size());
380 SpeechRecognitionAlternative alternative = alternatives.stream()
381 .max(Comparator.comparing(SpeechRecognitionAlternative::getConfidence)).orElse(null);
382 if (alternative == null) {
385 String transcript = alternative.getTranscript();
386 logger.debug("Alternative transcript: {}", transcript);
387 logger.debug("Alternative confidence: {}", alternative.getConfidence());
388 if (result.getIsFinal()) {
389 transcriptBuilder.append(transcript);
390 confidenceSum += alternative.getConfidence();
392 // when in single utterance mode we can just get one final result so complete
393 if (config.singleUtteranceMode) {
401 public void onComplete() {
402 if (!aborted.getAndSet(true)) {
403 sttListener.sttEventReceived(new RecognitionStopEvent());
404 float averageConfidence = confidenceSum / responseCount;
405 String transcript = transcriptBuilder.toString();
406 if (!transcript.isBlank()) {
407 sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
408 } else if (!config.noResultsMessage.isBlank()) {
409 sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
411 sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
417 public void onError(@Nullable Throwable t) {
418 logger.warn("Recognition error: ", t);
419 if (!aborted.getAndSet(true)) {
420 sttListener.sttEventReceived(new RecognitionStopEvent());
421 if (!config.errorMessage.isBlank()) {
422 sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
424 String errorMessage = t.getMessage();
425 sttListener.sttEventReceived(
426 new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
431 public boolean isDone() {
435 public long getLastInputTime() {
436 return lastInputTime;