2 * Copyright (c) 2010-2023 Contributors to the openHAB project
4 * See the NOTICE file(s) distributed with this work for additional
7 * This program and the accompanying materials are made available under the
8 * terms of the Eclipse Public License 2.0 which is available at
9 * http://www.eclipse.org/legal/epl-2.0
11 * SPDX-License-Identifier: EPL-2.0
13 package org.openhab.voice.rustpotterks.internal;
15 import static org.openhab.voice.rustpotterks.internal.RustpotterKSConstants.*;
18 import java.io.IOException;
19 import java.nio.file.Path;
20 import java.util.ArrayList;
21 import java.util.Locale;
24 import java.util.concurrent.ScheduledExecutorService;
25 import java.util.concurrent.atomic.AtomicBoolean;
27 import org.eclipse.jdt.annotation.NonNullByDefault;
28 import org.eclipse.jdt.annotation.Nullable;
29 import org.openhab.core.OpenHAB;
30 import org.openhab.core.audio.AudioFormat;
31 import org.openhab.core.audio.AudioStream;
32 import org.openhab.core.common.ThreadPoolManager;
33 import org.openhab.core.config.core.ConfigurableService;
34 import org.openhab.core.config.core.Configuration;
35 import org.openhab.core.voice.KSErrorEvent;
36 import org.openhab.core.voice.KSException;
37 import org.openhab.core.voice.KSListener;
38 import org.openhab.core.voice.KSService;
39 import org.openhab.core.voice.KSServiceHandle;
40 import org.openhab.core.voice.KSpottedEvent;
41 import org.osgi.framework.Constants;
42 import org.osgi.service.component.annotations.Activate;
43 import org.osgi.service.component.annotations.Component;
44 import org.osgi.service.component.annotations.Modified;
45 import org.slf4j.Logger;
46 import org.slf4j.LoggerFactory;
48 import io.github.givimad.rustpotter_java.Endianness;
49 import io.github.givimad.rustpotter_java.Rustpotter;
50 import io.github.givimad.rustpotter_java.RustpotterBuilder;
51 import io.github.givimad.rustpotter_java.SampleFormat;
52 import io.github.givimad.rustpotter_java.ScoreMode;
55 * The {@link RustpotterKSService} is a keyword spotting implementation based on rustpotter.
57 * @author Miguel Álvarez - Initial contribution
60 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
61 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
62 + " Keyword Spotter", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
63 public class RustpotterKSService implements KSService {
64 private static final String RUSTPOTTER_FOLDER = Path.of(OpenHAB.getUserDataFolder(), "rustpotter").toString();
65 private final Logger logger = LoggerFactory.getLogger(RustpotterKSService.class);
66 private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-rustpotterks");
67 private RustpotterKSConfiguration config = new RustpotterKSConfiguration();
69 Logger logger = LoggerFactory.getLogger(RustpotterKSService.class);
70 File directory = new File(RUSTPOTTER_FOLDER);
71 if (!directory.exists()) {
72 if (directory.mkdir()) {
73 logger.info("rustpotter dir created {}", RUSTPOTTER_FOLDER);
79 protected void activate(Map<String, Object> config) {
84 protected void modified(Map<String, Object> config) {
85 this.config = new Configuration(config).as(RustpotterKSConfiguration.class);
89 public String getId() {
94 public String getLabel(@Nullable Locale locale) {
99 public Set<Locale> getSupportedLocales() {
104 public Set<AudioFormat> getSupportedFormats() {
106 .of(new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, null, null, null));
110 public KSServiceHandle spot(KSListener ksListener, AudioStream audioStream, Locale locale, String keyword)
112 logger.debug("Loading library");
114 Rustpotter.loadLibrary();
115 } catch (IOException e) {
116 throw new KSException("Unable to load rustpotter lib: " + e.getMessage());
118 var audioFormat = audioStream.getFormat();
119 var frequency = audioFormat.getFrequency();
120 var bitDepth = audioFormat.getBitDepth();
121 var channels = audioFormat.getChannels();
122 var isBigEndian = audioFormat.isBigEndian();
123 if (frequency == null || bitDepth == null || channels == null || isBigEndian == null) {
124 throw new KSException(
125 "Missing stream metadata: frequency, bit depth, channels and endianness must be defined.");
127 var endianness = isBigEndian ? Endianness.BIG : Endianness.LITTLE;
128 logger.debug("Audio wav spec: frequency '{}', bit depth '{}', channels '{}', '{}'", frequency, bitDepth,
129 channels, isBigEndian ? "big-endian" : "little-endian");
130 Rustpotter rustpotter;
132 rustpotter = initRustpotter(frequency, bitDepth, channels, endianness);
133 } catch (Exception e) {
134 throw new KSException("Unable to configure rustpotter: " + e.getMessage(), e);
136 var modelName = keyword.replaceAll("\\s", "_") + ".rpw";
137 var modelPath = Path.of(RUSTPOTTER_FOLDER, modelName);
138 if (!modelPath.toFile().exists()) {
139 throw new KSException("Missing model " + modelName);
142 rustpotter.addWakewordModelFile(modelPath.toString());
143 } catch (Exception e) {
144 throw new KSException("Unable to load wake word model: " + e.getMessage());
146 logger.debug("Model '{}' loaded", modelPath);
147 AtomicBoolean aborted = new AtomicBoolean(false);
148 executor.submit(() -> processAudioStream(rustpotter, ksListener, audioStream, aborted));
150 logger.debug("Stopping service");
155 private Rustpotter initRustpotter(long frequency, int bitDepth, int channels, Endianness endianness)
157 var rustpotterBuilder = new RustpotterBuilder();
159 rustpotterBuilder.setBitsPerSample(bitDepth);
160 rustpotterBuilder.setSampleRate(frequency);
161 rustpotterBuilder.setChannels(channels);
162 rustpotterBuilder.setSampleFormat(SampleFormat.INT);
163 rustpotterBuilder.setEndianness(endianness);
165 rustpotterBuilder.setThreshold(config.threshold);
166 rustpotterBuilder.setAveragedThreshold(config.averagedThreshold);
167 rustpotterBuilder.setScoreMode(getScoreMode(config.scoreMode));
168 rustpotterBuilder.setMinScores(config.minScores);
169 rustpotterBuilder.setComparatorRef(config.comparatorRef);
170 rustpotterBuilder.setComparatorBandSize(config.comparatorBandSize);
172 rustpotterBuilder.setGainNormalizerEnabled(config.gainNormalizer);
173 rustpotterBuilder.setMinGain(config.minGain);
174 rustpotterBuilder.setMaxGain(config.maxGain);
175 rustpotterBuilder.setGainRef(config.gainRef);
176 rustpotterBuilder.setBandPassFilterEnabled(config.bandPass);
177 rustpotterBuilder.setBandPassLowCutoff(config.lowCutoff);
178 rustpotterBuilder.setBandPassHighCutoff(config.highCutoff);
180 var rustpotter = rustpotterBuilder.build();
181 rustpotterBuilder.delete();
185 private void processAudioStream(Rustpotter rustpotter, KSListener ksListener, AudioStream audioStream,
186 AtomicBoolean aborted) {
188 var bufferSize = (int) rustpotter.getBytesPerFrame();
189 byte[] audioBuffer = new byte[bufferSize];
190 int remaining = bufferSize;
191 while (!aborted.get()) {
193 numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
197 if (numBytesRead != remaining) {
198 remaining = remaining - numBytesRead;
202 remaining = bufferSize;
203 var result = rustpotter.processBytes(audioBuffer);
204 if (result.isPresent()) {
205 var detection = result.get();
206 if (logger.isDebugEnabled()) {
207 ArrayList<String> scores = new ArrayList<>();
208 var scoreNames = detection.getScoreNames().split("\\|\\|");
209 var scoreValues = detection.getScores();
210 for (var i = 0; i < Integer.min(scoreNames.length, scoreValues.length); i++) {
211 scores.add("'" + scoreNames[i] + "': " + scoreValues[i]);
213 logger.debug("Detected '{}' with: Score: {}, AvgScore: {}, Count: {}, Gain: {}, Scores: {}",
214 detection.getName(), detection.getScore(), detection.getAvgScore(),
215 detection.getCounter(), detection.getGain(), String.join(", ", scores));
218 ksListener.ksEventReceived(new KSpottedEvent());
220 } catch (IOException | InterruptedException e) {
221 String errorMessage = e.getMessage();
222 ksListener.ksEventReceived(new KSErrorEvent(errorMessage != null ? errorMessage : "Unexpected error"));
226 logger.debug("rustpotter stopped");
229 private ScoreMode getScoreMode(String mode) {
232 return ScoreMode.AVG;
234 return ScoreMode.MEDIAN;
236 return ScoreMode.P25;
238 return ScoreMode.P50;
240 return ScoreMode.P75;
242 return ScoreMode.P80;
244 return ScoreMode.P90;
246 return ScoreMode.P95;
249 return ScoreMode.MAX;