mirror of
https://github.com/yshtcn/OllamaProxy.git
synced 2025-12-11 16:50:27 +08:00
- 新增LM Studio服务器支持,可以通过--lms-url参数指定LM Studio服务器 - 优化wake-url参数为可选配置,不再强制要求配置唤醒服务器 - 根据服务器类型动态调整API端点路径 - 改进错误信息和日志输出,更好地区分服务器类型 - 重构配置验证逻辑,确保OLLAMA_URL和LMS_URL不会同时配置
308 lines
11 KiB
Python
308 lines
11 KiB
Python
from fastapi import FastAPI, Request, Response, HTTPException
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
import httpx
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import argparse
|
||
import sys
|
||
from datetime import datetime, timedelta
|
||
import json
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description='代理服务器')
|
||
parser.add_argument('--ollama-url', help='Ollama服务器URL')
|
||
parser.add_argument('--lms-url', help='LM Studio服务器URL')
|
||
parser.add_argument('--wake-url', help='唤醒服务器URL(可选)')
|
||
parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒)')
|
||
parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)')
|
||
parser.add_argument('--port', type=int, help='代理服务器端口')
|
||
parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟),仅在配置wake-url时有效')
|
||
parser.add_argument('--cache-duration', type=int, help='模型列表缓存有效期(分钟),默认1440分钟(1天)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 配置常量,优先使用环境变量,其次使用命令行参数
|
||
OLLAMA_URL = os.getenv('OLLAMA_URL') or args.ollama_url
|
||
LMS_URL = os.getenv('LMS_URL') or args.lms_url
|
||
WAKE_URL = os.getenv('WAKE_URL') or args.wake_url
|
||
TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout
|
||
MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒
|
||
PORT = os.getenv('PORT') or args.port
|
||
WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval)
|
||
CACHE_DURATION = int(os.getenv('CACHE_DURATION') or args.cache_duration or 1440) # 默认1天
|
||
|
||
# 检查URL配置
|
||
if OLLAMA_URL and LMS_URL:
|
||
logger.error("不能同时配置 OLLAMA_URL 和 LMS_URL,请只选择其中一个")
|
||
sys.exit(1)
|
||
elif not (OLLAMA_URL or LMS_URL):
|
||
logger.error("必须配置 OLLAMA_URL 或 LMS_URL 其中之一")
|
||
sys.exit(1)
|
||
|
||
# 设置服务器类型和基础URL
|
||
if OLLAMA_URL:
|
||
server_type = 'ollama'
|
||
BASE_URL = OLLAMA_URL
|
||
MODEL_LIST_PATH = 'api/tags'
|
||
GENERATE_ENDPOINTS = ["api/generate", "api/chat"]
|
||
else:
|
||
server_type = 'lmstudio'
|
||
BASE_URL = LMS_URL
|
||
MODEL_LIST_PATH = 'v1/models'
|
||
GENERATE_ENDPOINTS = ["v1/chat/completions"]
|
||
|
||
# 检查其他必要参数
|
||
missing_params = []
|
||
if not TIMEOUT_SECONDS:
|
||
missing_params.append("TIMEOUT_SECONDS")
|
||
if not PORT:
|
||
missing_params.append("PORT")
|
||
|
||
if missing_params:
|
||
logger.error(f"缺少必要参数: {', '.join(missing_params)}")
|
||
logger.error("请通过环境变量或命令行参数指定这些值")
|
||
sys.exit(1)
|
||
|
||
# 确保数值类型正确
|
||
try:
|
||
TIMEOUT_SECONDS = int(TIMEOUT_SECONDS)
|
||
PORT = int(PORT)
|
||
except ValueError as e:
|
||
logger.error("TIMEOUT_SECONDS 和 PORT 必须是整数")
|
||
sys.exit(1)
|
||
|
||
# 添加上次唤醒时间的全局变量
|
||
last_wake_time = None
|
||
|
||
# 添加缓存相关的变量
|
||
models_cache = None
|
||
models_cache_time = None
|
||
|
||
async def should_wake():
|
||
"""检查是否需要发送唤醒请求"""
|
||
if not WAKE_URL: # 如果没有配置WAKE_URL,永远不需要唤醒
|
||
return False
|
||
|
||
global last_wake_time
|
||
if last_wake_time is None:
|
||
return True
|
||
return datetime.now() - last_wake_time > timedelta(minutes=WAKE_INTERVAL)
|
||
|
||
async def wake_ollama():
|
||
"""唤醒服务器"""
|
||
if not WAKE_URL: # 如果没有配置WAKE_URL,直接返回
|
||
return
|
||
|
||
global last_wake_time
|
||
try:
|
||
async with httpx.AsyncClient() as client:
|
||
await client.get(WAKE_URL)
|
||
last_wake_time = datetime.now()
|
||
logger.info(f"已发送唤醒请求,更新唤醒时间: {last_wake_time}")
|
||
except Exception as e:
|
||
logger.error(f"唤醒请求失败: {str(e)}")
|
||
|
||
async def get_models_from_cache():
|
||
"""从缓存获取模型列表"""
|
||
global models_cache, models_cache_time
|
||
if models_cache is None or models_cache_time is None:
|
||
return None
|
||
if datetime.now() - models_cache_time > timedelta(minutes=CACHE_DURATION):
|
||
return None
|
||
return models_cache
|
||
|
||
async def update_models_cache(data):
|
||
"""更新模型列表缓存"""
|
||
global models_cache, models_cache_time
|
||
models_cache = data
|
||
models_cache_time = datetime.now()
|
||
logger.info("模型列表缓存已更新")
|
||
|
||
app = FastAPI()
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
logger.info("收到健康检查请求")
|
||
return {"status": "healthy"}
|
||
|
||
@app.get("/api/tags")
|
||
async def list_models():
|
||
try:
|
||
# 首先尝试从缓存获取
|
||
cached_models = await get_models_from_cache()
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.get(
|
||
f"{BASE_URL}/{MODEL_LIST_PATH}",
|
||
timeout=TIMEOUT_SECONDS # 使用较短的超时时间
|
||
)
|
||
# 更新缓存并返回最新数据
|
||
await update_models_cache(response.json())
|
||
return response.json()
|
||
|
||
except (httpx.TimeoutException, httpx.ConnectError) as e:
|
||
# 发生超时或连接错误时,触发唤醒
|
||
logger.warning(f"获取标签列表失败,正在唤醒服务器: {str(e)}")
|
||
asyncio.create_task(wake_ollama())
|
||
|
||
# 如果有缓存,返回缓存数据
|
||
if cached_models is not None:
|
||
logger.info("返回缓存的标签列表")
|
||
return JSONResponse(content=cached_models)
|
||
|
||
# 如果没有缓存,返回503
|
||
return JSONResponse(
|
||
status_code=503,
|
||
content={"message": "服务器正在唤醒中,请稍后重试"}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取标签列表时发生未知错误: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||
async def proxy(request: Request, path: str):
|
||
# 避免代理 /health 请求
|
||
if path == "health":
|
||
return await health_check()
|
||
|
||
# 其他请求的处理逻辑
|
||
if WAKE_URL and await should_wake():
|
||
logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求")
|
||
await wake_ollama()
|
||
|
||
try:
|
||
target_url = f"{BASE_URL}/{path}"
|
||
headers = dict(request.headers)
|
||
headers.pop('host', None)
|
||
headers.pop('connection', None)
|
||
# 移除可能导致问题的头部
|
||
headers.pop('content-length', None)
|
||
headers.pop('transfer-encoding', None)
|
||
|
||
# 根据请求类型选择不同的超时时间
|
||
timeout = TIMEOUT_SECONDS if path == MODEL_LIST_PATH else MODEL_TIMEOUT_SECONDS
|
||
|
||
# 检查是否为生成相关的端点
|
||
is_generate_endpoint = path in GENERATE_ENDPOINTS
|
||
|
||
if is_generate_endpoint and request.method == "POST":
|
||
request_body = await request.json()
|
||
# 强制设置stream为true以启用流式传输
|
||
request_body["stream"] = True
|
||
|
||
async def generate_stream():
|
||
client = httpx.AsyncClient()
|
||
try:
|
||
async with client.stream(
|
||
method=request.method,
|
||
url=target_url,
|
||
json=request_body,
|
||
headers=headers,
|
||
timeout=None # 流式传输不设置整体超时
|
||
) as response:
|
||
async for line in response.aiter_lines():
|
||
if line.strip(): # 忽略空行
|
||
yield line.encode('utf-8') + b'\n'
|
||
except httpx.TimeoutError as e:
|
||
logger.error(f"流式传输超时: {str(e)}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"流式传输时发生错误: {str(e)}")
|
||
raise
|
||
finally:
|
||
await client.aclose()
|
||
|
||
return StreamingResponse(
|
||
generate_stream(),
|
||
media_type="application/x-ndjson",
|
||
headers={'Transfer-Encoding': 'chunked'} # 使用分块传输编码
|
||
)
|
||
else:
|
||
# 非生成请求的处理
|
||
async with httpx.AsyncClient() as client:
|
||
body = await request.body()
|
||
response = await client.request(
|
||
method=request.method,
|
||
url=target_url,
|
||
content=body,
|
||
headers=headers,
|
||
timeout=timeout,
|
||
follow_redirects=True
|
||
)
|
||
|
||
# 如果是标签列表请求且成功,更新缓存
|
||
if path == MODEL_LIST_PATH and request.method == "GET" and response.status_code == 200:
|
||
await update_models_cache(response.json())
|
||
|
||
return Response(
|
||
content=response.content,
|
||
status_code=response.status_code,
|
||
headers=dict(response.headers)
|
||
)
|
||
|
||
except httpx.TimeoutException:
|
||
error_msg = "服务器超时"
|
||
if WAKE_URL:
|
||
error_msg += ",正在尝试唤醒"
|
||
logger.warning(f"{error_msg}")
|
||
# 如果是模型列表请求,尝试返回缓存
|
||
if path == MODEL_LIST_PATH and request.method == "GET":
|
||
cached_models = await get_models_from_cache()
|
||
if cached_models is not None:
|
||
logger.info("返回缓存的模型列表")
|
||
return JSONResponse(content=cached_models)
|
||
|
||
# 直接异步发送唤醒请求,不等待结果
|
||
asyncio.create_task(wake_ollama())
|
||
else:
|
||
logger.warning(error_msg)
|
||
|
||
return JSONResponse(
|
||
status_code=503,
|
||
content={"message": f"{error_msg},请稍后重试"}
|
||
)
|
||
|
||
except httpx.RequestError as e:
|
||
logger.error(f"请求错误: {str(e)}")
|
||
# 如果是标签列表请求,尝试返回缓存
|
||
if path == MODEL_LIST_PATH and request.method == "GET":
|
||
cached_models = await get_models_from_cache()
|
||
if cached_models is not None:
|
||
logger.info("返回缓存的标签列表")
|
||
return JSONResponse(content=cached_models)
|
||
|
||
return JSONResponse(
|
||
status_code=502,
|
||
content={"message": f"无法连接到服务器: {str(e)}"}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"代理请求失败: {str(e)}")
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"message": f"代理请求失败: {str(e)}"}
|
||
)
|
||
|
||
# 输出当前配置
|
||
logger.info(f"使用配置:")
|
||
logger.info(f"服务器类型: {server_type}")
|
||
logger.info(f"BASE_URL: {BASE_URL}")
|
||
if WAKE_URL:
|
||
logger.info(f"WAKE_URL: {WAKE_URL}")
|
||
logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes")
|
||
else:
|
||
logger.info("未配置唤醒功能")
|
||
logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}")
|
||
logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}")
|
||
logger.info(f"PORT: {PORT}")
|
||
logger.info(f"CACHE_DURATION: {CACHE_DURATION} minutes")
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=PORT) |