OllamaProxy/ollama_proxy.py
yshtcn a40fbadf7b feat: 增强缓存功能
1. 添加可配置的缓存时间参数(--cache-duration)
2. 默认缓存时间从30分钟改为1天(1440分钟)
3. 支持通过环境变量CACHE_DURATION配置
4. 更新文档和配置示例
5. 修复了流式传输的问题
2025-01-27 18:35:30 +08:00

271 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
import httpx
import asyncio
import logging
import os
import argparse
import sys
from datetime import datetime, timedelta
import json
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 解析命令行参数
parser = argparse.ArgumentParser(description='Ollama代理服务器')
parser.add_argument('--ollama-url', help='Ollama服务器URL')
parser.add_argument('--wake-url', help='唤醒服务器URL')
parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒)')
parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)')
parser.add_argument('--port', type=int, help='代理服务器端口')
parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)')
parser.add_argument('--cache-duration', type=int, help='模型列表缓存有效期(分钟)默认1440分钟(1天)')
args = parser.parse_args()
# 配置常量,优先使用环境变量,其次使用命令行参数
OLLAMA_URL = os.getenv('OLLAMA_URL') or args.ollama_url
WAKE_URL = os.getenv('WAKE_URL') or args.wake_url
TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout
MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒
PORT = os.getenv('PORT') or args.port
WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval)
CACHE_DURATION = int(os.getenv('CACHE_DURATION') or args.cache_duration or 1440) # 默认1天
# 检查必要参数
missing_params = []
if not OLLAMA_URL:
missing_params.append("OLLAMA_URL")
if not WAKE_URL:
missing_params.append("WAKE_URL")
if not TIMEOUT_SECONDS:
missing_params.append("TIMEOUT_SECONDS")
if not PORT:
missing_params.append("PORT")
if missing_params:
logger.error(f"缺少必要参数: {', '.join(missing_params)}")
logger.error("请通过环境变量或命令行参数指定这些值")
sys.exit(1)
# 确保数值类型正确
try:
TIMEOUT_SECONDS = int(TIMEOUT_SECONDS)
PORT = int(PORT)
except ValueError as e:
logger.error("TIMEOUT_SECONDS 和 PORT 必须是整数")
sys.exit(1)
# 添加上次唤醒时间的全局变量
last_wake_time = None
# 添加缓存相关的变量
models_cache = None
models_cache_time = None
async def should_wake():
"""检查是否需要发送唤醒请求"""
global last_wake_time
if last_wake_time is None:
return True
return datetime.now() - last_wake_time > timedelta(minutes=WAKE_INTERVAL)
async def wake_ollama():
"""唤醒 Ollama 服务器"""
global last_wake_time
try:
async with httpx.AsyncClient() as client:
await client.get(WAKE_URL)
last_wake_time = datetime.now()
logger.info(f"已发送唤醒请求,更新唤醒时间: {last_wake_time}")
except Exception as e:
logger.error(f"唤醒请求失败: {str(e)}")
async def get_models_from_cache():
"""从缓存获取模型列表"""
global models_cache, models_cache_time
if models_cache is None or models_cache_time is None:
return None
if datetime.now() - models_cache_time > timedelta(minutes=CACHE_DURATION):
return None
return models_cache
async def update_models_cache(data):
"""更新模型列表缓存"""
global models_cache, models_cache_time
models_cache = data
models_cache_time = datetime.now()
logger.info("模型列表缓存已更新")
# 输出当前配置
logger.info(f"使用配置:")
logger.info(f"OLLAMA_URL: {OLLAMA_URL}")
logger.info(f"WAKE_URL: {WAKE_URL}")
logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}")
logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}")
logger.info(f"PORT: {PORT}")
logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes")
logger.info(f"CACHE_DURATION: {CACHE_DURATION} minutes")
app = FastAPI()
@app.get("/health")
async def health_check():
logger.info("收到健康检查请求")
return {"status": "healthy"}
@app.get("/api/tags")
async def list_models():
try:
# 首先尝试从缓存获取
cached_models = await get_models_from_cache()
async with httpx.AsyncClient() as client:
response = await client.get(
f"{OLLAMA_URL}/api/tags",
timeout=TIMEOUT_SECONDS # 使用较短的超时时间
)
# 更新缓存并返回最新数据
await update_models_cache(response.json())
return response.json()
except (httpx.TimeoutException, httpx.ConnectError) as e:
# 发生超时或连接错误时,触发唤醒
logger.warning(f"获取标签列表失败,正在唤醒服务器: {str(e)}")
asyncio.create_task(wake_ollama())
# 如果有缓存,返回缓存数据
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 如果没有缓存返回503
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except Exception as e:
logger.error(f"获取标签列表时发生未知错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(request: Request, path: str):
# 避免代理 /health 请求
if path == "health":
return await health_check()
# 其他请求的处理逻辑
if await should_wake():
logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求")
await wake_ollama()
try:
target_url = f"{OLLAMA_URL}/{path}"
headers = dict(request.headers)
headers.pop('host', None)
headers.pop('connection', None)
# 移除可能导致问题的头部
headers.pop('content-length', None)
headers.pop('transfer-encoding', None)
# 根据请求类型选择不同的超时时间
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
# 检查是否为生成相关的端点
is_generate_endpoint = path in ["api/generate", "api/chat"]
if is_generate_endpoint and request.method == "POST":
request_body = await request.json()
# 强制设置stream为true以启用流式传输
request_body["stream"] = True
async def generate_stream():
client = httpx.AsyncClient()
try:
async with client.stream(
method=request.method,
url=target_url,
json=request_body,
headers=headers,
timeout=timeout
) as response:
async for line in response.aiter_lines():
if line.strip(): # 忽略空行
yield line.encode('utf-8') + b'\n'
except Exception as e:
logger.error(f"流式传输时发生错误: {str(e)}")
raise
finally:
await client.aclose()
return StreamingResponse(
generate_stream(),
media_type="application/x-ndjson",
headers={'Transfer-Encoding': 'chunked'} # 使用分块传输编码
)
else:
# 非生成请求的处理
async with httpx.AsyncClient() as client:
body = await request.body()
response = await client.request(
method=request.method,
url=target_url,
content=body,
headers=headers,
timeout=timeout,
follow_redirects=True
)
# 如果是标签列表请求且成功,更新缓存
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
await update_models_cache(response.json())
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.TimeoutException:
logger.warning("Ollama服务器超时发送唤醒请求")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 直接异步发送唤醒请求,不等待结果
asyncio.create_task(wake_ollama())
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
return JSONResponse(
status_code=502,
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
)
except Exception as e:
logger.error(f"代理请求失败: {str(e)}")
return JSONResponse(
status_code=500,
content={"message": f"代理请求失败: {str(e)}"}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT)