]> git.basschouten.com Git - openhab-addons.git/commitdiff
[Audio] Fix PCM format and use PipedAudioStream in sources (#16111)
authorGiviMAD <GiviMAD@users.noreply.github.com>
Sun, 4 Feb 2024 21:07:54 +0000 (13:07 -0800)
committerGitHub <noreply@github.com>
Sun, 4 Feb 2024 21:07:54 +0000 (22:07 +0100)
* [Audio] Fix pcm format and use PipedAudioStream
* fix rustpotter format changes

---------

Signed-off-by: Miguel Álvarez <miguelwork92@gmail.com>
bundles/org.openhab.binding.pulseaudio/src/main/java/org/openhab/binding/pulseaudio/internal/PulseAudioAudioSource.java
bundles/org.openhab.binding.pulseaudio/src/main/java/org/openhab/binding/pulseaudio/internal/handler/PulseaudioHandler.java
bundles/org.openhab.voice.googlestt/src/main/java/org/openhab/voice/googlestt/internal/GoogleSTTService.java
bundles/org.openhab.voice.rustpotterks/src/main/java/org/openhab/voice/rustpotterks/internal/RustpotterKSService.java
bundles/org.openhab.voice.voskstt/src/main/java/org/openhab/voice/voskstt/internal/VoskSTTService.java
bundles/org.openhab.voice.watsonstt/src/main/java/org/openhab/voice/watsonstt/internal/WatsonSTTService.java

index 863110143e470dcd504dd8b9d07a8a10cfc65300..3292aeb487bd3352d7c0e3fdb42c3a61a48881ab 100644 (file)
@@ -14,13 +14,8 @@ package org.openhab.binding.pulseaudio.internal;
 
 import java.io.IOException;
 import java.io.InputStream;
-import java.io.InterruptedIOException;
-import java.io.PipedInputStream;
-import java.io.PipedOutputStream;
 import java.net.Socket;
-import java.util.HashSet;
 import java.util.Set;
-import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 
@@ -31,6 +26,7 @@ import org.openhab.core.audio.AudioException;
 import org.openhab.core.audio.AudioFormat;
 import org.openhab.core.audio.AudioSource;
 import org.openhab.core.audio.AudioStream;
+import org.openhab.core.audio.PipedAudioStream;
 import org.openhab.core.common.ThreadPoolManager;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -45,25 +41,23 @@ import org.slf4j.LoggerFactory;
 public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implements AudioSource {
 
     private final Logger logger = LoggerFactory.getLogger(PulseAudioAudioSource.class);
-    private final ConcurrentLinkedQueue<PipedOutputStream> pipeOutputs = new ConcurrentLinkedQueue<>();
+    private final PipedAudioStream.Group streamGroup;
     private final ScheduledExecutorService executor;
+    private final AudioFormat streamFormat;
 
     private @Nullable Future<?> pipeWriteTask;
 
     public PulseAudioAudioSource(PulseaudioHandler pulseaudioHandler, ScheduledExecutorService scheduler) {
         super(pulseaudioHandler, scheduler);
+        streamFormat = pulseaudioHandler.getSourceAudioFormat();
         executor = ThreadPoolManager
                 .getScheduledPool("OH-binding-" + pulseaudioHandler.getThing().getUID() + "-source");
+        streamGroup = PipedAudioStream.newGroup(streamFormat);
     }
 
     @Override
     public Set<AudioFormat> getSupportedFormats() {
-        var supportedFormats = new HashSet<AudioFormat>();
-        var audioFormat = pulseaudioHandler.getSourceAudioFormat();
-        if (audioFormat != null) {
-            supportedFormats.add(audioFormat);
-        }
-        return supportedFormats;
+        return Set.of(streamFormat);
     }
 
     @Override
@@ -76,27 +70,18 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
                     if (clientSocketLocal == null) {
                         break;
                     }
-                    var sourceFormat = pulseaudioHandler.getSourceAudioFormat();
-                    if (sourceFormat == null) {
-                        throw new AudioException("Unable to get source audio format");
-                    }
-                    if (!audioFormat.isCompatible(sourceFormat)) {
+                    if (!audioFormat.isCompatible(streamFormat)) {
                         throw new AudioException("Incompatible audio format requested");
                     }
-                    var pipeOutput = new PipedOutputStream();
-                    var pipeInput = new PipedInputStream(pipeOutput, 1024 * 10) {
-                        @Override
-                        public void close() throws IOException {
-                            unregisterPipe(pipeOutput);
-                            super.close();
-                        }
-                    };
-                    registerPipe(pipeOutput);
-                    // get raw audio from the pulse audio socket
-                    return new PulseAudioStream(sourceFormat, pipeInput, () -> {
-                        // ensure pipe is writing
-                        startPipeWrite();
+                    var audioStream = streamGroup.getAudioStreamInGroup();
+                    audioStream.onClose(() -> {
+                        minusClientCount();
+                        stopPipeWriteTask();
                     });
+                    addClientCount();
+                    startPipeWrite();
+                    // get raw audio from the pulse audio socket
+                    return audioStream;
                 } catch (IOException e) {
                     disconnect(); // disconnect to force clear connection in case of socket not cleanly shutdown
                     if (countAttempt == 2) { // we won't retry : log and quit
@@ -120,14 +105,6 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
         throw new AudioException("Unable to create input stream");
     }
 
-    private synchronized void registerPipe(PipedOutputStream pipeOutput) {
-        boolean isAdded = this.pipeOutputs.add(pipeOutput);
-        if (isAdded) {
-            addClientCount();
-        }
-        startPipeWrite();
-    }
-
     /**
      * As startPipeWrite is called for every chunk read,
      * this wrapper method make the test before effectively
@@ -143,35 +120,16 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
         if (this.pipeWriteTask == null) {
             this.pipeWriteTask = executor.submit(() -> {
                 int lengthRead;
-                byte[] buffer = new byte[1024];
+                byte[] buffer = new byte[1200];
                 int readRetries = 3;
-                while (!pipeOutputs.isEmpty()) {
+                while (!streamGroup.isEmpty()) {
                     var stream = getSourceInputStream();
                     if (stream != null) {
                         try {
                             lengthRead = stream.read(buffer);
                             readRetries = 3;
-                            for (var output : pipeOutputs) {
-                                try {
-                                    output.write(buffer, 0, lengthRead);
-                                    if (pipeOutputs.contains(output)) {
-                                        output.flush();
-                                    }
-                                } catch (InterruptedIOException e) {
-                                    if (pipeOutputs.isEmpty()) {
-                                        // task has been ended while writing
-                                        return;
-                                    }
-                                    logger.warn("InterruptedIOException while writing from pulse source to pipe: {}",
-                                            getExceptionMessage(e));
-                                } catch (IOException e) {
-                                    logger.warn("IOException while writing from pulse source to pipe: {}",
-                                            getExceptionMessage(e));
-                                } catch (RuntimeException e) {
-                                    logger.warn("RuntimeException while writing from pulse source to pipe: {}",
-                                            getExceptionMessage(e));
-                                }
-                            }
+                            streamGroup.write(buffer, 0, lengthRead);
+                            streamGroup.flush();
                         } catch (IOException e) {
                             logger.warn("IOException while reading from pulse source: {}", getExceptionMessage(e));
                             if (readRetries == 0) {
@@ -192,25 +150,9 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
         }
     }
 
-    private synchronized void unregisterPipe(PipedOutputStream pipeOutput) {
-        boolean isRemoved = this.pipeOutputs.remove(pipeOutput);
-        if (isRemoved) {
-            minusClientCount();
-        }
-        try {
-            Thread.sleep(0);
-        } catch (InterruptedException ignored) {
-        }
-        stopPipeWriteTask();
-        try {
-            pipeOutput.close();
-        } catch (IOException ignored) {
-        }
-    }
-
     private synchronized void stopPipeWriteTask() {
         var pipeWriteTask = this.pipeWriteTask;
-        if (pipeOutputs.isEmpty() && pipeWriteTask != null) {
+        if (streamGroup.isEmpty() && pipeWriteTask != null) {
             pipeWriteTask.cancel(true);
             this.pipeWriteTask = null;
         }
@@ -243,58 +185,4 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
         stopPipeWriteTask();
         super.disconnect();
     }
-
-    static class PulseAudioStream extends AudioStream {
-        private final Logger logger = LoggerFactory.getLogger(PulseAudioAudioSource.class);
-        private final AudioFormat format;
-        private final InputStream input;
-        private final Runnable activity;
-        private boolean closed = false;
-
-        public PulseAudioStream(AudioFormat format, InputStream input, Runnable activity) {
-            this.input = input;
-            this.format = format;
-            this.activity = activity;
-        }
-
-        @Override
-        public AudioFormat getFormat() {
-            return format;
-        }
-
-        @Override
-        public int read() throws IOException {
-            byte[] b = new byte[1];
-            int bytesRead = read(b);
-            if (-1 == bytesRead) {
-                return bytesRead;
-            }
-            Byte bb = Byte.valueOf(b[0]);
-            return bb.intValue();
-        }
-
-        @Override
-        public int read(byte @Nullable [] b) throws IOException {
-            return read(b, 0, b == null ? 0 : b.length);
-        }
-
-        @Override
-        public int read(byte @Nullable [] b, int off, int len) throws IOException {
-            if (b == null) {
-                throw new IOException("Buffer is null");
-            }
-            logger.trace("reading from pulseaudio stream");
-            if (closed) {
-                throw new IOException("Stream is closed");
-            }
-            activity.run();
-            return input.read(b, off, len);
-        }
-
-        @Override
-        public void close() throws IOException {
-            closed = true;
-            input.close();
-        }
-    }
 }
index 6adc386c33ff5911f508c4213d776883dd8107db..f88b02407ee6ff6677237557a34e557fd0a45d25 100644 (file)
@@ -469,39 +469,50 @@ public class PulseaudioHandler extends BaseThingHandler {
                 .orElse(simpleTcpPort);
     }
 
-    public @Nullable AudioFormat getSourceAudioFormat() {
+    public AudioFormat getSourceAudioFormat() {
         String simpleFormat = ((String) getThing().getConfiguration().get(DEVICE_PARAMETER_AUDIO_SOURCE_FORMAT));
         BigDecimal simpleRate = ((BigDecimal) getThing().getConfiguration().get(DEVICE_PARAMETER_AUDIO_SOURCE_RATE));
         BigDecimal simpleChannels = ((BigDecimal) getThing().getConfiguration()
                 .get(DEVICE_PARAMETER_AUDIO_SOURCE_CHANNELS));
+        AudioFormat fallback = new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16,
+                16 * 16000, 16000L, 1);
         if (simpleFormat == null || simpleRate == null || simpleChannels == null) {
-            return null;
+            return fallback;
         }
+        int sampleRateAllChannels = simpleRate.intValue() * simpleChannels.intValue();
         switch (simpleFormat) {
-            case "u8":
-                return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_UNSIGNED, null, 8, 1,
-                        simpleRate.longValue(), simpleChannels.intValue());
-            case "s16le":
-                return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, 1,
-                        simpleRate.longValue(), simpleChannels.intValue());
-            case "s16be":
-                return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, true, 16, 1,
-                        simpleRate.longValue(), simpleChannels.intValue());
-            case "s24le":
-                return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 24, 1,
-                        simpleRate.longValue(), simpleChannels.intValue());
-            case "s24be":
-                return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, true, 24, 1,
-                        simpleRate.longValue(), simpleChannels.intValue());
-            case "s32le":
-                return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 32, 1,
-                        simpleRate.longValue(), simpleChannels.intValue());
-            case "s32be":
-                return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, true, 32, 1,
-                        simpleRate.longValue(), simpleChannels.intValue());
-            default:
+            case "u8" -> {
+                return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_UNSIGNED, null, 8,
+                        8 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
+            }
+            case "s16le" -> {
+                return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16,
+                        16 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
+            }
+            case "s16be" -> {
+                return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, true, 16,
+                        16 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
+            }
+            case "s24le" -> {
+                return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 24,
+                        24 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
+            }
+            case "s24be" -> {
+                return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, true, 24,
+                        24 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
+            }
+            case "s32le" -> {
+                return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 32,
+                        32 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
+            }
+            case "s32be" -> {
+                return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, true, 32,
+                        32 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
+            }
+            default -> {
                 logger.warn("unsupported format {}", simpleFormat);
-                return null;
+                return fallback;
+            }
         }
     }
 
index 308e711253327cddfc4042008bdcd26268eaf4f3..0a7fef170cf7dcd44aa5057dfa10888c0375cced 100644 (file)
@@ -29,6 +29,7 @@ import org.eclipse.jdt.annotation.NonNullByDefault;
 import org.eclipse.jdt.annotation.Nullable;
 import org.openhab.core.audio.AudioFormat;
 import org.openhab.core.audio.AudioStream;
+import org.openhab.core.audio.utils.AudioWaveUtils;
 import org.openhab.core.auth.client.oauth2.AccessTokenResponse;
 import org.openhab.core.auth.client.oauth2.OAuthClientService;
 import org.openhab.core.auth.client.oauth2.OAuthException;
@@ -144,12 +145,8 @@ public class GoogleSTTService implements STTService {
     @Override
     public Set<AudioFormat> getSupportedFormats() {
         return Set.of(
-                new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
-                new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 8000L),
-                new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 12000L),
-                new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 16000L),
-                new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 24000L),
-                new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 48000L));
+                new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
+                new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L));
     }
 
     @Override
@@ -248,8 +245,6 @@ public class GoogleSTTService implements STTService {
         RecognitionConfig.AudioEncoding streamEncoding;
         if (AudioFormat.WAV.isCompatible(streamFormat)) {
             streamEncoding = RecognitionConfig.AudioEncoding.LINEAR16;
-        } else if (AudioFormat.OGG.isCompatible(streamFormat)) {
-            streamEncoding = RecognitionConfig.AudioEncoding.OGG_OPUS;
         } else {
             logger.debug("Unsupported format {}", streamFormat);
             return;
@@ -271,6 +266,9 @@ public class GoogleSTTService implements STTService {
         final int bufferSize = 6400;
         int numBytesRead;
         int remaining = bufferSize;
+        if (AudioFormat.CONTAINER_WAVE.equals(streamFormat.getContainer())) {
+            AudioWaveUtils.removeFMT(audioStream);
+        }
         byte[] audioBuffer = new byte[bufferSize];
         while (!aborted.get() && !responseObserver.isDone()) {
             numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
index 7a95c85fbe151e36b900cefba7e17a230fd04972..758591564e50560fd60be5c851905c62264230a3 100644 (file)
@@ -109,6 +109,9 @@ public class RustpotterKSService implements KSService {
     @Override
     public Set<AudioFormat> getSupportedFormats() {
         return Set.of(
+                new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
+                new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, null, 16, null, null),
+                new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, null, 32, null, null),
                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, 16, null, null),
                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, 32, null, null));
index 2d2b0ca89c1db6731eb547430b045b9c2b9b4efc..b4d38392140bb691b920fc69d01de31d482632e6 100644 (file)
@@ -30,6 +30,7 @@ import org.eclipse.jdt.annotation.Nullable;
 import org.openhab.core.OpenHAB;
 import org.openhab.core.audio.AudioFormat;
 import org.openhab.core.audio.AudioStream;
+import org.openhab.core.audio.utils.AudioWaveUtils;
 import org.openhab.core.common.ThreadPoolManager;
 import org.openhab.core.config.core.ConfigurableService;
 import org.openhab.core.config.core.Configuration;
@@ -159,6 +160,7 @@ public class VoskSTTService implements STTService {
     @Override
     public Set<AudioFormat> getSupportedFormats() {
         return Set.of(
+                new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, null, null, 16000L),
                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, null, null, 16000L));
     }
 
@@ -167,10 +169,14 @@ public class VoskSTTService implements STTService {
             throws STTException {
         AtomicBoolean aborted = new AtomicBoolean(false);
         try {
-            var frequency = audioStream.getFormat().getFrequency();
+            AudioFormat format = audioStream.getFormat();
+            var frequency = format.getFrequency();
             if (frequency == null) {
                 throw new IOException("missing audio stream frequency");
             }
+            if (AudioFormat.CONTAINER_WAVE.equals(format.getContainer())) {
+                AudioWaveUtils.removeFMT(audioStream);
+            }
             backgroundRecognize(sttListener, audioStream, frequency, aborted);
         } catch (IOException e) {
             throw new STTException(e);
index e90bc08e69d908fa22334beaa19b12dbfbe275a7..d5ea0029c8d515295806214c3ac54c802f79d19e 100644 (file)
@@ -14,6 +14,7 @@ package org.openhab.voice.watsonstt.internal;
 
 import static org.openhab.voice.watsonstt.internal.WatsonSTTConstants.*;
 
+import java.io.IOException;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -28,6 +29,7 @@ import org.eclipse.jdt.annotation.NonNullByDefault;
 import org.eclipse.jdt.annotation.Nullable;
 import org.openhab.core.audio.AudioFormat;
 import org.openhab.core.audio.AudioStream;
+import org.openhab.core.audio.utils.AudioWaveUtils;
 import org.openhab.core.common.ThreadPoolManager;
 import org.openhab.core.config.core.ConfigurableService;
 import org.openhab.core.config.core.Configuration;
@@ -122,8 +124,7 @@ public class WatsonSTTService implements STTService {
 
     @Override
     public Set<AudioFormat> getSupportedFormats() {
-        return Set.of(AudioFormat.WAV, AudioFormat.OGG, new AudioFormat("OGG", "OPUS", null, null, null, null),
-                AudioFormat.MP3);
+        return Set.of(AudioFormat.PCM_SIGNED, AudioFormat.WAV);
     }
 
     @Override
@@ -147,6 +148,13 @@ public class WatsonSTTService implements STTService {
         final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
         final AtomicBoolean aborted = new AtomicBoolean(false);
         executor.submit(() -> {
+            if (AudioFormat.CONTAINER_WAVE.equals(audioStream.getFormat().getContainer())) {
+                try {
+                    AudioWaveUtils.removeFMT(audioStream);
+                } catch (IOException e) {
+                    logger.warn("Error removing format header: {}", e.getMessage());
+                }
+            }
             socketRef.set(stt.recognizeUsingWebSocket(wsOptions,
                     new TranscriptionListener(socketRef, sttListener, config, aborted)));
         });