]> git.basschouten.com Git - openhab-addons.git/blob
8b3c28628ce9b3490e43a5cdc506ecb64971f188
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2023 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.rustpotterks.internal;
14
15 import static org.openhab.voice.rustpotterks.internal.RustpotterKSConstants.*;
16
17 import java.io.File;
18 import java.io.IOException;
19 import java.nio.file.Path;
20 import java.util.Locale;
21 import java.util.Map;
22 import java.util.Set;
23 import java.util.concurrent.ScheduledExecutorService;
24 import java.util.concurrent.atomic.AtomicBoolean;
25
26 import org.eclipse.jdt.annotation.NonNullByDefault;
27 import org.eclipse.jdt.annotation.Nullable;
28 import org.openhab.core.OpenHAB;
29 import org.openhab.core.audio.AudioFormat;
30 import org.openhab.core.audio.AudioStream;
31 import org.openhab.core.common.ThreadPoolManager;
32 import org.openhab.core.config.core.ConfigurableService;
33 import org.openhab.core.config.core.Configuration;
34 import org.openhab.core.voice.KSErrorEvent;
35 import org.openhab.core.voice.KSException;
36 import org.openhab.core.voice.KSListener;
37 import org.openhab.core.voice.KSService;
38 import org.openhab.core.voice.KSServiceHandle;
39 import org.openhab.core.voice.KSpottedEvent;
40 import org.osgi.framework.Constants;
41 import org.osgi.service.component.ComponentContext;
42 import org.osgi.service.component.annotations.Activate;
43 import org.osgi.service.component.annotations.Component;
44 import org.osgi.service.component.annotations.Modified;
45 import org.slf4j.Logger;
46 import org.slf4j.LoggerFactory;
47
48 import io.github.givimad.rustpotter_java.Endianness;
49 import io.github.givimad.rustpotter_java.NoiseDetectionMode;
50 import io.github.givimad.rustpotter_java.RustpotterJava;
51 import io.github.givimad.rustpotter_java.RustpotterJavaBuilder;
52 import io.github.givimad.rustpotter_java.VadMode;
53
54 /**
55  * The {@link RustpotterKSService} is a keyword spotting implementation based on rustpotter.
56  *
57  * @author Miguel Álvarez - Initial contribution
58  */
59 @NonNullByDefault
60 @Component(configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "=" + SERVICE_PID)
61 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
62         + " Keyword Spotter", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
63 public class RustpotterKSService implements KSService {
64     private static final String RUSTPOTTER_FOLDER = Path.of(OpenHAB.getUserDataFolder(), "rustpotter").toString();
65     private final Logger logger = LoggerFactory.getLogger(RustpotterKSService.class);
66     private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-rustpotterks");
67     private RustpotterKSConfiguration config = new RustpotterKSConfiguration();
68     static {
69         Logger logger = LoggerFactory.getLogger(RustpotterKSService.class);
70         File directory = new File(RUSTPOTTER_FOLDER);
71         if (!directory.exists()) {
72             if (directory.mkdir()) {
73                 logger.info("rustpotter dir created {}", RUSTPOTTER_FOLDER);
74             }
75         }
76     }
77
78     @Activate
79     protected void activate(ComponentContext componentContext, Map<String, Object> config) {
80         modified(config);
81     }
82
83     @Modified
84     protected void modified(Map<String, Object> config) {
85         this.config = new Configuration(config).as(RustpotterKSConfiguration.class);
86     }
87
88     @Override
89     public String getId() {
90         return SERVICE_ID;
91     }
92
93     @Override
94     public String getLabel(@Nullable Locale locale) {
95         return SERVICE_NAME;
96     }
97
98     @Override
99     public Set<Locale> getSupportedLocales() {
100         return Set.of();
101     }
102
103     @Override
104     public Set<AudioFormat> getSupportedFormats() {
105         return Set
106                 .of(new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, null, null, null));
107     }
108
109     @Override
110     public KSServiceHandle spot(KSListener ksListener, AudioStream audioStream, Locale locale, String keyword)
111             throws KSException {
112         logger.debug("Loading library");
113         try {
114             RustpotterJava.loadLibrary();
115         } catch (IOException e) {
116             throw new KSException("Unable to load rustpotter lib: " + e.getMessage());
117         }
118         var audioFormat = audioStream.getFormat();
119         var frequency = audioFormat.getFrequency();
120         var bitDepth = audioFormat.getBitDepth();
121         var channels = audioFormat.getChannels();
122         var isBigEndian = audioFormat.isBigEndian();
123         if (frequency == null || bitDepth == null || channels == null || isBigEndian == null) {
124             throw new KSException(
125                     "Missing stream metadata: frequency, bit depth, channels and endianness must be defined.");
126         }
127         var endianness = isBigEndian ? Endianness.BIG : Endianness.LITTLE;
128         logger.debug("Audio wav spec: frequency '{}', bit depth '{}', channels '{}', '{}'", frequency, bitDepth,
129                 channels, audioFormat.isBigEndian() ? "big-endian" : "little-endian");
130         RustpotterJava rustpotter = initRustpotter(frequency, bitDepth, channels, endianness);
131         var modelName = keyword.replaceAll("\\s", "_") + ".rpw";
132         var modelPath = Path.of(RUSTPOTTER_FOLDER, modelName);
133         if (!modelPath.toFile().exists()) {
134             throw new KSException("Missing model " + modelName);
135         }
136         try {
137             rustpotter.addWakewordModelFile(modelPath.toString());
138         } catch (Exception e) {
139             throw new KSException("Unable to load wake word model: " + e.getMessage());
140         }
141         logger.debug("Model '{}' loaded", modelPath);
142         AtomicBoolean aborted = new AtomicBoolean(false);
143         executor.submit(() -> processAudioStream(rustpotter, ksListener, audioStream, aborted));
144         return new KSServiceHandle() {
145             @Override
146             public void abort() {
147                 logger.debug("Stopping service");
148                 aborted.set(true);
149             }
150         };
151     }
152
153     private RustpotterJava initRustpotter(long frequency, int bitDepth, int channels, Endianness endianness) {
154         var rustpotterBuilder = new RustpotterJavaBuilder();
155         // audio configs
156         rustpotterBuilder.setBitsPerSample(bitDepth);
157         rustpotterBuilder.setSampleRate(frequency);
158         rustpotterBuilder.setChannels(channels);
159         rustpotterBuilder.setEndianness(endianness);
160         // detector configs
161         rustpotterBuilder.setThreshold(config.threshold);
162         rustpotterBuilder.setAveragedThreshold(config.averagedThreshold);
163         rustpotterBuilder.setComparatorRef(config.comparatorRef);
164         rustpotterBuilder.setComparatorBandSize(config.comparatorBandSize);
165         @Nullable
166         VadMode vadMode = getVADMode(config.vadMode);
167         if (vadMode != null) {
168             rustpotterBuilder.setVADMode(vadMode);
169             rustpotterBuilder.setVADSensitivity(config.vadSensitivity);
170             rustpotterBuilder.setVADDelay(config.vadDelay);
171         }
172         @Nullable
173         NoiseDetectionMode noiseDetectionMode = getNoiseMode(config.noiseDetectionMode);
174         if (noiseDetectionMode != null) {
175             rustpotterBuilder.setNoiseMode(noiseDetectionMode);
176             rustpotterBuilder.setNoiseSensitivity(config.noiseSensitivity);
177         }
178         rustpotterBuilder.setEagerMode(config.eagerMode);
179         // init the detector
180         var rustpotter = rustpotterBuilder.build();
181         rustpotterBuilder.delete();
182         return rustpotter;
183     }
184
185     private void processAudioStream(RustpotterJava rustpotter, KSListener ksListener, AudioStream audioStream,
186             AtomicBoolean aborted) {
187         int numBytesRead;
188         var bufferSize = (int) rustpotter.getBytesPerFrame();
189         byte[] audioBuffer = new byte[bufferSize];
190         int remaining = bufferSize;
191         while (!aborted.get()) {
192             try {
193                 numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);
194                 if (aborted.get()) {
195                     break;
196                 }
197                 if (numBytesRead != remaining) {
198                     remaining = remaining - numBytesRead;
199                     Thread.sleep(100);
200                     continue;
201                 }
202                 remaining = bufferSize;
203                 var result = rustpotter.processBuffer(audioBuffer);
204                 if (result.isPresent()) {
205                     var detection = result.get();
206                     logger.debug("keyword '{}' detected with score {}!", detection.getName(), detection.getScore());
207                     detection.delete();
208                     ksListener.ksEventReceived(new KSpottedEvent());
209                 }
210             } catch (IOException | InterruptedException e) {
211                 String errorMessage = e.getMessage();
212                 ksListener.ksEventReceived(new KSErrorEvent(errorMessage != null ? errorMessage : "Unexpected error"));
213             }
214         }
215         rustpotter.delete();
216         logger.debug("rustpotter stopped");
217     }
218
219     private @Nullable VadMode getVADMode(String mode) {
220         switch (mode) {
221             case "low-bitrate":
222                 return VadMode.LOW_BITRATE;
223             case "quality":
224                 return VadMode.QUALITY;
225             case "aggressive":
226                 return VadMode.AGGRESSIVE;
227             case "very-aggressive":
228                 return VadMode.VERY_AGGRESSIVE;
229             default:
230                 return null;
231         }
232     }
233
234     private @Nullable NoiseDetectionMode getNoiseMode(String mode) {
235         switch (mode) {
236             case "easiest":
237                 return NoiseDetectionMode.EASIEST;
238             case "easy":
239                 return NoiseDetectionMode.EASY;
240             case "normal":
241                 return NoiseDetectionMode.NORMAL;
242             case "hard":
243                 return NoiseDetectionMode.HARD;
244             case "hardest":
245                 return NoiseDetectionMode.HARDEST;
246             default:
247                 return null;
248         }
249     }
250 }