OllamaProxy/ollama_proxy.py

231 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse
import httpx
import asyncio
import logging
import os
import argparse
import sys
from datetime import datetime, timedelta
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 解析命令行参数
parser = argparse.ArgumentParser(description='Ollama代理服务器')
parser.add_argument('--ollama-url', help='Ollama服务器URL')
parser.add_argument('--wake-url', help='唤醒服务器URL')
parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒)')
parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)')
parser.add_argument('--port', type=int, help='代理服务器端口')
parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)')
args = parser.parse_args()
# 配置常量,优先使用环境变量,其次使用命令行参数
OLLAMA_URL = os.getenv('OLLAMA_URL') or args.ollama_url
WAKE_URL = os.getenv('WAKE_URL') or args.wake_url
TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout
MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒
PORT = os.getenv('PORT') or args.port
WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval)
# 检查必要参数
missing_params = []
if not OLLAMA_URL:
missing_params.append("OLLAMA_URL")
if not WAKE_URL:
missing_params.append("WAKE_URL")
if not TIMEOUT_SECONDS:
missing_params.append("TIMEOUT_SECONDS")
if not PORT:
missing_params.append("PORT")
if missing_params:
logger.error(f"缺少必要参数: {', '.join(missing_params)}")
logger.error("请通过环境变量或命令行参数指定这些值")
sys.exit(1)
# 确保数值类型正确
try:
TIMEOUT_SECONDS = int(TIMEOUT_SECONDS)
PORT = int(PORT)
except ValueError as e:
logger.error("TIMEOUT_SECONDS 和 PORT 必须是整数")
sys.exit(1)
# 添加上次唤醒时间的全局变量
last_wake_time = None
# 添加缓存相关的变量
models_cache = None
models_cache_time = None
CACHE_DURATION = timedelta(minutes=30) # 缓存有效期30分钟
async def should_wake():
"""检查是否需要发送唤醒请求"""
global last_wake_time
if last_wake_time is None:
return True
return datetime.now() - last_wake_time > timedelta(minutes=WAKE_INTERVAL)
async def wake_ollama():
"""唤醒 Ollama 服务器"""
global last_wake_time
try:
async with httpx.AsyncClient() as client:
await client.get(WAKE_URL)
last_wake_time = datetime.now()
logger.info(f"已发送唤醒请求,更新唤醒时间: {last_wake_time}")
except Exception as e:
logger.error(f"唤醒请求失败: {str(e)}")
async def get_models_from_cache():
"""从缓存获取模型列表"""
global models_cache, models_cache_time
if models_cache is None or models_cache_time is None:
return None
if datetime.now() - models_cache_time > CACHE_DURATION:
return None
return models_cache
async def update_models_cache(data):
"""更新模型列表缓存"""
global models_cache, models_cache_time
models_cache = data
models_cache_time = datetime.now()
logger.info("模型列表缓存已更新")
# 输出当前配置
logger.info(f"使用配置:")
logger.info(f"OLLAMA_URL: {OLLAMA_URL}")
logger.info(f"WAKE_URL: {WAKE_URL}")
logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}")
logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}")
logger.info(f"PORT: {PORT}")
logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes")
app = FastAPI()
@app.get("/health")
async def health_check():
logger.info("收到健康检查请求")
return {"status": "healthy"}
@app.get("/api/tags")
async def list_models():
try:
# 首先尝试从缓存获取
cached_models = await get_models_from_cache()
async with httpx.AsyncClient() as client:
response = await client.get(
f"{OLLAMA_URL}/api/tags",
timeout=TIMEOUT_SECONDS # 使用较短的超时时间
)
# 更新缓存并返回最新数据
await update_models_cache(response.json())
return response.json()
except (httpx.TimeoutException, httpx.ConnectError) as e:
# 发生超时或连接错误时,触发唤醒
logger.warning(f"获取标签列表失败,正在唤醒服务器: {str(e)}")
asyncio.create_task(wake_ollama())
# 如果有缓存,返回缓存数据
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 如果没有缓存返回503
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except Exception as e:
logger.error(f"获取标签列表时发生未知错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(request: Request, path: str):
# 避免代理 /health 请求
if path == "health":
return await health_check()
# 其他请求的处理逻辑
if await should_wake():
logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求")
await wake_ollama()
async with httpx.AsyncClient() as client:
try:
target_url = f"{OLLAMA_URL}/{path}"
body = await request.body()
headers = dict(request.headers)
headers.pop('host', None)
headers.pop('connection', None)
# 根据请求类型选择不同的超时时间
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
response = await client.request(
method=request.method,
url=target_url,
content=body,
headers=headers,
timeout=timeout, # 使用动态超时时间
follow_redirects=True
)
# 如果是标签列表请求且成功,更新缓存
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
await update_models_cache(response.json())
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.TimeoutException:
logger.warning("Ollama服务器超时发送唤醒请求")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 直接异步发送唤醒请求,不等待结果
asyncio.create_task(wake_ollama())
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
return JSONResponse(
status_code=502,
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
)
except Exception as e:
logger.error(f"代理请求失败: {str(e)}")
return JSONResponse(
status_code=500,
content={"message": f"代理请求失败: {str(e)}"}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT)