]> git.basschouten.com Git - openhab-addons.git/blob
7a95c85fbe151e36b900cefba7e17a230fd04972
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2024 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.rustpotterks.internal;
14
15 import static org.openhab.voice.rustpotterks.internal.RustpotterKSConstants.*;
16
17 import java.io.IOException;
18 import java.nio.file.Files;
19 import java.nio.file.Path;
20 import java.util.ArrayList;
21 import java.util.List;
22 import java.util.Locale;
23 import java.util.Map;
24 import java.util.Optional;
25 import java.util.Set;
26 import java.util.concurrent.ExecutorService;
27 import java.util.concurrent.atomic.AtomicBoolean;
28
29 import org.eclipse.jdt.annotation.NonNullByDefault;
30 import org.eclipse.jdt.annotation.Nullable;
31 import org.openhab.core.OpenHAB;
32 import org.openhab.core.audio.AudioFormat;
33 import org.openhab.core.audio.AudioStream;
34 import org.openhab.core.common.ThreadPoolManager;
35 import org.openhab.core.config.core.ConfigurableService;
36 import org.openhab.core.config.core.Configuration;
37 import org.openhab.core.voice.KSErrorEvent;
38 import org.openhab.core.voice.KSException;
39 import org.openhab.core.voice.KSListener;
40 import org.openhab.core.voice.KSService;
41 import org.openhab.core.voice.KSServiceHandle;
42 import org.openhab.core.voice.KSpottedEvent;
43 import org.osgi.framework.Constants;
44 import org.osgi.service.component.annotations.Activate;
45 import org.osgi.service.component.annotations.Component;
46 import org.osgi.service.component.annotations.Modified;
47 import org.slf4j.Logger;
48 import org.slf4j.LoggerFactory;
49
50 import io.github.givimad.rustpotter_java.Endianness;
51 import io.github.givimad.rustpotter_java.Rustpotter;
52 import io.github.givimad.rustpotter_java.RustpotterConfig;
53 import io.github.givimad.rustpotter_java.RustpotterDetection;
54 import io.github.givimad.rustpotter_java.SampleFormat;
55 import io.github.givimad.rustpotter_java.ScoreMode;
56 import io.github.givimad.rustpotter_java.VADMode;
57
58 /**
59  * The {@link RustpotterKSService} is a keyword spotting implementation based on rustpotter.
60  *
61  * @author Miguel Álvarez - Initial contribution
62  */
63 @NonNullByDefault
64 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
65 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
66         + " Keyword Spotter", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
67 public class RustpotterKSService implements KSService {
68     private static final Path RUSTPOTTER_FOLDER = Path.of(OpenHAB.getUserDataFolder(), "rustpotter");
69     private static final Path RUSTPOTTER_RECORDS_FOLDER = RUSTPOTTER_FOLDER.resolve("records");
70     private final Logger logger = LoggerFactory.getLogger(RustpotterKSService.class);
71     private final ExecutorService executor = ThreadPoolManager.getPool("voice-rustpotterks");
72     private RustpotterKSConfiguration config = new RustpotterKSConfiguration();
73     private final List<RustpotterMutex> runningInstances = new ArrayList<>();
74
75     @Activate
76     protected void activate(Map<String, Object> config) {
77         logger.debug("Loading library");
78         tryCreateDir(RUSTPOTTER_FOLDER);
79         tryCreateDir(RUSTPOTTER_RECORDS_FOLDER);
80         try {
81             Rustpotter.loadLibrary();
82         } catch (IOException e) {
83             logger.warn("Unable to load rustpotter native library: {}", e.getMessage());
84         }
85         modified(config);
86     }
87
88     @Modified
89     protected void modified(Map<String, Object> config) {
90         this.config = new Configuration(config).as(RustpotterKSConfiguration.class);
91         asyncUpdateActiveInstances();
92     }
93
94     @Override
95     public String getId() {
96         return SERVICE_ID;
97     }
98
99     @Override
100     public String getLabel(@Nullable Locale locale) {
101         return SERVICE_NAME;
102     }
103
104     @Override
105     public Set<Locale> getSupportedLocales() {
106         return Set.of();
107     }
108
109     @Override
110     public Set<AudioFormat> getSupportedFormats() {
111         return Set.of(
112                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
113                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, 16, null, null),
114                 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, 32, null, null));
115     }
116
117     @Override
118     public KSServiceHandle spot(KSListener ksListener, AudioStream audioStream, Locale locale, String keyword)
119             throws KSException {
120         var audioFormat = audioStream.getFormat();
121         var frequency = audioFormat.getFrequency();
122         var bitDepth = audioFormat.getBitDepth();
123         var channels = audioFormat.getChannels();
124         var isBigEndian = audioFormat.isBigEndian();
125         if (frequency == null || bitDepth == null || channels == null || isBigEndian == null) {
126             throw new KSException(
127                     "Missing stream metadata: frequency, bit depth, channels and endianness must be defined.");
128         }
129         var endianness = isBigEndian ? Endianness.BIG : Endianness.LITTLE;
130         logger.debug("Audio wav spec: sample rate {}, {} bits, {} channels, {}", frequency, bitDepth, channels,
131                 isBigEndian ? "big-endian" : "little-endian");
132         var wakewordName = keyword.replaceAll("\\s", "_") + ".rpw";
133
134         var wakewordPath = RUSTPOTTER_FOLDER.resolve(wakewordName);
135         if (!Files.exists(wakewordPath)) {
136             throw new KSException("Missing wakeword file: " + wakewordPath);
137         }
138         Rustpotter rustpotter;
139         try {
140             rustpotter = initRustpotter(frequency, bitDepth, channels, endianness);
141         } catch (Exception e) {
142             throw new KSException("Unable to start rustpotter: " + e.getMessage(), e);
143         }
144         try {
145             rustpotter.addWakewordFile("w", wakewordPath.toString());
146         } catch (Exception e) {
147             throw new KSException("Unable to load wakeword file: " + e.getMessage());
148         }
149         logger.debug("Wakeword '{}' loaded", wakewordPath);
150         AtomicBoolean aborted = new AtomicBoolean(false);
151         int bufferSize = (int) rustpotter.getBytesPerFrame();
152         long bytesPerMs = frequency / 1000 * (long) bitDepth;
153         RustpotterMutex rustpotterMutex = new RustpotterMutex(rustpotter);
154         synchronized (this.runningInstances) {
155             this.runningInstances.add(rustpotterMutex);
156         }
157         executor.submit(
158                 () -> processAudioStream(rustpotterMutex, bufferSize, bytesPerMs, ksListener, audioStream, aborted));
159         return () -> {
160             logger.debug("Stopping service");
161             aborted.set(true);
162         };
163     }
164
165     private Rustpotter initRustpotter(long frequency, int bitDepth, int channels, Endianness endianness)
166             throws Exception {
167         var rustpotterConfig = initRustpotterConfig();
168         // audio format config just need to be set for initializing the instance, is ignored on config updates
169         rustpotterConfig.setSampleFormat(getIntSampleFormat(bitDepth));
170         rustpotterConfig.setSampleRate(frequency);
171         rustpotterConfig.setChannels(channels);
172         rustpotterConfig.setEndianness(endianness);
173         // init the detector
174         var rustpotter = new Rustpotter(rustpotterConfig);
175         rustpotterConfig.delete();
176         return rustpotter;
177     }
178
179     private RustpotterConfig initRustpotterConfig() {
180         var rustpotterConfig = new RustpotterConfig();
181         // detector configs
182         rustpotterConfig.setThreshold(config.threshold);
183         rustpotterConfig.setAveragedThreshold(config.averagedThreshold);
184         rustpotterConfig.setScoreMode(getScoreMode(config.scoreMode));
185         rustpotterConfig.setMinScores(config.minScores);
186         rustpotterConfig.setEager(config.eager);
187         rustpotterConfig.setScoreRef(config.scoreRef);
188         rustpotterConfig.setBandSize(config.bandSize);
189         rustpotterConfig.setVADMode(getVADMode(config.vadMode));
190         rustpotterConfig.setRecordPath(config.record ? RUSTPOTTER_RECORDS_FOLDER.toString() : null);
191         // filter configs
192         rustpotterConfig.setGainNormalizerEnabled(config.gainNormalizer);
193         rustpotterConfig.setMinGain(config.minGain);
194         rustpotterConfig.setMaxGain(config.maxGain);
195         rustpotterConfig.setGainRef(config.gainRef);
196         rustpotterConfig.setBandPassFilterEnabled(config.bandPass);
197         rustpotterConfig.setBandPassLowCutoff(config.lowCutoff);
198         rustpotterConfig.setBandPassHighCutoff(config.highCutoff);
199
200         return rustpotterConfig;
201     }
202
203     private void processAudioStream(RustpotterMutex rustpotter, int bufferSize, long bytesPerMs, KSListener ksListener,
204             AudioStream audioStream, AtomicBoolean aborted) {
205         int numBytesRead;
206         byte[] audioBuffer = new byte[bufferSize];
207         int remaining = bufferSize;
208         boolean hasFailed = false;
209         while (!aborted.get()) {
210             try {
211                 numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
212                 if (aborted.get() || numBytesRead == -1) {
213                     break;
214                 }
215                 if (numBytesRead != remaining) {
216                     remaining = remaining - numBytesRead;
217                     try {
218                         Thread.sleep(remaining / bytesPerMs);
219                     } catch (InterruptedException ignored) {
220                         logger.warn("Thread interrupted while waiting for audio, aborting execution");
221                         aborted.set(true);
222                     }
223                     if (aborted.get()) {
224                         break;
225                     }
226                     continue;
227                 }
228                 remaining = bufferSize;
229                 var result = rustpotter.processBytes(audioBuffer);
230                 hasFailed = false;
231                 if (result.isPresent()) {
232                     var detection = result.get();
233                     if (logger.isDebugEnabled()) {
234                         ArrayList<String> scores = new ArrayList<>();
235                         var scoreNames = detection.getScoreNames().split("\\|\\|");
236                         var scoreValues = detection.getScores();
237                         for (var i = 0; i < Integer.min(scoreNames.length, scoreValues.length); i++) {
238                             scores.add("'" + scoreNames[i] + "': " + scoreValues[i]);
239                         }
240                         logger.debug("Detected '{}' with: Score: {}, AvgScore: {}, Count: {}, Gain: {}, Scores: {}",
241                                 detection.getName(), detection.getScore(), detection.getAvgScore(),
242                                 detection.getCounter(), detection.getGain(), String.join(", ", scores));
243                     }
244                     detection.delete();
245                     ksListener.ksEventReceived(new KSpottedEvent());
246                 }
247             } catch (IOException e) {
248                 String errorMessage = e.getMessage();
249                 ksListener.ksEventReceived(new KSErrorEvent(errorMessage != null ? errorMessage : "Unexpected error"));
250                 if (hasFailed) {
251                     logger.warn("Multiple consecutive errors, stopping service");
252                     break;
253                 }
254                 hasFailed = true;
255             }
256         }
257         synchronized (this.runningInstances) {
258             this.runningInstances.remove(rustpotter);
259         }
260         rustpotter.delete();
261         logger.debug("Rustpotter stopped");
262     }
263
264     private void asyncUpdateActiveInstances() {
265         int nInstances;
266         synchronized (this.runningInstances) {
267             nInstances = this.runningInstances.size();
268         }
269         if (nInstances == 0) {
270             return;
271         }
272         var rustpotterConfig = initRustpotterConfig();
273         executor.submit(() -> {
274             logger.debug("Updating running instances");
275             synchronized (this.runningInstances) {
276                 for (RustpotterMutex rustpotter : this.runningInstances) {
277                     rustpotter.updateConfig(rustpotterConfig);
278                 }
279                 logger.debug("{} running instances updated", this.runningInstances.size());
280             }
281             rustpotterConfig.delete();
282         });
283     }
284
285     private static SampleFormat getIntSampleFormat(int bitDepth) throws IOException {
286         return switch (bitDepth) {
287             case 8 -> SampleFormat.I8;
288             case 16 -> SampleFormat.I16;
289             case 32 -> SampleFormat.I32;
290             default -> throw new IOException("Unsupported audio bit depth: " + bitDepth);
291         };
292     }
293
294     private ScoreMode getScoreMode(String mode) {
295         return switch (mode) {
296             case "average" -> ScoreMode.AVG;
297             case "median" -> ScoreMode.MEDIAN;
298             case "p25" -> ScoreMode.P25;
299             case "p50" -> ScoreMode.P50;
300             case "p75" -> ScoreMode.P75;
301             case "p80" -> ScoreMode.P80;
302             case "p90" -> ScoreMode.P90;
303             case "p95" -> ScoreMode.P95;
304             default -> ScoreMode.MAX;
305         };
306     }
307
308     private @Nullable VADMode getVADMode(String mode) {
309         return switch (mode) {
310             case "easy" -> VADMode.EASY;
311             case "medium" -> VADMode.MEDIUM;
312             case "hard" -> VADMode.HARD;
313             default -> null;
314         };
315     }
316
317     private void tryCreateDir(Path rustpotterFolder) {
318         if (!Files.exists(rustpotterFolder) || !Files.isDirectory(rustpotterFolder)) {
319             try {
320                 Files.createDirectory(rustpotterFolder);
321                 logger.info("Folder {} created", rustpotterFolder);
322             } catch (IOException e) {
323                 logger.warn("Unable to create folder {}", rustpotterFolder);
324             }
325         }
326     }
327
328     private record RustpotterMutex(Rustpotter rustpotter) {
329
330         public Optional<RustpotterDetection> processBytes(byte[] bytes) {
331             synchronized (this.rustpotter) {
332                 return this.rustpotter.processBytes(bytes);
333             }
334         }
335
336         public void updateConfig(RustpotterConfig config) {
337             synchronized (this.rustpotter) {
338                 this.rustpotter.updateConfig(config);
339             }
340         }
341
342         public void delete() {
343             synchronized (this.rustpotter) {
344                 this.rustpotter.delete();
345             }
346         }
347     }
348 }