From a40fbadf7b68236f4d79f4f51947930844fda2f6 Mon Sep 17 00:00:00 2001 From: yshtcn Date: Mon, 27 Jan 2025 18:35:30 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 添加可配置的缓存时间参数(--cache-duration) 2. 默认缓存时间从30分钟改为1天(1440分钟) 3. 支持通过环境变量CACHE_DURATION配置 4. 更新文档和配置示例 5. 修复了流式传输的问题 --- .env.example | 4 +- README.md | 13 +++- ollama_proxy.py | 172 +++++++++++++++++++++++++++++------------------- 3 files changed, 119 insertions(+), 70 deletions(-) diff --git a/.env.example b/.env.example index 40c13e0..0497173 100644 --- a/.env.example +++ b/.env.example @@ -2,4 +2,6 @@ OLLAMA_URL=http://your-ollama-server:11434 WAKE_URL=http://your-wake-server:9090/wol?mac=XX:XX:XX:XX:XX:XX TIMEOUT_SECONDS=1 PORT=11434 -MODEL_TIMEOUT_SECONDS=30 # 模型推理请求的超时时间(秒) \ No newline at end of file +MODEL_TIMEOUT_SECONDS=30 # 模型推理请求的超时时间(秒) +WAKE_INTERVAL=10 # 唤醒间隔时间(分钟) +CACHE_DURATION=1440 # 模型列表缓存有效期(分钟,默认1天) \ No newline at end of file diff --git a/README.md b/README.md index c107c6a..5852d5f 100644 --- a/README.md +++ b/README.md @@ -43,11 +43,11 @@ Ollama Proxy 是一个为 Ollama 服务设计的智能代理服务器,它提 ### 3. 模型列表缓存 - 缓存 `/api/tags` 接口返回的模型列表 -- 缓存有效期为30分钟 -- 当主服务不可用时返回缓存数据 +- 可配置缓存有效期,默认为1440分钟(1天) +- 当主服务不可用时返回缓存数据,确保客户端始终可以获取模型列表 ### 4. 健康检查 -- 提供 `/health` 端点进行健康状态检查 +- 提供 ` ` 端点进行健康状态检查 - Docker 容器集成了健康检查配置 ## 配置参数 @@ -62,6 +62,7 @@ Ollama Proxy 是一个为 Ollama 服务设计的智能代理服务器,它提 | `--model-timeout` | `MODEL_TIMEOUT_SECONDS` | 模型推理请求超时时间(秒) | 30 | | `--port` | `PORT` | 代理服务器端口 | 11434 | | `--wake-interval` | `WAKE_INTERVAL` | 唤醒间隔时间(分钟) | 10 | +| `--cache-duration` | `CACHE_DURATION` | 模型列表缓存有效期(分钟) | 1440 | ## 部署方式 @@ -81,6 +82,9 @@ docker run -d \ -e OLLAMA_URL=http://localhost:11434 \ -e WAKE_URL=http://localhost:11434/api/generate \ -e TIMEOUT_SECONDS=10 \ + -e MODEL_TIMEOUT_SECONDS=30 \ + -e WAKE_INTERVAL=10 \ + -e CACHE_DURATION=1440 \ -e PORT=11434 \ yshtcn/ollama-proxy:latest ``` @@ -98,6 +102,9 @@ python ollama_proxy.py \ --ollama-url http://localhost:11434 \ --wake-url http://localhost:11434/api/generate \ --timeout 10 \ + --model-timeout 30 \ + --wake-interval 10 \ + --cache-duration 1440 \ --port 11434 ``` diff --git a/ollama_proxy.py b/ollama_proxy.py index 9dbced6..2c89c4e 100644 --- a/ollama_proxy.py +++ b/ollama_proxy.py @@ -1,5 +1,5 @@ from fastapi import FastAPI, Request, Response, HTTPException -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse import httpx import asyncio import logging @@ -7,6 +7,7 @@ import os import argparse import sys from datetime import datetime, timedelta +import json # 配置日志 logging.basicConfig(level=logging.INFO) @@ -20,6 +21,7 @@ parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒 parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)') parser.add_argument('--port', type=int, help='代理服务器端口') parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)') +parser.add_argument('--cache-duration', type=int, help='模型列表缓存有效期(分钟),默认1440分钟(1天)') args = parser.parse_args() @@ -30,6 +32,7 @@ TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒 PORT = os.getenv('PORT') or args.port WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval) +CACHE_DURATION = int(os.getenv('CACHE_DURATION') or args.cache_duration or 1440) # 默认1天 # 检查必要参数 missing_params = [] @@ -61,7 +64,6 @@ last_wake_time = None # 添加缓存相关的变量 models_cache = None models_cache_time = None -CACHE_DURATION = timedelta(minutes=30) # 缓存有效期30分钟 async def should_wake(): """检查是否需要发送唤醒请求""" @@ -86,7 +88,7 @@ async def get_models_from_cache(): global models_cache, models_cache_time if models_cache is None or models_cache_time is None: return None - if datetime.now() - models_cache_time > CACHE_DURATION: + if datetime.now() - models_cache_time > timedelta(minutes=CACHE_DURATION): return None return models_cache @@ -105,6 +107,7 @@ logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}") logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}") logger.info(f"PORT: {PORT}") logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes") +logger.info(f"CACHE_DURATION: {CACHE_DURATION} minutes") app = FastAPI() @@ -159,72 +162,109 @@ async def proxy(request: Request, path: str): logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求") await wake_ollama() - async with httpx.AsyncClient() as client: - try: - target_url = f"{OLLAMA_URL}/{path}" - body = await request.body() - headers = dict(request.headers) - headers.pop('host', None) - headers.pop('connection', None) - - # 根据请求类型选择不同的超时时间 - timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS - - response = await client.request( - method=request.method, - url=target_url, - content=body, - headers=headers, - timeout=timeout, # 使用动态超时时间 - follow_redirects=True - ) - - # 如果是标签列表请求且成功,更新缓存 - if path == "api/tags" and request.method == "GET" and response.status_code == 200: - await update_models_cache(response.json()) - - return Response( - content=response.content, - status_code=response.status_code, - headers=dict(response.headers) - ) - - except httpx.TimeoutException: - logger.warning("Ollama服务器超时,发送唤醒请求") - # 如果是标签列表请求,尝试返回缓存 - if path == "api/tags" and request.method == "GET": - cached_models = await get_models_from_cache() - if cached_models is not None: - logger.info("返回缓存的标签列表") - return JSONResponse(content=cached_models) - - # 直接异步发送唤醒请求,不等待结果 - asyncio.create_task(wake_ollama()) - return JSONResponse( - status_code=503, - content={"message": "服务器正在唤醒中,请稍后重试"} - ) + try: + target_url = f"{OLLAMA_URL}/{path}" + headers = dict(request.headers) + headers.pop('host', None) + headers.pop('connection', None) + # 移除可能导致问题的头部 + headers.pop('content-length', None) + headers.pop('transfer-encoding', None) - except httpx.RequestError as e: - logger.error(f"请求错误: {str(e)}") - # 如果是标签列表请求,尝试返回缓存 - if path == "api/tags" and request.method == "GET": - cached_models = await get_models_from_cache() - if cached_models is not None: - logger.info("返回缓存的标签列表") - return JSONResponse(content=cached_models) - - return JSONResponse( - status_code=502, - content={"message": f"无法连接到Ollama服务器: {str(e)}"} + # 根据请求类型选择不同的超时时间 + timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS + + # 检查是否为生成相关的端点 + is_generate_endpoint = path in ["api/generate", "api/chat"] + + if is_generate_endpoint and request.method == "POST": + request_body = await request.json() + # 强制设置stream为true以启用流式传输 + request_body["stream"] = True + + async def generate_stream(): + client = httpx.AsyncClient() + try: + async with client.stream( + method=request.method, + url=target_url, + json=request_body, + headers=headers, + timeout=timeout + ) as response: + async for line in response.aiter_lines(): + if line.strip(): # 忽略空行 + yield line.encode('utf-8') + b'\n' + except Exception as e: + logger.error(f"流式传输时发生错误: {str(e)}") + raise + finally: + await client.aclose() + + return StreamingResponse( + generate_stream(), + media_type="application/x-ndjson", + headers={'Transfer-Encoding': 'chunked'} # 使用分块传输编码 ) + else: + # 非生成请求的处理 + async with httpx.AsyncClient() as client: + body = await request.body() + response = await client.request( + method=request.method, + url=target_url, + content=body, + headers=headers, + timeout=timeout, + follow_redirects=True + ) - except Exception as e: - logger.error(f"代理请求失败: {str(e)}") - return JSONResponse( - status_code=500, - content={"message": f"代理请求失败: {str(e)}"} - ) + # 如果是标签列表请求且成功,更新缓存 + if path == "api/tags" and request.method == "GET" and response.status_code == 200: + await update_models_cache(response.json()) + + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers) + ) + + except httpx.TimeoutException: + logger.warning("Ollama服务器超时,发送唤醒请求") + # 如果是标签列表请求,尝试返回缓存 + if path == "api/tags" and request.method == "GET": + cached_models = await get_models_from_cache() + if cached_models is not None: + logger.info("返回缓存的标签列表") + return JSONResponse(content=cached_models) + + # 直接异步发送唤醒请求,不等待结果 + asyncio.create_task(wake_ollama()) + return JSONResponse( + status_code=503, + content={"message": "服务器正在唤醒中,请稍后重试"} + ) + + except httpx.RequestError as e: + logger.error(f"请求错误: {str(e)}") + # 如果是标签列表请求,尝试返回缓存 + if path == "api/tags" and request.method == "GET": + cached_models = await get_models_from_cache() + if cached_models is not None: + logger.info("返回缓存的标签列表") + return JSONResponse(content=cached_models) + + return JSONResponse( + status_code=502, + content={"message": f"无法连接到Ollama服务器: {str(e)}"} + ) + + except Exception as e: + logger.error(f"代理请求失败: {str(e)}") + return JSONResponse( + status_code=500, + content={"message": f"代理请求失败: {str(e)}"} + ) if __name__ == "__main__": import uvicorn