]> git.basschouten.com Git - openhab-addons.git/blob
04e73cd3565cc8db71167f0866748330f45481b9
[openhab-addons.git] /
1 /**
2  * Copyright (c) 2010-2024 Contributors to the openHAB project
3  *
4  * See the NOTICE file(s) distributed with this work for additional
5  * information.
6  *
7  * This program and the accompanying materials are made available under the
8  * terms of the Eclipse Public License 2.0 which is available at
9  * http://www.eclipse.org/legal/epl-2.0
10  *
11  * SPDX-License-Identifier: EPL-2.0
12  */
13 package org.openhab.voice.pipertts.internal;
14
15 import static org.openhab.voice.pipertts.internal.PiperTTSConstants.SERVICE_CATEGORY;
16 import static org.openhab.voice.pipertts.internal.PiperTTSConstants.SERVICE_ID;
17 import static org.openhab.voice.pipertts.internal.PiperTTSConstants.SERVICE_NAME;
18 import static org.openhab.voice.pipertts.internal.PiperTTSConstants.SERVICE_PID;
19
20 import java.io.ByteArrayInputStream;
21 import java.io.ByteArrayOutputStream;
22 import java.io.IOException;
23 import java.nio.ByteBuffer;
24 import java.nio.ByteOrder;
25 import java.nio.file.Files;
26 import java.nio.file.Path;
27 import java.util.ArrayList;
28 import java.util.HashMap;
29 import java.util.List;
30 import java.util.Locale;
31 import java.util.Map;
32 import java.util.Objects;
33 import java.util.Optional;
34 import java.util.Set;
35 import java.util.concurrent.atomic.AtomicInteger;
36 import java.util.stream.Collectors;
37
38 import javax.sound.sampled.AudioFileFormat;
39 import javax.sound.sampled.AudioInputStream;
40 import javax.sound.sampled.AudioSystem;
41
42 import org.eclipse.jdt.annotation.NonNullByDefault;
43 import org.eclipse.jdt.annotation.Nullable;
44 import org.openhab.core.OpenHAB;
45 import org.openhab.core.audio.AudioFormat;
46 import org.openhab.core.audio.AudioStream;
47 import org.openhab.core.audio.ByteArrayAudioStream;
48 import org.openhab.core.config.core.ConfigurableService;
49 import org.openhab.core.config.core.Configuration;
50 import org.openhab.core.voice.AbstractCachedTTSService;
51 import org.openhab.core.voice.TTSCache;
52 import org.openhab.core.voice.TTSException;
53 import org.openhab.core.voice.TTSService;
54 import org.openhab.core.voice.Voice;
55 import org.osgi.framework.Constants;
56 import org.osgi.service.component.annotations.Activate;
57 import org.osgi.service.component.annotations.Component;
58 import org.osgi.service.component.annotations.Deactivate;
59 import org.osgi.service.component.annotations.Modified;
60 import org.osgi.service.component.annotations.Reference;
61 import org.slf4j.Logger;
62 import org.slf4j.LoggerFactory;
63
64 import com.fasterxml.jackson.databind.JsonNode;
65 import com.fasterxml.jackson.databind.ObjectMapper;
66
67 import io.github.givimad.piperjni.PiperJNI;
68 import io.github.givimad.piperjni.PiperVoice;
69
70 /**
71  * The {@link PiperTTSService} class is a service implementation to use Piper for Text-to-Speech.
72  *
73  * @author Miguel Álvarez - Initial contribution
74  */
75 @NonNullByDefault
76 @Component(service = TTSService.class, configurationPid = SERVICE_PID, property = Constants.SERVICE_PID + "="
77         + SERVICE_PID)
78 @ConfigurableService(category = SERVICE_CATEGORY, label = SERVICE_NAME
79         + " Text-to-Speech", description_uri = SERVICE_CATEGORY + ":" + SERVICE_ID)
80 public class PiperTTSService extends AbstractCachedTTSService {
81     private static final Path PIPER_FOLDER = Path.of(OpenHAB.getUserDataFolder(), "piper");
82     private final Logger logger = LoggerFactory.getLogger(PiperTTSService.class);
83     private final Object modelLock = new Object();
84     private PiperTTSConfiguration config = new PiperTTSConfiguration();
85     private @Nullable VoiceModel preloadedModel;
86     private @Nullable PiperJNI piper;
87     private Map<String, List<Voice>> cachedVoicesByModel = new HashMap<>();
88
89     @Activate
90     public PiperTTSService(final @Reference TTSCache ttsCache) {
91         super(ttsCache);
92     }
93
94     @Activate
95     protected void activate(Map<String, Object> config) {
96         try {
97             piper = new PiperJNI();
98             piper.initialize(true, false);
99             logger.debug("Using Piper version {}", piper.getPiperVersion());
100         } catch (IOException e) {
101             logger.warn("Piper registration failed, the add-on will not work: {}", e.getMessage());
102         }
103         tryCreatePiperDirectory();
104         configChange(config);
105     }
106
107     @Modified
108     protected void modified(Map<String, Object> config) {
109         configChange(config);
110     }
111
112     @Deactivate
113     protected void deactivate(Map<String, Object> config) {
114         try {
115             unloadModel();
116             getPiper().close();
117             piper = null;
118         } catch (IOException e) {
119             logger.warn("Exception unloading model: {}", e.getMessage());
120         } catch (LibraryNotLoaded ignored) {
121         }
122     }
123
124     private void configChange(Map<String, Object> config) {
125         this.config = new Configuration(config).as(PiperTTSConfiguration.class);
126         try {
127             unloadModel();
128         } catch (IOException e) {
129             logger.warn("IOException unloading model: {}", e.getMessage());
130         }
131     }
132
133     private PiperJNI getPiper() throws LibraryNotLoaded {
134         PiperJNI piper = this.piper;
135         if (piper == null) {
136             throw new LibraryNotLoaded();
137         }
138         return piper;
139     }
140
141     private void tryCreatePiperDirectory() {
142         if (!Files.exists(PIPER_FOLDER)) {
143             try {
144                 Files.createDirectory(PIPER_FOLDER);
145                 logger.info("Piper directory created at: {}", PIPER_FOLDER);
146             } catch (IOException e) {
147                 logger.warn("Unable to create piper directory at {}", PIPER_FOLDER);
148             }
149         }
150     }
151
152     @Override
153     public String getId() {
154         return SERVICE_ID;
155     }
156
157     @Override
158     public String getLabel(@Nullable Locale locale) {
159         return SERVICE_NAME;
160     }
161
162     @Override
163     public Set<Voice> getAvailableVoices() {
164         try (var filesStream = Files.list(PIPER_FOLDER)) {
165             HashMap<String, List<Voice>> newCachedVoices = new HashMap<>();
166             Set<Voice> voices = filesStream //
167                     .filter(filePath -> filePath.getFileName().toString().endsWith(".onnx")) //
168                     .map(filePath -> {
169                         List<Voice> modelVoices = getVoice(filePath);
170                         newCachedVoices.put(filePath.toString(), modelVoices);
171                         return modelVoices;
172                     }) //
173                     .flatMap(List::stream) //
174                     .collect(Collectors.toSet());
175             cachedVoicesByModel = newCachedVoices;
176             logger.debug("Available number of piper voices: {}", voices.size());
177             return voices;
178         } catch (IOException e) {
179             logger.warn("IOException getting piper voices: {}", e.getMessage());
180         }
181         return Set.of();
182     }
183
184     private List<Voice> getVoice(Path modelPath) {
185         try {
186             Path configFile = modelPath.getParent().resolve(modelPath.getFileName() + ".json");
187             if (!Files.exists(configFile) || Files.isDirectory(configFile)) {
188                 throw new IOException("Missed config file: " + configFile.toAbsolutePath());
189             }
190             List<Voice> cachedVoices = cachedVoicesByModel.get(modelPath.toString());
191             if (cachedVoices != null) {
192                 return cachedVoices;
193             }
194             String voiceData = Files.readString(configFile);
195             JsonNode voiceJsonRoot = new ObjectMapper().readTree(voiceData);
196             JsonNode datasetJsonNode = voiceJsonRoot.get("dataset");
197             JsonNode languageJsonNode = voiceJsonRoot.get("language");
198             JsonNode numSpeakersJsonNode = voiceJsonRoot.get("num_speakers");
199             if (datasetJsonNode == null || languageJsonNode == null) {
200                 throw new IOException("Unknown voice config structure");
201             }
202             JsonNode languageFamilyJsonNode = languageJsonNode.get("family");
203             JsonNode languageRegionJsonNode = languageJsonNode.get("region");
204             if (languageFamilyJsonNode == null || languageRegionJsonNode == null) {
205                 throw new IOException("Unknown voice config structure");
206             }
207             String voiceName = datasetJsonNode.textValue();
208             String voiceUID = voiceName.replace(" ", "_");
209             String languageFamily = languageFamilyJsonNode.textValue();
210             String languageRegion = languageRegionJsonNode.textValue();
211             int numSpeakers = numSpeakersJsonNode != null ? numSpeakersJsonNode.intValue() : 1;
212             JsonNode speakersIdsJsonNode = voiceJsonRoot.get("speaker_id_map");
213             if (numSpeakers != 1 && speakersIdsJsonNode != null) {
214                 List<Voice> voices = new ArrayList<>();
215                 speakersIdsJsonNode.fieldNames().forEachRemaining(field -> {
216                     JsonNode fieldNode = speakersIdsJsonNode.get(field);
217                     voices.add(new PiperTTSVoice( //
218                             voiceUID + "_" + field, //
219                             capitalize(voiceName + " " + field), //
220                             languageFamily, //
221                             languageRegion, //
222                             modelPath, //
223                             configFile, //
224                             Optional.of(fieldNode.longValue())));
225                 });
226                 return voices;
227             }
228             return List.of(new PiperTTSVoice(voiceUID, capitalize(voiceName), languageFamily, languageRegion, modelPath,
229                     configFile, Optional.empty()));
230         } catch (IOException e) {
231             logger.warn("IOException reading voice info: {}", e.getMessage());
232             return List.of();
233         }
234     }
235
236     @Override
237     public Set<AudioFormat> getSupportedFormats() {
238         return Set.of(new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, null, null, null,
239                 null));
240     }
241
242     @Override
243     public AudioStream synthesizeForCache(String text, Voice voice, AudioFormat audioFormat) throws TTSException {
244         if (!(voice instanceof PiperTTSVoice ttsVoice)) {
245             throw new TTSException("No piper voice provided");
246         }
247         VoiceModel voiceModel = null;
248         boolean usingPreloadedModel = false;
249         short[] buffer;
250         final VoiceModel preloadedModel = this.preloadedModel;
251         try {
252             try {
253                 if (preloadedModel != null && preloadedModel.ttsVoice.getUID().equals(ttsVoice.getUID())) {
254                     logger.debug("Using preloaded voice model");
255                     preloadedModel.consumers.incrementAndGet();
256                     voiceModel = preloadedModel;
257                     usingPreloadedModel = true;
258                 } else {
259                     unloadModel();
260                     logger.debug("Loading voice model...");
261                     voiceModel = loadModel(ttsVoice);
262                     synchronized (modelLock) {
263                         usingPreloadedModel = voiceModel.equals(this.preloadedModel);
264                     }
265                 }
266             } catch (IOException e) {
267                 throw new TTSException("Unable to load voice model: " + e.getMessage());
268             }
269             try {
270                 logger.debug("Generating audio for: '{}'", text);
271                 buffer = getPiper().textToAudio(voiceModel.piperVoice, text);
272                 logger.debug("Generated {} samples of audio", buffer.length);
273             } catch (IOException e) {
274                 throw new TTSException("Voice generation failed: " + e.getMessage());
275             }
276         } catch (PiperJNI.NotInitialized | LibraryNotLoaded e) {
277             throw new TTSException("Piper not initialized, try restarting the add-on.");
278         } catch (RuntimeException e) {
279             logger.warn("RuntimeException running text to audio: {}", e.getMessage());
280             throw new TTSException("There was an error running Piper");
281         } finally {
282             if (voiceModel != null) {
283                 if (!usingPreloadedModel
284                         || voiceModel.consumers.decrementAndGet() == 0 && !voiceModel.equals(this.preloadedModel)) {
285                     logger.debug("Unloading voice model");
286                     voiceModel.close();
287                 } else {
288                     logger.debug("Skipping voice model unload");
289                 }
290             }
291         }
292         try {
293             logger.debug("Return re-encoded audio stream");
294             return getAudioStream(buffer, voiceModel.sampleRate, audioFormat);
295         } catch (IOException e) {
296             throw new TTSException("Error while creating audio stream: " + e.getMessage());
297         }
298     }
299
300     private VoiceModel loadModel(PiperTTSVoice voice) throws IOException, PiperJNI.NotInitialized, LibraryNotLoaded {
301         if (!Files.exists(voice.voiceModelPath()) || !Files.exists(voice.voiceModelConfigPath())) {
302             throw new IOException("Missing voice files");
303         }
304         PiperJNI piper = getPiper();
305         PiperVoice piperVoice;
306         VoiceModel voiceModel;
307         piperVoice = piper.loadVoice(voice.voiceModelPath(), voice.voiceModelConfigPath(), voice.speakerId.orElse(-1L));
308         voiceModel = new VoiceModel(voice, piperVoice, piperVoice.getSampleRate(), new AtomicInteger(1), logger);
309         if (config.preloadModel) {
310             synchronized (modelLock) {
311                 if (preloadedModel == null) {
312                     logger.debug("Voice model will be kept preloaded");
313                     preloadedModel = voiceModel;
314                 } else {
315                     logger.debug("Another voice model already preloaded");
316                 }
317             }
318         }
319         return voiceModel;
320     }
321
322     private void unloadModel() throws IOException {
323         var model = preloadedModel;
324         if (model != null) {
325             synchronized (modelLock) {
326                 preloadedModel = null;
327                 if (model.consumers.get() == 0) {
328                     // Do not release the model memory if it's been used, it should be released by the consumer
329                     // when there is no other consumers and is not a ref of the preloaded model object.
330                     logger.debug("Unloading preloaded model");
331                     model.close();
332                 } else {
333                     logger.debug("Preloaded model in use, skip memory release");
334                 }
335             }
336         }
337     }
338
339     private ByteArrayAudioStream getAudioStream(short[] samples, long sampleRate, AudioFormat targetFormat)
340             throws IOException {
341         // Convert the i16 samples returned by piper to a byte buffer
342         ByteBuffer byteBuffer;
343         int numSamples = samples.length;
344         byteBuffer = ByteBuffer.allocate(numSamples * 2).order(ByteOrder.LITTLE_ENDIAN);
345         for (var sample : samples) {
346             byteBuffer.putShort(sample);
347         }
348         // Initialize a Java audio stream using the Piper output format with the byte buffer created.
349         byte[] bytes = byteBuffer.array();
350         javax.sound.sampled.AudioFormat jAudioFormat = new javax.sound.sampled.AudioFormat(sampleRate, 16, 1, true,
351                 false);
352         long audioLength = (long) Math.ceil(((double) bytes.length) / jAudioFormat.getFrameSize());
353         AudioInputStream audioInputStreamTemp = new AudioInputStream(new ByteArrayInputStream(bytes), jAudioFormat,
354                 audioLength);
355         // Move the audio data to another Java audio stream in the target format so the Java AudioSystem encoded it as
356         // needed.
357         javax.sound.sampled.AudioFormat jTargetFormat = new javax.sound.sampled.AudioFormat(
358                 Objects.requireNonNull(targetFormat.getFrequency()), Objects.requireNonNull(targetFormat.getBitDepth()),
359                 Objects.requireNonNull(targetFormat.getChannels()), true, false);
360         AudioInputStream convertedInputStream = AudioSystem.getAudioInputStream(jTargetFormat, audioInputStreamTemp);
361         // It's required to add the wav header to the byte array stream returned for it to work with all the sink
362         // implementations.
363         // It can not be done with the AudioInputStream returned by AudioSystem::getAudioInputStream because it missed
364         // the length property.
365         // Therefore, the following method creates another AudioInputStream instance and uses the Java AudioSystem to
366         // prepend
367         // the wav header bytes,
368         // and finally initializes an OpenHAB audio stream.
369         return getAudioStreamWithRIFFHeader(convertedInputStream.readAllBytes(), jTargetFormat, targetFormat);
370     }
371
372     private String capitalize(String text) {
373         return text.substring(0, 1).toUpperCase() + text.substring(1);
374     }
375
376     private ByteArrayAudioStream getAudioStreamWithRIFFHeader(byte[] audioBytes,
377             javax.sound.sampled.AudioFormat jAudioFormat, AudioFormat audioFormat) throws IOException {
378         AudioInputStream audioInputStreamTemp = new AudioInputStream(new ByteArrayInputStream(audioBytes), jAudioFormat,
379                 (long) Math.ceil(((double) audioBytes.length) / jAudioFormat.getFrameSize()));
380         ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
381         AudioSystem.write(audioInputStreamTemp, AudioFileFormat.Type.WAVE, outputStream);
382         return new ByteArrayAudioStream(outputStream.toByteArray(), audioFormat);
383     }
384
385     private record PiperTTSVoice(String voiceId, String voiceName, String languageFamily, String languageRegion,
386             Path voiceModelPath, Path voiceModelConfigPath, Optional<Long> speakerId) implements Voice {
387         @Override
388         public String getUID() {
389             // Voice uid should be prefixed by service id to be listed properly on the UI.
390             return SERVICE_ID + ":" + voiceId + "-" + languageFamily + "_" + languageRegion;
391         }
392
393         @Override
394         public String getLabel() {
395             return voiceName;
396         }
397
398         @Override
399         public Locale getLocale() {
400             return new Locale(languageFamily, languageRegion);
401         }
402     }
403
404     private static class LibraryNotLoaded extends Exception {
405         private LibraryNotLoaded() {
406             super("Library not loaded");
407         }
408     }
409
410     private record VoiceModel(PiperTTSVoice ttsVoice, PiperVoice piperVoice, int sampleRate, AtomicInteger consumers,
411             Logger logger) implements AutoCloseable {
412
413         @Override
414         public void close() {
415             piperVoice.close();
416         }
417     }
418 }