]> git.basschouten.com Git - openhab-addons.git/commitdiff
[googlestt] lazy abort (#12317)
authorGiviMAD <GiviMAD@users.noreply.github.com>
Sun, 20 Feb 2022 11:45:31 +0000 (12:45 +0100)
committerGitHub <noreply@github.com>
Sun, 20 Feb 2022 11:45:31 +0000 (12:45 +0100)
Signed-off-by: Miguel Álvarez Díez <miguelwork92@gmail.com>
bundles/org.openhab.voice.googlestt/src/main/java/org/openhab/voice/googlestt/internal/GoogleSTTService.java

index 5f82829cf14775dd92412d00756f9cd703d88b27..a0f0bfa0f81a90859c26cac3f6f26cdea082dc24 100644 (file)
@@ -24,7 +24,6 @@ import java.util.Set;
 import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Consumer;
 
 import org.eclipse.jdt.annotation.NonNullByDefault;
 import org.eclipse.jdt.annotation.Nullable;
@@ -147,17 +146,12 @@ public class GoogleSTTService implements STTService {
     @Override
     public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
             Set<String> set) {
-        AtomicBoolean keepStreaming = new AtomicBoolean(true);
-        Future scheduledTask = backgroundRecognize(sttListener, audioStream, keepStreaming, locale, set);
+        AtomicBoolean aborted = new AtomicBoolean(false);
+        backgroundRecognize(sttListener, audioStream, aborted, locale, set);
         return new STTServiceHandle() {
             @Override
             public void abort() {
-                keepStreaming.set(false);
-                try {
-                    Thread.sleep(100);
-                } catch (InterruptedException e) {
-                }
-                scheduledTask.cancel(true);
+                aborted.set(true);
             }
         };
     }
@@ -206,7 +200,7 @@ public class GoogleSTTService implements STTService {
         }
     }
 
-    private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean keepStreaming,
+    private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
             Locale locale, Set<String> set) {
         Credentials credentials = getCredentials();
         return executor.submit(() -> {
@@ -214,10 +208,9 @@ public class GoogleSTTService implements STTService {
             ClientStream<StreamingRecognizeRequest> clientStream = null;
             try (SpeechClient client = SpeechClient
                     .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
-                TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config,
-                        (t) -> keepStreaming.set(false));
+                TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
                 clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
-                streamAudio(clientStream, audioStream, responseObserver, keepStreaming, locale);
+                streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
                 clientStream.closeSend();
                 logger.debug("Background recognize done");
             } catch (IOException e) {
@@ -232,7 +225,7 @@ public class GoogleSTTService implements STTService {
     }
 
     private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
-            TranscriptionListener responseObserver, AtomicBoolean keepStreaming, Locale locale) throws IOException {
+            TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
         // Gather stream info and send config
         AudioFormat streamFormat = audioStream.getFormat();
         RecognitionConfig.AudioEncoding streamEncoding;
@@ -259,10 +252,14 @@ public class GoogleSTTService implements STTService {
         long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
         long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
         int readBytes = 6400;
-        while (keepStreaming.get()) {
+        while (!aborted.get()) {
             byte[] data = new byte[readBytes];
             int dataN = audioStream.read(data);
-            if (!keepStreaming.get() || isExpiredInterval(maxTranscriptionMillis, startTime)) {
+            if (aborted.get()) {
+                logger.debug("Stops listening, aborted");
+                break;
+            }
+            if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
                 logger.debug("Stops listening, max transcription time reached");
                 break;
             }
@@ -328,16 +325,15 @@ public class GoogleSTTService implements STTService {
         private final StringBuilder transcriptBuilder = new StringBuilder();
         private final STTListener sttListener;
         GoogleSTTConfiguration config;
-        private final Consumer<@Nullable Throwable> completeListener;
+        private final AtomicBoolean aborted;
         private float confidenceSum = 0;
         private int responseCount = 0;
         private long lastInputTime = 0;
 
-        public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config,
-                Consumer<@Nullable Throwable> completeListener) {
+        public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
             this.sttListener = sttListener;
             this.config = config;
-            this.completeListener = completeListener;
+            this.aborted = aborted;
         }
 
         @Override
@@ -372,7 +368,7 @@ public class GoogleSTTService implements STTService {
                     responseCount++;
                     // when in single utterance mode we can just get one final result so complete
                     if (config.singleUtteranceMode) {
-                        completeListener.accept(null);
+                        onComplete();
                     }
                 }
             });
@@ -380,16 +376,18 @@ public class GoogleSTTService implements STTService {
 
         @Override
         public void onComplete() {
-            sttListener.sttEventReceived(new RecognitionStopEvent());
-            float averageConfidence = confidenceSum / responseCount;
-            String transcript = transcriptBuilder.toString();
-            if (!transcript.isBlank()) {
-                sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
-            } else {
-                if (!config.noResultsMessage.isBlank()) {
-                    sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
+            if (!aborted.getAndSet(true)) {
+                sttListener.sttEventReceived(new RecognitionStopEvent());
+                float averageConfidence = confidenceSum / responseCount;
+                String transcript = transcriptBuilder.toString();
+                if (!transcript.isBlank()) {
+                    sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
                 } else {
-                    sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
+                    if (!config.noResultsMessage.isBlank()) {
+                        sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
+                    } else {
+                        sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
+                    }
                 }
             }
         }
@@ -397,14 +395,15 @@ public class GoogleSTTService implements STTService {
         @Override
         public void onError(@Nullable Throwable t) {
             logger.warn("Recognition error: ", t);
-            completeListener.accept(t);
-            sttListener.sttEventReceived(new RecognitionStopEvent());
-            if (!config.errorMessage.isBlank()) {
-                sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
-            } else {
-                String errorMessage = t.getMessage();
-                sttListener.sttEventReceived(
-                        new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
+            if (!aborted.getAndSet(true)) {
+                sttListener.sttEventReceived(new RecognitionStopEvent());
+                if (!config.errorMessage.isBlank()) {
+                    sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
+                } else {
+                    String errorMessage = t.getMessage();
+                    sttListener.sttEventReceived(
+                            new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
+                }
             }
         }