import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable;
public class WatsonSTTService implements STTService {
private final Logger logger = LoggerFactory.getLogger(WatsonSTTService.class);
private final ScheduledExecutorService executor = ThreadPoolManager.getScheduledPool("OH-voice-watsonstt");
- private final List<String> models = List.of("ar-AR_BroadbandModel", "de-DE_BroadbandModel", "en-AU_BroadbandModel",
- "en-GB_BroadbandModel", "en-US_BroadbandModel", "es-AR_BroadbandModel", "es-CL_BroadbandModel",
- "es-CO_BroadbandModel", "es-ES_BroadbandModel", "es-MX_BroadbandModel", "es-PE_BroadbandModel",
- "fr-CA_BroadbandModel", "fr-FR_BroadbandModel", "it-IT_BroadbandModel", "ja-JP_BroadbandModel",
- "ko-KR_BroadbandModel", "nl-NL_BroadbandModel", "pt-BR_BroadbandModel", "zh-CN_BroadbandModel");
- private final Set<Locale> supportedLocales = models.stream().map(name -> name.split("_")[0])
- .map(Locale::forLanguageTag).collect(Collectors.toSet());
+ private final List<String> telephonyModels = List.of("ar-MS_Telephony", "zh-CN_Telephony", "nl-BE_Telephony",
+ "nl-NL_Telephony", "en-AU_Telephony", "en-IN_Telephony", "en-GB_Telephony", "en-US_Telephony",
+ "fr-CA_Telephony", "fr-FR_Telephony", "hi-IN_Telephony", "pt-BR_Telephony", "es-ES_Telephony");
+ private final List<String> multimediaModels = List.of("en-AU_Multimedia", "en-GB_Multimedia", "en-US_Multimedia",
+ "fr-FR_Multimedia", "de-DE_Multimedia", "it-IT_Multimedia", "ja-JP_Multimedia", "ko-KR_Multimedia",
+ "pt-BR_Multimedia", "es-ES_Multimedia");
+ // model 'en-WW_Medical_Telephony' and 'es-LA_Telephony' will be used as fallbacks for es and en
+ private final List<Locale> fallbackLocales = List.of(Locale.forLanguageTag("es"), Locale.ENGLISH);
+ private final Set<Locale> supportedLocales = Stream
+ .concat(Stream.concat(telephonyModels.stream(), multimediaModels.stream()).map(name -> name.split("_")[0])
+ .distinct().map(Locale::forLanguageTag), fallbackLocales.stream())
+ .collect(Collectors.toSet());
private WatsonSTTConfiguration config = new WatsonSTTConfiguration();
private @Nullable SpeechToText speechToText = null;
logger.debug("Content-Type: {}", contentType);
RecognizeWithWebsocketsOptions wsOptions = new RecognizeWithWebsocketsOptions.Builder().audio(audioStream)
.contentType(contentType).redaction(config.redaction).smartFormatting(config.smartFormatting)
- .model(locale.toLanguageTag() + "_BroadbandModel").interimResults(true)
+ .model(getModel(locale)).interimResults(true)
.backgroundAudioSuppression(config.backgroundAudioSuppression)
.speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.maxSilenceSeconds)
.build();
};
}
+ private String getModel(Locale locale) throws STTException {
+ String languageTag = locale.toLanguageTag();
+ Stream<String> allModels;
+ if (config.preferMultimediaModel) {
+ allModels = Stream.concat(multimediaModels.stream(), telephonyModels.stream());
+ } else {
+ allModels = Stream.concat(telephonyModels.stream(), multimediaModels.stream());
+ }
+ var modelOption = allModels.filter(model -> model.startsWith(languageTag)).findFirst();
+ if (modelOption.isEmpty()) {
+ if ("es".equals(locale.getLanguage())) {
+ // fallback for latin american spanish languages
+ var model = "es-LA_Telephony";
+ logger.debug("Falling back to model: {}", model);
+ }
+ if ("en".equals(locale.getLanguage())) {
+ // fallback english dialects
+ var model = "en-WW_Medical_Telephony";
+ logger.debug("Falling back to model: {}", model);
+ }
+ throw new STTException("No compatible model for language " + languageTag);
+ }
+ var model = modelOption.get();
+ logger.debug("Using model: {}", model);
+ return model;
+ }
+
private @Nullable String getContentType(AudioStream audioStream) throws STTException {
AudioFormat format = audioStream.getFormat();
String container = format.getContainer();