MikuAI/api.py

53 lines
1.5 KiB
Python
Raw Normal View History

2024-03-31 19:52:44 +00:00
from fastapi import FastAPI, Response, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse
from config import TOKEN
import hmac
import model
from pydantic import BaseModel
from typing import List, Optional
import tempfile
from rvc.main import song_cover_pipeline
app = FastAPI()
class Message(BaseModel):
role: str
content: str
@app.post("/")
async def root(token: str,
messages: List[Message],
response: Response,
max_new_tokens: Optional[int] = 64,
temperature: Optional[float] = 0.9,
repetition_penalty: Optional[float] = 1.2):
if not hmac.compare_digest(token, TOKEN):
response.status_code = 401
return {"error": "Bad token"}
dict_in = jsonable_encoder(messages)
output = model.inference(dict_in, max_new_tokens=max_new_tokens, temperature=temperature, repetition_penalty=repetition_penalty)
return {"raw": output}
@app.post("/rvc")
async def rvc(token: str,
file: UploadFile,
response: Response,
pitch_change_oct: Optional[int] = 1,
pitch_change_sem: Optional[int] = 0):
if not hmac.compare_digest(token, TOKEN):
response.status_code = 401
return {"error": "Bad token"}
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(await file.read())
ai_vocals_path = song_cover_pipeline(tmp.name, pitch_change_oct, voice_model='miku', pitch_change_sem=pitch_change_sem)
return FileResponse(ai_vocals_path)
@app.get("/ping")
def ping():
return {"message": "pong"}