Initial commit: Add Ollama Proxy project files

This commit is contained in:
yshtcn
2025-01-23 00:13:12 +08:00
commit 49b834ff93
11 changed files with 662 additions and 0 deletions

231
ollama_proxy.py Normal file
View File

@@ -0,0 +1,231 @@
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse
import httpx
import asyncio
import logging
import os
import argparse
import sys
from datetime import datetime, timedelta
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 解析命令行参数
parser = argparse.ArgumentParser(description='Ollama代理服务器')
parser.add_argument('--ollama-url', help='Ollama服务器URL')
parser.add_argument('--wake-url', help='唤醒服务器URL')
parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒)')
parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)')
parser.add_argument('--port', type=int, help='代理服务器端口')
parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)')
args = parser.parse_args()
# 配置常量,优先使用环境变量,其次使用命令行参数
OLLAMA_URL = os.getenv('OLLAMA_URL') or args.ollama_url
WAKE_URL = os.getenv('WAKE_URL') or args.wake_url
TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout
MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒
PORT = os.getenv('PORT') or args.port
WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval)
# 检查必要参数
missing_params = []
if not OLLAMA_URL:
missing_params.append("OLLAMA_URL")
if not WAKE_URL:
missing_params.append("WAKE_URL")
if not TIMEOUT_SECONDS:
missing_params.append("TIMEOUT_SECONDS")
if not PORT:
missing_params.append("PORT")
if missing_params:
logger.error(f"缺少必要参数: {', '.join(missing_params)}")
logger.error("请通过环境变量或命令行参数指定这些值")
sys.exit(1)
# 确保数值类型正确
try:
TIMEOUT_SECONDS = int(TIMEOUT_SECONDS)
PORT = int(PORT)
except ValueError as e:
logger.error("TIMEOUT_SECONDS 和 PORT 必须是整数")
sys.exit(1)
# 添加上次唤醒时间的全局变量
last_wake_time = None
# 添加缓存相关的变量
models_cache = None
models_cache_time = None
CACHE_DURATION = timedelta(minutes=30) # 缓存有效期30分钟
async def should_wake():
"""检查是否需要发送唤醒请求"""
global last_wake_time
if last_wake_time is None:
return True
return datetime.now() - last_wake_time > timedelta(minutes=WAKE_INTERVAL)
async def wake_ollama():
"""唤醒 Ollama 服务器"""
global last_wake_time
try:
async with httpx.AsyncClient() as client:
await client.get(WAKE_URL)
last_wake_time = datetime.now()
logger.info(f"已发送唤醒请求,更新唤醒时间: {last_wake_time}")
except Exception as e:
logger.error(f"唤醒请求失败: {str(e)}")
async def get_models_from_cache():
"""从缓存获取模型列表"""
global models_cache, models_cache_time
if models_cache is None or models_cache_time is None:
return None
if datetime.now() - models_cache_time > CACHE_DURATION:
return None
return models_cache
async def update_models_cache(data):
"""更新模型列表缓存"""
global models_cache, models_cache_time
models_cache = data
models_cache_time = datetime.now()
logger.info("模型列表缓存已更新")
# 输出当前配置
logger.info(f"使用配置:")
logger.info(f"OLLAMA_URL: {OLLAMA_URL}")
logger.info(f"WAKE_URL: {WAKE_URL}")
logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}")
logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}")
logger.info(f"PORT: {PORT}")
logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes")
app = FastAPI()
@app.get("/health")
async def health_check():
logger.info("收到健康检查请求")
return {"status": "healthy"}
@app.get("/api/tags")
async def list_models():
try:
# 首先尝试从缓存获取
cached_models = await get_models_from_cache()
async with httpx.AsyncClient() as client:
response = await client.get(
f"{OLLAMA_URL}/api/tags",
timeout=TIMEOUT_SECONDS # 使用较短的超时时间
)
# 更新缓存并返回最新数据
await update_models_cache(response.json())
return response.json()
except (httpx.TimeoutException, httpx.ConnectError) as e:
# 发生超时或连接错误时,触发唤醒
logger.warning(f"获取标签列表失败,正在唤醒服务器: {str(e)}")
asyncio.create_task(wake_ollama())
# 如果有缓存,返回缓存数据
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 如果没有缓存返回503
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except Exception as e:
logger.error(f"获取标签列表时发生未知错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(request: Request, path: str):
# 避免代理 /health 请求
if path == "health":
return await health_check()
# 其他请求的处理逻辑
if await should_wake():
logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求")
await wake_ollama()
async with httpx.AsyncClient() as client:
try:
target_url = f"{OLLAMA_URL}/{path}"
body = await request.body()
headers = dict(request.headers)
headers.pop('host', None)
headers.pop('connection', None)
# 根据请求类型选择不同的超时时间
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
response = await client.request(
method=request.method,
url=target_url,
content=body,
headers=headers,
timeout=timeout, # 使用动态超时时间
follow_redirects=True
)
# 如果是标签列表请求且成功,更新缓存
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
await update_models_cache(response.json())
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.TimeoutException:
logger.warning("Ollama服务器超时发送唤醒请求")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 直接异步发送唤醒请求,不等待结果
asyncio.create_task(wake_ollama())
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
return JSONResponse(
status_code=502,
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
)
except Exception as e:
logger.error(f"代理请求失败: {str(e)}")
return JSONResponse(
status_code=500,
content={"message": f"代理请求失败: {str(e)}"}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT)