mirror of
https://github.com/yshtcn/OllamaProxy.git
synced 2025-12-11 16:50:27 +08:00
feat: 增强缓存功能
1. 添加可配置的缓存时间参数(--cache-duration) 2. 默认缓存时间从30分钟改为1天(1440分钟) 3. 支持通过环境变量CACHE_DURATION配置 4. 更新文档和配置示例 5. 修复了流式传输的问题
This commit is contained in:
parent
b8067e28a1
commit
a40fbadf7b
@ -2,4 +2,6 @@ OLLAMA_URL=http://your-ollama-server:11434
|
||||
WAKE_URL=http://your-wake-server:9090/wol?mac=XX:XX:XX:XX:XX:XX
|
||||
TIMEOUT_SECONDS=1
|
||||
PORT=11434
|
||||
MODEL_TIMEOUT_SECONDS=30 # 模型推理请求的超时时间(秒)
|
||||
MODEL_TIMEOUT_SECONDS=30 # 模型推理请求的超时时间(秒)
|
||||
WAKE_INTERVAL=10 # 唤醒间隔时间(分钟)
|
||||
CACHE_DURATION=1440 # 模型列表缓存有效期(分钟,默认1天)
|
||||
13
README.md
13
README.md
@ -43,11 +43,11 @@ Ollama Proxy 是一个为 Ollama 服务设计的智能代理服务器,它提
|
||||
|
||||
### 3. 模型列表缓存
|
||||
- 缓存 `/api/tags` 接口返回的模型列表
|
||||
- 缓存有效期为30分钟
|
||||
- 当主服务不可用时返回缓存数据
|
||||
- 可配置缓存有效期,默认为1440分钟(1天)
|
||||
- 当主服务不可用时返回缓存数据,确保客户端始终可以获取模型列表
|
||||
|
||||
### 4. 健康检查
|
||||
- 提供 `/health` 端点进行健康状态检查
|
||||
- 提供 ` ` 端点进行健康状态检查
|
||||
- Docker 容器集成了健康检查配置
|
||||
|
||||
## 配置参数
|
||||
@ -62,6 +62,7 @@ Ollama Proxy 是一个为 Ollama 服务设计的智能代理服务器,它提
|
||||
| `--model-timeout` | `MODEL_TIMEOUT_SECONDS` | 模型推理请求超时时间(秒) | 30 |
|
||||
| `--port` | `PORT` | 代理服务器端口 | 11434 |
|
||||
| `--wake-interval` | `WAKE_INTERVAL` | 唤醒间隔时间(分钟) | 10 |
|
||||
| `--cache-duration` | `CACHE_DURATION` | 模型列表缓存有效期(分钟) | 1440 |
|
||||
|
||||
## 部署方式
|
||||
|
||||
@ -81,6 +82,9 @@ docker run -d \
|
||||
-e OLLAMA_URL=http://localhost:11434 \
|
||||
-e WAKE_URL=http://localhost:11434/api/generate \
|
||||
-e TIMEOUT_SECONDS=10 \
|
||||
-e MODEL_TIMEOUT_SECONDS=30 \
|
||||
-e WAKE_INTERVAL=10 \
|
||||
-e CACHE_DURATION=1440 \
|
||||
-e PORT=11434 \
|
||||
yshtcn/ollama-proxy:latest
|
||||
```
|
||||
@ -98,6 +102,9 @@ python ollama_proxy.py \
|
||||
--ollama-url http://localhost:11434 \
|
||||
--wake-url http://localhost:11434/api/generate \
|
||||
--timeout 10 \
|
||||
--model-timeout 30 \
|
||||
--wake-interval 10 \
|
||||
--cache-duration 1440 \
|
||||
--port 11434
|
||||
```
|
||||
|
||||
|
||||
172
ollama_proxy.py
172
ollama_proxy.py
@ -1,5 +1,5 @@
|
||||
from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import httpx
|
||||
import asyncio
|
||||
import logging
|
||||
@ -7,6 +7,7 @@ import os
|
||||
import argparse
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -20,6 +21,7 @@ parser.add_argument('--timeout', type=int, help='简单请求的超时时间(秒
|
||||
parser.add_argument('--model-timeout', type=int, help='模型推理请求的超时时间(秒)')
|
||||
parser.add_argument('--port', type=int, help='代理服务器端口')
|
||||
parser.add_argument('--wake-interval', type=int, default=10, help='唤醒间隔时间(分钟)')
|
||||
parser.add_argument('--cache-duration', type=int, help='模型列表缓存有效期(分钟),默认1440分钟(1天)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -30,6 +32,7 @@ TIMEOUT_SECONDS = os.getenv('TIMEOUT_SECONDS') or args.timeout
|
||||
MODEL_TIMEOUT_SECONDS = int(os.getenv('MODEL_TIMEOUT_SECONDS') or args.model_timeout or 30) # 默认30秒
|
||||
PORT = os.getenv('PORT') or args.port
|
||||
WAKE_INTERVAL = int(os.getenv('WAKE_INTERVAL') or args.wake_interval)
|
||||
CACHE_DURATION = int(os.getenv('CACHE_DURATION') or args.cache_duration or 1440) # 默认1天
|
||||
|
||||
# 检查必要参数
|
||||
missing_params = []
|
||||
@ -61,7 +64,6 @@ last_wake_time = None
|
||||
# 添加缓存相关的变量
|
||||
models_cache = None
|
||||
models_cache_time = None
|
||||
CACHE_DURATION = timedelta(minutes=30) # 缓存有效期30分钟
|
||||
|
||||
async def should_wake():
|
||||
"""检查是否需要发送唤醒请求"""
|
||||
@ -86,7 +88,7 @@ async def get_models_from_cache():
|
||||
global models_cache, models_cache_time
|
||||
if models_cache is None or models_cache_time is None:
|
||||
return None
|
||||
if datetime.now() - models_cache_time > CACHE_DURATION:
|
||||
if datetime.now() - models_cache_time > timedelta(minutes=CACHE_DURATION):
|
||||
return None
|
||||
return models_cache
|
||||
|
||||
@ -105,6 +107,7 @@ logger.info(f"TIMEOUT_SECONDS: {TIMEOUT_SECONDS}")
|
||||
logger.info(f"MODEL_TIMEOUT_SECONDS: {MODEL_TIMEOUT_SECONDS}")
|
||||
logger.info(f"PORT: {PORT}")
|
||||
logger.info(f"WAKE_INTERVAL: {WAKE_INTERVAL} minutes")
|
||||
logger.info(f"CACHE_DURATION: {CACHE_DURATION} minutes")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -159,72 +162,109 @@ async def proxy(request: Request, path: str):
|
||||
logger.info("距离上次唤醒已超过设定时间,发送预防性唤醒请求")
|
||||
await wake_ollama()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
target_url = f"{OLLAMA_URL}/{path}"
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop('host', None)
|
||||
headers.pop('connection', None)
|
||||
|
||||
# 根据请求类型选择不同的超时时间
|
||||
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
|
||||
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
content=body,
|
||||
headers=headers,
|
||||
timeout=timeout, # 使用动态超时时间
|
||||
follow_redirects=True
|
||||
)
|
||||
|
||||
# 如果是标签列表请求且成功,更新缓存
|
||||
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
|
||||
await update_models_cache(response.json())
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Ollama服务器超时,发送唤醒请求")
|
||||
# 如果是标签列表请求,尝试返回缓存
|
||||
if path == "api/tags" and request.method == "GET":
|
||||
cached_models = await get_models_from_cache()
|
||||
if cached_models is not None:
|
||||
logger.info("返回缓存的标签列表")
|
||||
return JSONResponse(content=cached_models)
|
||||
|
||||
# 直接异步发送唤醒请求,不等待结果
|
||||
asyncio.create_task(wake_ollama())
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"message": "服务器正在唤醒中,请稍后重试"}
|
||||
)
|
||||
try:
|
||||
target_url = f"{OLLAMA_URL}/{path}"
|
||||
headers = dict(request.headers)
|
||||
headers.pop('host', None)
|
||||
headers.pop('connection', None)
|
||||
# 移除可能导致问题的头部
|
||||
headers.pop('content-length', None)
|
||||
headers.pop('transfer-encoding', None)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求错误: {str(e)}")
|
||||
# 如果是标签列表请求,尝试返回缓存
|
||||
if path == "api/tags" and request.method == "GET":
|
||||
cached_models = await get_models_from_cache()
|
||||
if cached_models is not None:
|
||||
logger.info("返回缓存的标签列表")
|
||||
return JSONResponse(content=cached_models)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=502,
|
||||
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
|
||||
# 根据请求类型选择不同的超时时间
|
||||
timeout = TIMEOUT_SECONDS if path == "api/tags" else MODEL_TIMEOUT_SECONDS
|
||||
|
||||
# 检查是否为生成相关的端点
|
||||
is_generate_endpoint = path in ["api/generate", "api/chat"]
|
||||
|
||||
if is_generate_endpoint and request.method == "POST":
|
||||
request_body = await request.json()
|
||||
# 强制设置stream为true以启用流式传输
|
||||
request_body["stream"] = True
|
||||
|
||||
async def generate_stream():
|
||||
client = httpx.AsyncClient()
|
||||
try:
|
||||
async with client.stream(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
json=request_body,
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.strip(): # 忽略空行
|
||||
yield line.encode('utf-8') + b'\n'
|
||||
except Exception as e:
|
||||
logger.error(f"流式传输时发生错误: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={'Transfer-Encoding': 'chunked'} # 使用分块传输编码
|
||||
)
|
||||
else:
|
||||
# 非生成请求的处理
|
||||
async with httpx.AsyncClient() as client:
|
||||
body = await request.body()
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
content=body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
follow_redirects=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"代理请求失败: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"message": f"代理请求失败: {str(e)}"}
|
||||
)
|
||||
# 如果是标签列表请求且成功,更新缓存
|
||||
if path == "api/tags" and request.method == "GET" and response.status_code == 200:
|
||||
await update_models_cache(response.json())
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Ollama服务器超时,发送唤醒请求")
|
||||
# 如果是标签列表请求,尝试返回缓存
|
||||
if path == "api/tags" and request.method == "GET":
|
||||
cached_models = await get_models_from_cache()
|
||||
if cached_models is not None:
|
||||
logger.info("返回缓存的标签列表")
|
||||
return JSONResponse(content=cached_models)
|
||||
|
||||
# 直接异步发送唤醒请求,不等待结果
|
||||
asyncio.create_task(wake_ollama())
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"message": "服务器正在唤醒中,请稍后重试"}
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求错误: {str(e)}")
|
||||
# 如果是标签列表请求,尝试返回缓存
|
||||
if path == "api/tags" and request.method == "GET":
|
||||
cached_models = await get_models_from_cache()
|
||||
if cached_models is not None:
|
||||
logger.info("返回缓存的标签列表")
|
||||
return JSONResponse(content=cached_models)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=502,
|
||||
content={"message": f"无法连接到Ollama服务器: {str(e)}"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"代理请求失败: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"message": f"代理请求失败: {str(e)}"}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user