feat: 增强缓存功能

1. 添加可配置的缓存时间参数(--cache-duration)
2. 默认缓存时间从30分钟改为1天(1440分钟)
3. 支持通过环境变量CACHE_DURATION配置
4. 更新文档和配置示例
5. 修复了流式传输的问题
This commit is contained in:
yshtcn
2025-01-27 18:35:30 +08:00
parent b8067e28a1
commit a40fbadf7b
3 changed files with 119 additions and 70 deletions

View File

@@ -1,5 +1,5 @@
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
import httpx
import asyncio
import logging
@@ -7,6 +7,7 @@ import os
import argparse
import sys
from datetime import datetime, timedelta
import json
# 配置日志
logging.basicConfig(level=logging.INFO)
@@ -20,6 +21,7 @@ parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒
parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)')
parser.add_argument('--port', type=int, help='代理服务器端口')
parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)')
parser.add_argument('--cache-duration', type=int, help='模型列表缓存有效期(分钟)默认1440分钟(1天)')
args = parser.parse_args()
@@ -30,6 +32,7 @@ TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout
MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒
PORT = os.getenv('PORT') or args.port
WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval)
CACHE_DURATION = int(os.getenv('CACHE_DURATION') or args.cache_duration or 1440) # 默认1天
# 检查必要参数
missing_params = []
@@ -61,7 +64,6 @@ last_wake_time = None
# 添加缓存相关的变量
models_cache = None
models_cache_time = None
CACHE_DURATION = timedelta(minutes=30) # 缓存有效期30分钟
async def should_wake():
"""检查是否需要发送唤醒请求"""
@@ -86,7 +88,7 @@ async def get_models_from_cache():
global models_cache, models_cache_time
if models_cache is None or models_cache_time is None:
return None
if datetime.now() - models_cache_time > CACHE_DURATION:
if datetime.now() - models_cache_time > timedelta(minutes=CACHE_DURATION):
return None
return models_cache
@@ -105,6 +107,7 @@ logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}")
logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}")
logger.info(f"PORT: {PORT}")
logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes")
logger.info(f"CACHE_DURATION: {CACHE_DURATION} minutes")
app = FastAPI()
@@ -159,72 +162,109 @@ async def proxy(request: Request, path: str):
logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求")
await wake_ollama()
async with httpx.AsyncClient() as client:
try:
target_url = f"{OLLAMA_URL}/{path}"
body = await request.body()
headers = dict(request.headers)
headers.pop('host', None)
headers.pop('connection', None)
# 根据请求类型选择不同的超时时间
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
response = await client.request(
method=request.method,
url=target_url,
content=body,
headers=headers,
timeout=timeout, # 使用动态超时时间
follow_redirects=True
)
# 如果是标签列表请求且成功,更新缓存
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
await update_models_cache(response.json())
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.TimeoutException:
logger.warning("Ollama服务器超时发送唤醒请求")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 直接异步发送唤醒请求,不等待结果
asyncio.create_task(wake_ollama())
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
try:
target_url = f"{OLLAMA_URL}/{path}"
headers = dict(request.headers)
headers.pop('host', None)
headers.pop('connection', None)
# 移除可能导致问题的头部
headers.pop('content-length', None)
headers.pop('transfer-encoding', None)
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
return JSONResponse(
status_code=502,
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
# 根据请求类型选择不同的超时时间
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
# 检查是否为生成相关的端点
is_generate_endpoint = path in ["api/generate", "api/chat"]
if is_generate_endpoint and request.method == "POST":
request_body = await request.json()
# 强制设置stream为true以启用流式传输
request_body["stream"] = True
async def generate_stream():
client = httpx.AsyncClient()
try:
async with client.stream(
method=request.method,
url=target_url,
json=request_body,
headers=headers,
timeout=timeout
) as response:
async for line in response.aiter_lines():
if line.strip(): # 忽略空行
yield line.encode('utf-8') + b'\n'
except Exception as e:
logger.error(f"流式传输时发生错误: {str(e)}")
raise
finally:
await client.aclose()
return StreamingResponse(
generate_stream(),
media_type="application/x-ndjson",
headers={'Transfer-Encoding': 'chunked'} # 使用分块传输编码
)
else:
# 非生成请求的处理
async with httpx.AsyncClient() as client:
body = await request.body()
response = await client.request(
method=request.method,
url=target_url,
content=body,
headers=headers,
timeout=timeout,
follow_redirects=True
)
except Exception as e:
logger.error(f"代理请求失败: {str(e)}")
return JSONResponse(
status_code=500,
content={"message": f"代理请求失败: {str(e)}"}
)
# 如果是标签列表请求且成功,更新缓存
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
await update_models_cache(response.json())
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.TimeoutException:
logger.warning("Ollama服务器超时发送唤醒请求")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
# 直接异步发送唤醒请求,不等待结果
asyncio.create_task(wake_ollama())
return JSONResponse(
status_code=503,
content={"message": "服务器正在唤醒中,请稍后重试"}
)
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
# 如果是标签列表请求,尝试返回缓存
if path == "api/tags" and request.method == "GET":
cached_models = await get_models_from_cache()
if cached_models is not None:
logger.info("返回缓存的标签列表")
return JSONResponse(content=cached_models)
return JSONResponse(
status_code=502,
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
)
except Exception as e:
logger.error(f"代理请求失败: {str(e)}")
return JSONResponse(
status_code=500,
content={"message": f"代理请求失败: {str(e)}"}
)
if __name__ == "__main__":
import uvicorn