2 * Copyright (c) 2010-2024 Contributors to the openHAB project
4 * See the NOTICE file(s) distributed with this work for additional
7 * This program and the accompanying materials are made available under the
8 * terms of the Eclipse Public License 2.0 which is available at
9 * http://www.eclipse.org/legal/epl-2.0
11 * SPDX-License-Identifier: EPL-2.0
13 package org.openhab.voice.rustpotterks.internal;
15 import static org.openhab.voice.rustpotterks.internal.RustpotterKSConstants.*;
17 import java.io.IOException;
18 import java.nio.file.Files;
19 import java.nio.file.Path;
20 import java.util.ArrayList;
21 import java.util.List;
22 import java.util.Locale;
24 import java.util.Optional;
26 import java.util.concurrent.ExecutorService;
27 import java.util.concurrent.atomic.AtomicBoolean;
29 import org.eclipse.jdt.annotation.NonNullByDefault;
30 import org.eclipse.jdt.annotation.Nullable;
31 import org.openhab.core.OpenHAB;
32 import org.openhab.core.audio.AudioFormat;
33 import org.openhab.core.audio.AudioStream;
34 import org.openhab.core.common.ThreadPoolManager;
35 import org.openhab.core.config.core.ConfigurableService;
36 import org.openhab.core.config.core.Configuration;
37 import org.openhab.core.voice.KSErrorEvent;
38 import org.openhab.core.voice.KSException;
39 import org.openhab.core.voice.KSListener;
40 import org.openhab.core.voice.KSService;
41 import org.openhab.core.voice.KSServiceHandle;
42 import org.openhab.core.voice.KSpottedEvent;
43 import org.osgi.framework.Constants;
44 import org.osgi.service.component.annotations.Activate;
45 import org.osgi.service.component.annotations.Component;
46 import org.osgi.service.component.annotations.Modified;
47 import org.slf4j.Logger;
48 import org.slf4j.LoggerFactory;
50 import io.github.givimad.rustpotter_java.Endianness;
51 import io.github.givimad.rustpotter_java.Rustpotter;
52 import io.github.givimad.rustpotter_java.RustpotterConfig;
53 import io.github.givimad.rustpotter_java.RustpotterDetection;
54 import io.github.givimad.rustpotter_java.SampleFormat;
55 import io.github.givimad.rustpotter_java.ScoreMode;
56 import io.github.givimad.rustpotter_java.VADMode;
59 * The {@link RustpotterKSService} is a keyword spotting implementation based on rustpotter.
61 * @author Miguel Álvarez - Initial contribution
64 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
65 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
66 + " Keyword Spotter", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
67 public class RustpotterKSService implements KSService {
68 private static final Path RUSTPOTTER_FOLDER = Path.of(OpenHAB.getUserDataFolder(), "rustpotter");
69 private static final Path RUSTPOTTER_RECORDS_FOLDER = RUSTPOTTER_FOLDER.resolve("records");
70 private final Logger logger = LoggerFactory.getLogger(RustpotterKSService.class);
71 private final ExecutorService executor = ThreadPoolManager.getPool("voice-rustpotterks");
72 private RustpotterKSConfiguration config = new RustpotterKSConfiguration();
73 private final List<RustpotterMutex> runningInstances = new ArrayList<>();
76 protected void activate(Map<String, Object> config) {
77 logger.debug("Loading library");
78 tryCreateDir(RUSTPOTTER_FOLDER);
79 tryCreateDir(RUSTPOTTER_RECORDS_FOLDER);
81 Rustpotter.loadLibrary();
82 } catch (IOException e) {
83 logger.warn("Unable to load rustpotter native library: {}", e.getMessage());
89 protected void modified(Map<String, Object> config) {
90 this.config = new Configuration(config).as(RustpotterKSConfiguration.class);
91 asyncUpdateActiveInstances();
95 public String getId() {
100 public String getLabel(@Nullable Locale locale) {
105 public Set<Locale> getSupportedLocales() {
110 public Set<AudioFormat> getSupportedFormats() {
112 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
113 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, 16, null, null),
114 new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, 32, null, null));
118 public KSServiceHandle spot(KSListener ksListener, AudioStream audioStream, Locale locale, String keyword)
120 var audioFormat = audioStream.getFormat();
121 var frequency = audioFormat.getFrequency();
122 var bitDepth = audioFormat.getBitDepth();
123 var channels = audioFormat.getChannels();
124 var isBigEndian = audioFormat.isBigEndian();
125 if (frequency == null || bitDepth == null || channels == null || isBigEndian == null) {
126 throw new KSException(
127 "Missing stream metadata: frequency, bit depth, channels and endianness must be defined.");
129 var endianness = isBigEndian ? Endianness.BIG : Endianness.LITTLE;
130 logger.debug("Audio wav spec: sample rate {}, {} bits, {} channels, {}", frequency, bitDepth, channels,
131 isBigEndian ? "big-endian" : "little-endian");
132 var wakewordName = keyword.replaceAll("\\s", "_") + ".rpw";
134 var wakewordPath = RUSTPOTTER_FOLDER.resolve(wakewordName);
135 if (!Files.exists(wakewordPath)) {
136 throw new KSException("Missing wakeword file: " + wakewordPath);
138 Rustpotter rustpotter;
140 rustpotter = initRustpotter(frequency, bitDepth, channels, endianness);
141 } catch (Exception e) {
142 throw new KSException("Unable to start rustpotter: " + e.getMessage(), e);
145 rustpotter.addWakewordFile("w", wakewordPath.toString());
146 } catch (Exception e) {
147 throw new KSException("Unable to load wakeword file: " + e.getMessage());
149 logger.debug("Wakeword '{}' loaded", wakewordPath);
150 AtomicBoolean aborted = new AtomicBoolean(false);
151 int bufferSize = (int) rustpotter.getBytesPerFrame();
152 long bytesPerMs = frequency / 1000 * (long) bitDepth;
153 RustpotterMutex rustpotterMutex = new RustpotterMutex(rustpotter);
154 synchronized (this.runningInstances) {
155 this.runningInstances.add(rustpotterMutex);
158 () -> processAudioStream(rustpotterMutex, bufferSize, bytesPerMs, ksListener, audioStream, aborted));
160 logger.debug("Stopping service");
165 private Rustpotter initRustpotter(long frequency, int bitDepth, int channels, Endianness endianness)
167 var rustpotterConfig = initRustpotterConfig();
168 // audio format config just need to be set for initializing the instance, is ignored on config updates
169 rustpotterConfig.setSampleFormat(getIntSampleFormat(bitDepth));
170 rustpotterConfig.setSampleRate(frequency);
171 rustpotterConfig.setChannels(channels);
172 rustpotterConfig.setEndianness(endianness);
174 var rustpotter = new Rustpotter(rustpotterConfig);
175 rustpotterConfig.delete();
179 private RustpotterConfig initRustpotterConfig() {
180 var rustpotterConfig = new RustpotterConfig();
182 rustpotterConfig.setThreshold(config.threshold);
183 rustpotterConfig.setAveragedThreshold(config.averagedThreshold);
184 rustpotterConfig.setScoreMode(getScoreMode(config.scoreMode));
185 rustpotterConfig.setMinScores(config.minScores);
186 rustpotterConfig.setEager(config.eager);
187 rustpotterConfig.setScoreRef(config.scoreRef);
188 rustpotterConfig.setBandSize(config.bandSize);
189 rustpotterConfig.setVADMode(getVADMode(config.vadMode));
190 rustpotterConfig.setRecordPath(config.record ? RUSTPOTTER_RECORDS_FOLDER.toString() : null);
192 rustpotterConfig.setGainNormalizerEnabled(config.gainNormalizer);
193 rustpotterConfig.setMinGain(config.minGain);
194 rustpotterConfig.setMaxGain(config.maxGain);
195 rustpotterConfig.setGainRef(config.gainRef);
196 rustpotterConfig.setBandPassFilterEnabled(config.bandPass);
197 rustpotterConfig.setBandPassLowCutoff(config.lowCutoff);
198 rustpotterConfig.setBandPassHighCutoff(config.highCutoff);
200 return rustpotterConfig;
203 private void processAudioStream(RustpotterMutex rustpotter, int bufferSize, long bytesPerMs, KSListener ksListener,
204 AudioStream audioStream, AtomicBoolean aborted) {
206 byte[] audioBuffer = new byte[bufferSize];
207 int remaining = bufferSize;
208 boolean hasFailed = false;
209 while (!aborted.get()) {
211 numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
212 if (aborted.get() || numBytesRead == -1) {
215 if (numBytesRead != remaining) {
216 remaining = remaining - numBytesRead;
218 Thread.sleep(remaining / bytesPerMs);
219 } catch (InterruptedException ignored) {
220 logger.warn("Thread interrupted while waiting for audio, aborting execution");
228 remaining = bufferSize;
229 var result = rustpotter.processBytes(audioBuffer);
231 if (result.isPresent()) {
232 var detection = result.get();
233 if (logger.isDebugEnabled()) {
234 ArrayList<String> scores = new ArrayList<>();
235 var scoreNames = detection.getScoreNames().split("\\|\\|");
236 var scoreValues = detection.getScores();
237 for (var i = 0; i < Integer.min(scoreNames.length, scoreValues.length); i++) {
238 scores.add("'" + scoreNames[i] + "': " + scoreValues[i]);
240 logger.debug("Detected '{}' with: Score: {}, AvgScore: {}, Count: {}, Gain: {}, Scores: {}",
241 detection.getName(), detection.getScore(), detection.getAvgScore(),
242 detection.getCounter(), detection.getGain(), String.join(", ", scores));
245 ksListener.ksEventReceived(new KSpottedEvent());
247 } catch (IOException e) {
248 String errorMessage = e.getMessage();
249 ksListener.ksEventReceived(new KSErrorEvent(errorMessage != null ? errorMessage : "Unexpected error"));
251 logger.warn("Multiple consecutive errors, stopping service");
257 synchronized (this.runningInstances) {
258 this.runningInstances.remove(rustpotter);
261 logger.debug("Rustpotter stopped");
264 private void asyncUpdateActiveInstances() {
266 synchronized (this.runningInstances) {
267 nInstances = this.runningInstances.size();
269 if (nInstances == 0) {
272 var rustpotterConfig = initRustpotterConfig();
273 executor.submit(() -> {
274 logger.debug("Updating running instances");
275 synchronized (this.runningInstances) {
276 for (RustpotterMutex rustpotter : this.runningInstances) {
277 rustpotter.updateConfig(rustpotterConfig);
279 logger.debug("{} running instances updated", this.runningInstances.size());
281 rustpotterConfig.delete();
285 private static SampleFormat getIntSampleFormat(int bitDepth) throws IOException {
286 return switch (bitDepth) {
287 case 8 -> SampleFormat.I8;
288 case 16 -> SampleFormat.I16;
289 case 32 -> SampleFormat.I32;
290 default -> throw new IOException("Unsupported audio bit depth: " + bitDepth);
294 private ScoreMode getScoreMode(String mode) {
295 return switch (mode) {
296 case "average" -> ScoreMode.AVG;
297 case "median" -> ScoreMode.MEDIAN;
298 case "p25" -> ScoreMode.P25;
299 case "p50" -> ScoreMode.P50;
300 case "p75" -> ScoreMode.P75;
301 case "p80" -> ScoreMode.P80;
302 case "p90" -> ScoreMode.P90;
303 case "p95" -> ScoreMode.P95;
304 default -> ScoreMode.MAX;
308 private @Nullable VADMode getVADMode(String mode) {
309 return switch (mode) {
310 case "easy" -> VADMode.EASY;
311 case "medium" -> VADMode.MEDIUM;
312 case "hard" -> VADMode.HARD;
317 private void tryCreateDir(Path rustpotterFolder) {
318 if (!Files.exists(rustpotterFolder) || !Files.isDirectory(rustpotterFolder)) {
320 Files.createDirectory(rustpotterFolder);
321 logger.info("Folder {} created", rustpotterFolder);
322 } catch (IOException e) {
323 logger.warn("Unable to create folder {}", rustpotterFolder);
328 private record RustpotterMutex(Rustpotter rustpotter) {
330 public Optional<RustpotterDetection> processBytes(byte[] bytes) {
331 synchronized (this.rustpotter) {
332 return this.rustpotter.processBytes(bytes);
336 public void updateConfig(RustpotterConfig config) {
337 synchronized (this.rustpotter) {
338 this.rustpotter.updateConfig(config);
342 public void delete() {
343 synchronized (this.rustpotter) {
344 this.rustpotter.delete();