|
| 1 | +import httpx |
| 2 | +import sys |
| 3 | +import os |
| 4 | +from emd.models.utils.constants import ModelType |
| 5 | +import inspect |
| 6 | +from backend.backend import OpenAICompitableProxyBackendBase |
| 7 | +from emd.utils.logger_utils import get_logger |
| 8 | + |
| 9 | +logger = get_logger(__name__) |
| 10 | + |
| 11 | +class HiggsAudioBackend(OpenAICompitableProxyBackendBase): |
| 12 | + """ |
| 13 | + Higgs Audio Backend that uses the Docker image's native entrypoint |
| 14 | + instead of the standard vLLM serve command. |
| 15 | +
|
| 16 | + This backend is specifically designed for the Higgs Audio v2 Generation 3B Base model |
| 17 | + which provides its own API server via the pre-built Docker image with entrypoint: |
| 18 | + ["python3", "-m", "vllm.entrypoints.bosonai.api_server"] |
| 19 | + """ |
| 20 | + |
| 21 | + def before_start(self,model_dir=None): |
| 22 | + logger.info(f"before_startbefore_startbefore_startbefore_startbefore_start") |
| 23 | + |
| 24 | + def create_proxy_server_start_command(self, model_path): |
| 25 | + return f'python3 -m vllm.entrypoints.bosonai.api_server --served-model-name higgs-audio-v2-generation-3B-base --model bosonai/higgs-audio-v2-generation-3B-base --audio-tokenizer-type bosonai/higgs-audio-v2-tokenizer --limit-mm-per-prompt audio=50 --max-model-len 8192 --tensor-parallel-size 8 --pipeline-parallel-size 1 --port 8000 --gpu-memory-utilization 0.65 --disable-mm-preprocessor-cache' |
| 26 | + |
| 27 | + def openai_create_helper(self, fn: callable, request: dict): |
| 28 | + """ |
| 29 | + Helper method to handle OpenAI-compatible API calls with extra parameters. |
| 30 | + """ |
| 31 | + sig = inspect.signature(fn) |
| 32 | + extra_body = request.get("extra_body", {}) |
| 33 | + extra_params = {k: request.pop(k) for k in list(request.keys()) if k not in sig.parameters} |
| 34 | + extra_body.update(extra_params) |
| 35 | + request['extra_body'] = extra_body |
| 36 | + return fn(**request) |
| 37 | + |
| 38 | + def invoke(self, request): |
| 39 | + """ |
| 40 | + Invoke the Higgs Audio model with OpenAI-compatible API. |
| 41 | + Supports audio modalities for voice cloning, smart voice generation, and multi-speaker synthesis. |
| 42 | + """ |
| 43 | + # Transform input to Higgs Audio format |
| 44 | + request = self._transform_request(request) |
| 45 | + |
| 46 | + logger.info(f"Higgs Audio request: {request}") |
| 47 | + |
| 48 | + # Handle different model types - Higgs Audio is primarily for audio generation |
| 49 | + if self.model_type == ModelType.AUDIO: |
| 50 | + # Use chat completions endpoint for audio generation |
| 51 | + response = self.openai_create_helper(self.client.chat.completions.create, request) |
| 52 | + else: |
| 53 | + # Fallback to standard chat completions |
| 54 | + response = self.openai_create_helper(self.client.chat.completions.create, request) |
| 55 | + |
| 56 | + logger.info(f"Higgs Audio response: {response}, request: {request}") |
| 57 | + |
| 58 | + if request.get("stream", False): |
| 59 | + return self._transform_streaming_response(response) |
| 60 | + else: |
| 61 | + return self._transform_response(response) |
| 62 | + |
| 63 | + async def ainvoke(self, request): |
| 64 | + """ |
| 65 | + Async invoke the Higgs Audio model with OpenAI-compatible API. |
| 66 | + """ |
| 67 | + # Transform input to Higgs Audio format |
| 68 | + request = self._transform_request(request) |
| 69 | + |
| 70 | + logger.info(f"Higgs Audio async request: {request}") |
| 71 | + |
| 72 | + # Handle different model types - Higgs Audio is primarily for audio generation |
| 73 | + if self.model_type == ModelType.AUDIO: |
| 74 | + # Use chat completions endpoint for audio generation |
| 75 | + response = await self.openai_create_helper(self.async_client.chat.completions.create, request) |
| 76 | + else: |
| 77 | + # Fallback to standard chat completions |
| 78 | + response = await self.openai_create_helper(self.async_client.chat.completions.create, request) |
| 79 | + |
| 80 | + logger.info(f"Higgs Audio async response: {response}, request: {request}") |
| 81 | + |
| 82 | + if request.get("stream", False): |
| 83 | + logger.info(f"Higgs Audio streaming response: {response}") |
| 84 | + return await self._atransform_streaming_response(response) |
| 85 | + else: |
| 86 | + return await self._atransform_response(response) |
0 commit comments