mirror of
https://github.com/yshtcn/OllamaProxy.git
synced 2026-03-18 00:26:27 +08:00
Initial commit: Add Ollama Proxy project files
This commit is contained in:
231
ollama_proxy.py
Normal file
231
ollama_proxy.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description='Ollama代理服务器')
|
||||
parser.add_argument('--ollama-url', help='Ollama服务器URL')
|
||||
parser.add_argument('--wake-url', help='唤醒服务器URL')
|
||||
parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒)')
|
||||
parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)')
|
||||
parser.add_argument('--port', type=int, help='代理服务器端口')
|
||||
parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 配置常量,优先使用环境变量,其次使用命令行参数
|
||||
OLLAMA_URL = os.getenv('OLLAMA_URL') or args.ollama_url
|
||||
WAKE_URL = os.getenv('WAKE_URL') or args.wake_url
|
||||
TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout
|
||||
MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒
|
||||
PORT = os.getenv('PORT') or args.port
|
||||
WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval)
|
||||
|
||||
# 检查必要参数
|
||||
missing_params = []
|
||||
if not OLLAMA_URL:
|
||||
missing_params.append("OLLAMA_URL")
|
||||
if not WAKE_URL:
|
||||
missing_params.append("WAKE_URL")
|
||||
if not TIMEOUT_SECONDS:
|
||||
missing_params.append("TIMEOUT_SECONDS")
|
||||
if not PORT:
|
||||
missing_params.append("PORT")
|
||||
|
||||
if missing_params:
|
||||
logger.error(f"缺少必要参数: {', '.join(missing_params)}")
|
||||
logger.error("请通过环境变量或命令行参数指定这些值")
|
||||
sys.exit(1)
|
||||
|
||||
# 确保数值类型正确
|
||||
try:
|
||||
TIMEOUT_SECONDS = int(TIMEOUT_SECONDS)
|
||||
PORT = int(PORT)
|
||||
except ValueError as e:
|
||||
logger.error("TIMEOUT_SECONDS 和 PORT 必须是整数")
|
||||
sys.exit(1)
|
||||
|
||||
# 添加上次唤醒时间的全局变量
|
||||
last_wake_time = None
|
||||
|
||||
# 添加缓存相关的变量
|
||||
models_cache = None
|
||||
models_cache_time = None
|
||||
CACHE_DURATION = timedelta(minutes=30) # 缓存有效期30分钟
|
||||
|
||||
async def should_wake():
|
||||
"""检查是否需要发送唤醒请求"""
|
||||
global last_wake_time
|
||||
if last_wake_time is None:
|
||||
return True
|
||||
return datetime.now() - last_wake_time > timedelta(minutes=WAKE_INTERVAL)
|
||||
|
||||
async def wake_ollama():
|
||||
"""唤醒 Ollama 服务器"""
|
||||
global last_wake_time
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.get(WAKE_URL)
|
||||
last_wake_time = datetime.now()
|
||||
logger.info(f"已发送唤醒请求,更新唤醒时间: {last_wake_time}")
|
||||
except Exception as e:
|
||||
logger.error(f"唤醒请求失败: {str(e)}")
|
||||
|
||||
async def get_models_from_cache():
|
||||
"""从缓存获取模型列表"""
|
||||
global models_cache, models_cache_time
|
||||
if models_cache is None or models_cache_time is None:
|
||||
return None
|
||||
if datetime.now() - models_cache_time > CACHE_DURATION:
|
||||
return None
|
||||
return models_cache
|
||||
|
||||
async def update_models_cache(data):
|
||||
"""更新模型列表缓存"""
|
||||
global models_cache, models_cache_time
|
||||
models_cache = data
|
||||
models_cache_time = datetime.now()
|
||||
logger.info("模型列表缓存已更新")
|
||||
|
||||
# 输出当前配置
|
||||
logger.info(f"使用配置:")
|
||||
logger.info(f"OLLAMA_URL: {OLLAMA_URL}")
|
||||
logger.info(f"WAKE_URL: {WAKE_URL}")
|
||||
logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}")
|
||||
logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}")
|
||||
logger.info(f"PORT: {PORT}")
|
||||
logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
logger.info("收到健康检查请求")
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.get("/api/tags")
|
||||
async def list_models():
|
||||
try:
|
||||
# 首先尝试从缓存获取
|
||||
cached_models = await get_models_from_cache()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{OLLAMA_URL}/api/tags",
|
||||
timeout=TIMEOUT_SECONDS # 使用较短的超时时间
|
||||
)
|
||||
# 更新缓存并返回最新数据
|
||||
await update_models_cache(response.json())
|
||||
return response.json()
|
||||
|
||||
except (httpx.TimeoutException, httpx.ConnectError) as e:
|
||||
# 发生超时或连接错误时,触发唤醒
|
||||
logger.warning(f"获取标签列表失败,正在唤醒服务器: {str(e)}")
|
||||
asyncio.create_task(wake_ollama())
|
||||
|
||||
# 如果有缓存,返回缓存数据
|
||||
if cached_models is not None:
|
||||
logger.info("返回缓存的标签列表")
|
||||
return JSONResponse(content=cached_models)
|
||||
|
||||
# 如果没有缓存,返回503
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"message": "服务器正在唤醒中,请稍后重试"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取标签列表时发生未知错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(request: Request, path: str):
|
||||
# 避免代理 /health 请求
|
||||
if path == "health":
|
||||
return await health_check()
|
||||
|
||||
# 其他请求的处理逻辑
|
||||
if await should_wake():
|
||||
logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求")
|
||||
await wake_ollama()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
target_url = f"{OLLAMA_URL}/{path}"
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop('host', None)
|
||||
headers.pop('connection', None)
|
||||
|
||||
# 根据请求类型选择不同的超时时间
|
||||
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
|
||||
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
content=body,
|
||||
headers=headers,
|
||||
timeout=timeout, # 使用动态超时时间
|
||||
follow_redirects=True
|
||||
)
|
||||
|
||||
# 如果是标签列表请求且成功,更新缓存
|
||||
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
|
||||
await update_models_cache(response.json())
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Ollama服务器超时,发送唤醒请求")
|
||||
# 如果是标签列表请求,尝试返回缓存
|
||||
if path == "api/tags" and request.method == "GET":
|
||||
cached_models = await get_models_from_cache()
|
||||
if cached_models is not None:
|
||||
logger.info("返回缓存的标签列表")
|
||||
return JSONResponse(content=cached_models)
|
||||
|
||||
# 直接异步发送唤醒请求,不等待结果
|
||||
asyncio.create_task(wake_ollama())
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"message": "服务器正在唤醒中,请稍后重试"}
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求错误: {str(e)}")
|
||||
# 如果是标签列表请求,尝试返回缓存
|
||||
if path == "api/tags" and request.method == "GET":
|
||||
cached_models = await get_models_from_cache()
|
||||
if cached_models is not None:
|
||||
logger.info("返回缓存的标签列表")
|
||||
return JSONResponse(content=cached_models)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=502,
|
||||
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"代理请求失败: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"message": f"代理请求失败: {str(e)}"}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=PORT)
|
||||
Reference in New Issue
Block a user