39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
|
import os
|
||
|
os.environ['KERAS_BACKEND'] = 'torch'
|
||
|
from transformers import AutoModelForSpeechSeq2Seq, AutoConfig, PreTrainedTokenizerFast
|
||
|
import torchaudio
|
||
|
import sys
|
||
|
import time
|
||
|
import torch
|
||
|
from appPublic.worker import awaitify
|
||
|
from appPublic.jsonConfig import getConfig
|
||
|
from ahserver.serverenv import ServerEnv
|
||
|
from ahserver.webapp import webapp
|
||
|
|
||
|
class Moonshine:
|
||
|
def __init__(self, modelname):
|
||
|
# default modelname 'usefulsensors/moonshine-tiny'
|
||
|
if modelname is None:
|
||
|
modelname = 'usefulsensors/moonshine-tiny'
|
||
|
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(modelname,
|
||
|
trust_remote_code=True)
|
||
|
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(modelname)
|
||
|
|
||
|
print(tokenizer.decode(tokens[0], skip_special_tokens=True))
|
||
|
def inference(self, audiofile):
|
||
|
audio, sr = torchaudio.load(audiofile)
|
||
|
if sr != 16000:
|
||
|
audio = torchaudio.functional.resample(audio, sr, 16000)
|
||
|
tokens = self.model(audio)
|
||
|
return tokenizer.decode(tokens[0], skip_special_tokens=True)
|
||
|
|
||
|
def main():
|
||
|
config = getConfig()
|
||
|
modelname = config.modelname
|
||
|
m = Moonshine(modelname)
|
||
|
g = ServerEnv()
|
||
|
g.inference = awaitify(m.inference)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
webapp(main)
|