]> git.basschouten.com Git - openhab-addons.git/blob
c31982e508bb401895f66fafcb43d423d2883eda
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2023 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.rustpotterks.internal;
14
15 import static org.openhab.voice.rustpotterks.internal.RustpotterKSConstants.*;
16
17 import java.io.File;
18 import java.io.IOException;
19 import java.nio.file.Path;
20 import java.util.ArrayList;
21 import java.util.Locale;
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.ScheduledExecutorService;
25 import java.util.concurrent.atomic.AtomicBoolean;
26
27 import org.eclipse.jdt.annotation.NonNullByDefault;
28 import org.eclipse.jdt.annotation.Nullable;
29 import org.openhab.core.OpenHAB;
30 import org.openhab.core.audio.AudioFormat;
31 import org.openhab.core.audio.AudioStream;
32 import org.openhab.core.common.ThreadPoolManager;
33 import org.openhab.core.config.core.ConfigurableService;
34 import org.openhab.core.config.core.Configuration;
35 import org.openhab.core.voice.KSErrorEvent;
36 import org.openhab.core.voice.KSException;
37 import org.openhab.core.voice.KSListener;
38 import org.openhab.core.voice.KSService;
39 import org.openhab.core.voice.KSServiceHandle;
40 import org.openhab.core.voice.KSpottedEvent;
41 import org.osgi.framework.Constants;
42 import org.osgi.service.component.annotations.Activate;
43 import org.osgi.service.component.annotations.Component;
44 import org.osgi.service.component.annotations.Modified;
45 import org.slf4j.Logger;
46 import org.slf4j.LoggerFactory;
47
48 import io.github.givimad.rustpotter_java.Endianness;
49 import io.github.givimad.rustpotter_java.Rustpotter;
50 import io.github.givimad.rustpotter_java.RustpotterBuilder;
51 import io.github.givimad.rustpotter_java.SampleFormat;
52 import io.github.givimad.rustpotter_java.ScoreMode;
53
54 /**
55  * The {@link RustpotterKSService} is a keyword spotting implementation based on rustpotter.
56  *
57  * @author Miguel Álvarez - Initial contribution
58  */
59 @NonNullByDefault
60 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
61 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
62         + " Keyword Spotter", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
63 public class RustpotterKSService implements KSService {
64     private static final String RUSTPOTTER_FOLDER = Path.of(OpenHAB.getUserDataFolder(), "rustpotter").toString();
65     private final Logger logger = LoggerFactory.getLogger(RustpotterKSService.class);
66     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-rustpotterks");
67     private RustpotterKSConfiguration config = new RustpotterKSConfiguration();
68     static {
69         Logger logger = LoggerFactory.getLogger(RustpotterKSService.class);
70         File directory = new File(RUSTPOTTER_FOLDER);
71         if (!directory.exists()) {
72             if (directory.mkdir()) {
73                 logger.info("rustpotter dir created {}", RUSTPOTTER_FOLDER);
74             }
75         }
76     }
77
78     @Activate
79     protected void activate(Map<String, Object> config) {
80         modified(config);
81     }
82
83     @Modified
84     protected void modified(Map<String, Object> config) {
85         this.config = new Configuration(config).as(RustpotterKSConfiguration.class);
86     }
87
88     @Override
89     public String getId() {
90         return SERVICE_ID;
91     }
92
93     @Override
94     public String getLabel(@Nullable Locale locale) {
95         return SERVICE_NAME;
96     }
97
98     @Override
99     public Set<Locale> getSupportedLocales() {
100         return Set.of();
101     }
102
103     @Override
104     public Set<AudioFormat> getSupportedFormats() {
105         return Set
106                 .of(new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, null, null, null));
107     }
108
109     @Override
110     public KSServiceHandle spot(KSListener ksListener, AudioStream audioStream, Locale locale, String keyword)
111             throws KSException {
112         logger.debug("Loading library");
113         try {
114             Rustpotter.loadLibrary();
115         } catch (IOException e) {
116             throw new KSException("Unable to load rustpotter lib: " + e.getMessage());
117         }
118         var audioFormat = audioStream.getFormat();
119         var frequency = audioFormat.getFrequency();
120         var bitDepth = audioFormat.getBitDepth();
121         var channels = audioFormat.getChannels();
122         var isBigEndian = audioFormat.isBigEndian();
123         if (frequency == null || bitDepth == null || channels == null || isBigEndian == null) {
124             throw new KSException(
125                     "Missing stream metadata: frequency, bit depth, channels and endianness must be defined.");
126         }
127         var endianness = isBigEndian ? Endianness.BIG : Endianness.LITTLE;
128         logger.debug("Audio wav spec: frequency '{}', bit depth '{}', channels '{}', '{}'", frequency, bitDepth,
129                 channels, isBigEndian ? "big-endian" : "little-endian");
130         Rustpotter rustpotter;
131         try {
132             rustpotter = initRustpotter(frequency, bitDepth, channels, endianness);
133         } catch (Exception e) {
134             throw new KSException("Unable to configure rustpotter: " + e.getMessage(), e);
135         }
136         var modelName = keyword.replaceAll("\\s", "_") + ".rpw";
137         var modelPath = Path.of(RUSTPOTTER_FOLDER, modelName);
138         if (!modelPath.toFile().exists()) {
139             throw new KSException("Missing model " + modelName);
140         }
141         try {
142             rustpotter.addWakewordModelFile(modelPath.toString());
143         } catch (Exception e) {
144             throw new KSException("Unable to load wake word model: " + e.getMessage());
145         }
146         logger.debug("Model '{}' loaded", modelPath);
147         AtomicBoolean aborted = new AtomicBoolean(false);
148         executor.submit(() -> processAudioStream(rustpotter, ksListener, audioStream, aborted));
149         return () -> {
150             logger.debug("Stopping service");
151             aborted.set(true);
152         };
153     }
154
155     private Rustpotter initRustpotter(long frequency, int bitDepth, int channels, Endianness endianness)
156             throws Exception {
157         var rustpotterBuilder = new RustpotterBuilder();
158         // audio configs
159         rustpotterBuilder.setBitsPerSample(bitDepth);
160         rustpotterBuilder.setSampleRate(frequency);
161         rustpotterBuilder.setChannels(channels);
162         rustpotterBuilder.setSampleFormat(SampleFormat.INT);
163         rustpotterBuilder.setEndianness(endianness);
164         // detector configs
165         rustpotterBuilder.setThreshold(config.threshold);
166         rustpotterBuilder.setAveragedThreshold(config.averagedThreshold);
167         rustpotterBuilder.setScoreMode(getScoreMode(config.scoreMode));
168         rustpotterBuilder.setMinScores(config.minScores);
169         rustpotterBuilder.setComparatorRef(config.comparatorRef);
170         rustpotterBuilder.setComparatorBandSize(config.comparatorBandSize);
171         // filter configs
172         rustpotterBuilder.setGainNormalizerEnabled(config.gainNormalizer);
173         rustpotterBuilder.setMinGain(config.minGain);
174         rustpotterBuilder.setMaxGain(config.maxGain);
175         rustpotterBuilder.setGainRef(config.gainRef);
176         rustpotterBuilder.setBandPassFilterEnabled(config.bandPass);
177         rustpotterBuilder.setBandPassLowCutoff(config.lowCutoff);
178         rustpotterBuilder.setBandPassHighCutoff(config.highCutoff);
179         // init the detector
180         var rustpotter = rustpotterBuilder.build();
181         rustpotterBuilder.delete();
182         return rustpotter;
183     }
184
185     private void processAudioStream(Rustpotter rustpotter, KSListener ksListener, AudioStream audioStream,
186             AtomicBoolean aborted) {
187         int numBytesRead;
188         var bufferSize = (int) rustpotter.getBytesPerFrame();
189         byte[] audioBuffer = new byte[bufferSize];
190         int remaining = bufferSize;
191         while (!aborted.get()) {
192             try {
193                 numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
194                 if (aborted.get()) {
195                     break;
196                 }
197                 if (numBytesRead != remaining) {
198                     remaining = remaining - numBytesRead;
199                     Thread.sleep(100);
200                     continue;
201                 }
202                 remaining = bufferSize;
203                 var result = rustpotter.processBytes(audioBuffer);
204                 if (result.isPresent()) {
205                     var detection = result.get();
206                     if (logger.isDebugEnabled()) {
207                         ArrayList<String> scores = new ArrayList<>();
208                         var scoreNames = detection.getScoreNames().split("\\|\\|");
209                         var scoreValues = detection.getScores();
210                         for (var i = 0; i < Integer.min(scoreNames.length, scoreValues.length); i++) {
211                             scores.add("'" + scoreNames[i] + "': " + scoreValues[i]);
212                         }
213                         logger.debug("Detected '{}' with: Score: {}, AvgScore: {}, Count: {}, Gain: {}, Scores: {}",
214                                 detection.getName(), detection.getScore(), detection.getAvgScore(),
215                                 detection.getCounter(), detection.getGain(), String.join(", ", scores));
216                     }
217                     detection.delete();
218                     ksListener.ksEventReceived(new KSpottedEvent());
219                 }
220             } catch (IOException | InterruptedException e) {
221                 String errorMessage = e.getMessage();
222                 ksListener.ksEventReceived(new KSErrorEvent(errorMessage != null ? errorMessage : "Unexpected error"));
223             }
224         }
225         rustpotter.delete();
226         logger.debug("rustpotter stopped");
227     }
228
229     private ScoreMode getScoreMode(String mode) {
230         switch (mode) {
231             case "average":
232                 return ScoreMode.AVG;
233             case "median":
234                 return ScoreMode.MEDIAN;
235             case "p25":
236                 return ScoreMode.P25;
237             case "p50":
238                 return ScoreMode.P50;
239             case "p75":
240                 return ScoreMode.P75;
241             case "p80":
242                 return ScoreMode.P80;
243             case "p90":
244                 return ScoreMode.P90;
245             case "p95":
246                 return ScoreMode.P95;
247             case "max":
248             default:
249                 return ScoreMode.MAX;
250         }
251     }
252 }