feat: 增强缓存功能

1. 添加可配置的缓存时间参数(--cache-duration)
2. 默认缓存时间从30分钟改为1天(1440分钟)
3. 支持通过环境变量CACHE_DURATION配置
4. 更新文档和配置示例
5. 修复了流式传输的问题
This commit is contained in:
yshtcn 2025-01-27 18:35:30 +08:00
parent b8067e28a1
commit a40fbadf7b
3 changed files with 119 additions and 70 deletions

View File

@ -2,4 +2,6 @@ OLLAMA_URL=http://your-ollama-server:11434
WAKE_URL=http://your-wake-server:9090/wol?mac=XX:XX:XX:XX:XX:XX WAKE_URL=http://your-wake-server:9090/wol?mac=XX:XX:XX:XX:XX:XX
TIMEOUT_SECONDS=1 TIMEOUT_SECONDS=1
PORT=11434 PORT=11434
MODEL_TIMEOUT_SECONDS=30 # 模型推理请求的超时时间(秒) MODEL_TIMEOUT_SECONDS=30 # 模型推理请求的超时时间(秒)
WAKE_INTERVAL=10 # 唤醒间隔时间(分钟)
CACHE_DURATION=1440 # 模型列表缓存有效期分钟默认1天

View File

@ -43,11 +43,11 @@ Ollama Proxy 是一个为 Ollama 服务设计的智能代理服务器,它提
### 3. 模型列表缓存 ### 3. 模型列表缓存
- 缓存 `/api/tags` 接口返回的模型列表 - 缓存 `/api/tags` 接口返回的模型列表
- 缓存有效期为30分钟 - 可配置缓存有效期默认为1440分钟1天
- 当主服务不可用时返回缓存数据 - 当主服务不可用时返回缓存数据,确保客户端始终可以获取模型列表
### 4. 健康检查 ### 4. 健康检查
- 提供 `/health` 端点进行健康状态检查 - 提供 ` ` 端点进行健康状态检查
- Docker 容器集成了健康检查配置 - Docker 容器集成了健康检查配置
## 配置参数 ## 配置参数
@ -62,6 +62,7 @@ Ollama Proxy 是一个为 Ollama 服务设计的智能代理服务器,它提
| `--model-timeout` | `MODEL_TIMEOUT_SECONDS` | 模型推理请求超时时间(秒) | 30 | | `--model-timeout` | `MODEL_TIMEOUT_SECONDS` | 模型推理请求超时时间(秒) | 30 |
| `--port` | `PORT` | 代理服务器端口 | 11434 | | `--port` | `PORT` | 代理服务器端口 | 11434 |
| `--wake-interval` | `WAKE_INTERVAL` | 唤醒间隔时间(分钟) | 10 | | `--wake-interval` | `WAKE_INTERVAL` | 唤醒间隔时间(分钟) | 10 |
| `--cache-duration` | `CACHE_DURATION` | 模型列表缓存有效期(分钟) | 1440 |
## 部署方式 ## 部署方式
@ -81,6 +82,9 @@ docker run -d \
-e OLLAMA_URL=http://localhost:11434 \ -e OLLAMA_URL=http://localhost:11434 \
-e WAKE_URL=http://localhost:11434/api/generate \ -e WAKE_URL=http://localhost:11434/api/generate \
-e TIMEOUT_SECONDS=10 \ -e TIMEOUT_SECONDS=10 \
-e MODEL_TIMEOUT_SECONDS=30 \
-e WAKE_INTERVAL=10 \
-e CACHE_DURATION=1440 \
-e PORT=11434 \ -e PORT=11434 \
yshtcn/ollama-proxy:latest yshtcn/ollama-proxy:latest
``` ```
@ -98,6 +102,9 @@ python ollama_proxy.py \
--ollama-url http://localhost:11434 \ --ollama-url http://localhost:11434 \
--wake-url http://localhost:11434/api/generate \ --wake-url http://localhost:11434/api/generate \
--timeout 10 \ --timeout 10 \
--model-timeout 30 \
--wake-interval 10 \
--cache-duration 1440 \
--port 11434 --port 11434
``` ```

View File

@ -1,5 +1,5 @@
from fastapi import FastAPI, Request, Response, HTTPException from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse, StreamingResponse
import httpx import httpx
import asyncio import asyncio
import logging import logging
@ -7,6 +7,7 @@ import os
import argparse import argparse
import sys import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json
# 配置日志 # 配置日志
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -20,6 +21,7 @@ parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒
parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)') parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)')
parser.add_argument('--port', type=int, help='代理服务器端口') parser.add_argument('--port', type=int, help='代理服务器端口')
parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)') parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)')
parser.add_argument('--cache-duration', type=int, help='模型列表缓存有效期(分钟)默认1440分钟(1天)')
args = parser.parse_args() args = parser.parse_args()
@ -30,6 +32,7 @@ TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout
MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒 MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒
PORT = os.getenv('PORT') or args.port PORT = os.getenv('PORT') or args.port
WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval) WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval)
CACHE_DURATION = int(os.getenv('CACHE_DURATION') or args.cache_duration or 1440) # 默认1天
# 检查必要参数 # 检查必要参数
missing_params = [] missing_params = []
@ -61,7 +64,6 @@ last_wake_time = None
# 添加缓存相关的变量 # 添加缓存相关的变量
models_cache = None models_cache = None
models_cache_time = None models_cache_time = None
CACHE_DURATION = timedelta(minutes=30) # 缓存有效期30分钟
async def should_wake(): async def should_wake():
"""检查是否需要发送唤醒请求""" """检查是否需要发送唤醒请求"""
@ -86,7 +88,7 @@ async def get_models_from_cache():
global models_cache, models_cache_time global models_cache, models_cache_time
if models_cache is None or models_cache_time is None: if models_cache is None or models_cache_time is None:
return None return None
if datetime.now() - models_cache_time > CACHE_DURATION: if datetime.now() - models_cache_time > timedelta(minutes=CACHE_DURATION):
return None return None
return models_cache return models_cache
@ -105,6 +107,7 @@ logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}")
logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}") logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}")
logger.info(f"PORT: {PORT}") logger.info(f"PORT: {PORT}")
logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes") logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes")
logger.info(f"CACHE_DURATION: {CACHE_DURATION} minutes")
app = FastAPI() app = FastAPI()
@ -159,72 +162,109 @@ async def proxy(request: Request, path: str):
logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求") logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求")
await wake_ollama() await wake_ollama()
async with httpx.AsyncClient() as client: try:
try: target_url = f"{OLLAMA_URL}/{path}"
target_url = f"{OLLAMA_URL}/{path}" headers = dict(request.headers)
body = await request.body() headers.pop('host', None)
headers = dict(request.headers) headers.pop('connection', None)
headers.pop('host', None) # 移除可能导致问题的头部
headers.pop('connection', None) headers.pop('content-length', None)
headers.pop('transfer-encoding', None)
# 根据请求类型选择不同的超时时间
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
response = await client.request(
method=request.method,
url=target_url,
content=body,
headers=headers,
timeout=timeout, # 使用动态超时时间
follow_redirects=True
)
# 如果是标签列表请求且成功,更新缓存
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
await update_models_cache(response.json())
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.TimeoutException:
logger.warning("Ollama服务器超时发送唤醒请求")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 直接异步发送唤醒请求,不等待结果
asyncio.create_task(wake_ollama())
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except httpx.RequestError as e: # 根据请求类型选择不同的超时时间
logger.error(f"请求错误: {str(e)}") timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET": # 检查是否为生成相关的端点
cached_models = await get_models_from_cache() is_generate_endpoint = path in ["api/generate", "api/chat"]
if cached_models is not None:
logger.info("返回缓存的标签列表") if is_generate_endpoint and request.method == "POST":
return JSONResponse(content=cached_models) request_body = await request.json()
# 强制设置stream为true以启用流式传输
return JSONResponse( request_body["stream"] = True
status_code=502,
content={"message": f"无法连接到Ollama服务器: {str(e)}"} async def generate_stream():
client = httpx.AsyncClient()
try:
async with client.stream(
method=request.method,
url=target_url,
json=request_body,
headers=headers,
timeout=timeout
) as response:
async for line in response.aiter_lines():
if line.strip(): # 忽略空行
yield line.encode('utf-8') + b'\n'
except Exception as e:
logger.error(f"流式传输时发生错误: {str(e)}")
raise
finally:
await client.aclose()
return StreamingResponse(
generate_stream(),
media_type="application/x-ndjson",
headers={'Transfer-Encoding': 'chunked'} # 使用分块传输编码
) )
else:
# 非生成请求的处理
async with httpx.AsyncClient() as client:
body = await request.body()
response = await client.request(
method=request.method,
url=target_url,
content=body,
headers=headers,
timeout=timeout,
follow_redirects=True
)
except Exception as e: # 如果是标签列表请求且成功,更新缓存
logger.error(f"代理请求失败: {str(e)}") if path == "api/tags" and request.method == "GET" and response.status_code == 200:
return JSONResponse( await update_models_cache(response.json())
status_code=500,
content={"message": f"代理请求失败: {str(e)}"} return Response(
) content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.TimeoutException:
logger.warning("Ollama服务器超时发送唤醒请求")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 直接异步发送唤醒请求,不等待结果
asyncio.create_task(wake_ollama())
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
return JSONResponse(
status_code=502,
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
)
except Exception as e:
logger.error(f"代理请求失败: {str(e)}")
return JSONResponse(
status_code=500,
content={"message": f"代理请求失败: {str(e)}"}
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn