工具系统设计:给AI Agent装上“手和脚”
·
AI Agent实战系列·第3篇
从简单计算器到Docker沙箱的进化之路
写在前面
上期回顾: 第2篇我们学习了ReAct模式,让Agent能够"边想边做"。但我们只用了3个简单工具,远远不够。
本期你将收获:
- ✅ 掌握工具系统的设计原则
- ✅ 实现5个生产级实用工具
- ✅ 学会Docker沙箱隔离技术
- ✅ 理解工具编排和错误处理
一、翻车现场:我的Agent把服务器删了
上周,我给Agent加了个"文件操作"工具,结果…
# 我写的工具(危险版本)
def file_operation(command: str):
"""执行文件操作"""
return os.system(command) # ❌ 巨大的安全隐患
# 某次对话
我:帮我清理一下临时文件
Agent:好的
→ 使用工具:file_operation("rm -rf /tmp/*")
→ 哦不,我理解错了...
→ 使用工具:file_operation("rm -rf /*") # 💀
# 服务器:💥
这个惨痛教训告诉我们:
- ❌ 工具设计不能随便
- ❌ 必须要有安全边界
- ❌ 需要严格的输入验证
- ❌ 要有回滚机制
今天,我们就来设计一个安全、健壮、可扩展的工具系统。
二、工具系统的5大设计原则
原则1️⃣:单一职责
每个工具只做一件事,做好一件事。
# ❌ 不好:一个工具做太多事
def super_tool(action, **kwargs):
if action == "calculate":
return eval(kwargs['expression'])
elif action == "search":
return search_web(kwargs['query'])
elif action == "read_file":
return read_file(kwargs['filename'])
# ... 100行
# ✅ 好:每个工具专注一个功能
class Calculator(Tool):
"""只负责计算"""
class WebSearch(Tool):
"""只负责搜索"""
class FileReader(Tool):
"""只负责读文件"""
原则2️⃣:明确的输入输出
# ❌ 不好:模糊的接口
def process(data):
"""处理数据""" # 什么数据?怎么处理?
return result # 什么结果?
# ✅ 好:清晰的接口
def calculate(expression: str) -> float:
"""
执行数学计算
Args:
expression: 数学表达式,如 "2 + 3 * 4"
Returns:
float: 计算结果
Raises:
ValueError: 表达式无效时
Example:
>>> calculate("2 + 3")
5.0
"""
原则3️⃣:防御性编程
假设一切输入都是恶意的。
class SafeTool(Tool):
def run(self, **kwargs):
# 1. 验证输入
self._validate_input(kwargs)
# 2. 清理数据
cleaned = self._sanitize(kwargs)
# 3. 限制执行时间
with timeout(self.max_time):
result = self._execute(cleaned)
# 4. 验证输出
return self._validate_output(result)
原则4️⃣:优雅的错误处理
class RobustTool(Tool):
def run(self, **kwargs):
try:
return self._execute(**kwargs)
except FileNotFoundError as e:
return f"错误:文件不存在 - {e.filename}"
except PermissionError:
return "错误:没有权限执行此操作"
except TimeoutError:
return "错误:操作超时"
except Exception as e:
# 记录详细日志
logger.error(f"工具执行失败: {e}", exc_info=True)
# 返回用户友好的信息
return f"错误:{type(e).__name__}"
原则5️⃣:可观测性
每个工具都应该记录日志、指标。
class ObservableTool(Tool):
def run(self, **kwargs):
start_time = time.time()
logger.info(f"工具开始执行: {self.name}", extra={
"tool": self.name,
"inputs": kwargs
})
try:
result = self._execute(**kwargs)
# 记录成功指标
metrics.increment(f"tool.{self.name}.success")
return result
except Exception as e:
# 记录失败指标
metrics.increment(f"tool.{self.name}.error")
raise
finally:
duration = time.time() - start_time
metrics.timing(f"tool.{self.name}.duration", duration)
三、完整的工具基类设计
基于上述原则,我们设计一个完善的工具基类:
"""
tool_base.py - 工具基类
"""
import time
import logging
from typing import Any, Dict, Optional
from abc import ABC, abstractmethod
import json
from functools import wraps
logger = logging.getLogger(__name__)
class ToolExecutionError(Exception):
"""工具执行错误"""
pass
class ToolTimeoutError(ToolExecutionError):
"""工具超时错误"""
pass
class ToolValidationError(ToolExecutionError):
"""工具验证错误"""
pass
def timeout(seconds):
"""超时装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
import signal
def handler(signum, frame):
raise ToolTimeoutError(f"操作超时({seconds}秒)")
# 设置超时
signal.signal(signal.SIGALRM, handler)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0) # 取消超时
return result
return wrapper
return decorator
class Tool(ABC):
"""工具基类"""
def __init__(
self,
name: str,
description: str,
max_execution_time: int = 30,
require_confirmation: bool = False
):
self.name = name
self.description = description
self.max_execution_time = max_execution_time
self.require_confirmation = require_confirmation
# 统计信息
self.execution_count = 0
self.success_count = 0
self.error_count = 0
self.total_time = 0.0
def run(self, input_str: str) -> str:
"""
执行工具(对外接口)
Args:
input_str: 输入参数(JSON字符串或普通字符串)
Returns:
str: 执行结果
"""
start_time = time.time()
self.execution_count += 1
logger.info(f"[{self.name}] 开始执行", extra={
"tool": self.name,
"input": input_str[:100] # 只记录前100个字符
})
try:
# 1. 解析输入
parsed_input = self._parse_input(input_str)
# 2. 验证输入
self._validate_input(parsed_input)
# 3. 确认(如果需要)
if self.require_confirmation:
if not self._confirm_execution(parsed_input):
return "操作已取消"
# 4. 执行(带超时)
@timeout(self.max_execution_time)
def execute():
return self._execute(**parsed_input)
result = execute()
# 5. 验证输出
validated_result = self._validate_output(result)
# 6. 记录成功
self.success_count += 1
execution_time = time.time() - start_time
self.total_time += execution_time
logger.info(f"[{self.name}] 执行成功", extra={
"tool": self.name,
"duration": execution_time,
"result_length": len(str(validated_result))
})
return validated_result
except ToolTimeoutError as e:
self.error_count += 1
logger.error(f"[{self.name}] 超时: {e}")
return f"错误:操作超时(>{self.max_execution_time}秒)"
except ToolValidationError as e:
self.error_count += 1
logger.error(f"[{self.name}] 验证失败: {e}")
return f"错误:{str(e)}"
except Exception as e:
self.error_count += 1
logger.error(f"[{self.name}] 执行失败", exc_info=True)
return f"错误:{type(e).__name__} - {str(e)}"
def _parse_input(self, input_str: str) -> Dict[str, Any]:
"""解析输入"""
# 尝试解析JSON
input_str = input_str.strip()
if input_str.startswith('{'):
try:
return json.loads(input_str)
except json.JSONDecodeError:
pass
# 否则作为单个参数
return {"input": input_str}
def _validate_input(self, inputs: Dict[str, Any]):
"""
验证输入(子类可覆盖)
Raises:
ToolValidationError: 验证失败时
"""
pass
def _validate_output(self, result: Any) -> str:
"""
验证输出(子类可覆盖)
Args:
result: 执行结果
Returns:
str: 验证后的结果字符串
"""
result_str = str(result)
# 限制输出长度
max_length = 5000
if len(result_str) > max_length:
result_str = result_str[:max_length] + f"\n... (截断,总长度{len(result_str)})"
return result_str
def _confirm_execution(self, inputs: Dict[str, Any]) -> bool:
"""
确认执行(危险操作时使用)
实际应该通过某种机制让用户确认
这里简化处理
"""
print(f"⚠️ 工具 {self.name} 需要确认")
print(f"参数: {inputs}")
# 实际场景应该等待用户输入
return True
@abstractmethod
def _execute(self, **kwargs) -> Any:
"""
实际执行逻辑(子类必须实现)
Returns:
Any: 执行结果
"""
raise NotImplementedError
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
avg_time = self.total_time / self.execution_count if self.execution_count > 0 else 0
success_rate = self.success_count / self.execution_count if self.execution_count > 0 else 0
return {
"name": self.name,
"executions": self.execution_count,
"success": self.success_count,
"errors": self.error_count,
"success_rate": f"{success_rate:.1%}",
"avg_time": f"{avg_time:.2f}s",
"total_time": f"{self.total_time:.2f}s"
}
def __str__(self):
return f"{self.name}: {self.description}"
四、5个生产级工具实现
现在基于这个基类,我们实现5个实用工具。
工具1:安全的文件读取器
"""
file_tools.py - 文件操作工具
"""
import os
from pathlib import Path
from tool_base import Tool, ToolValidationError
class FileReader(Tool):
"""安全的文件读取工具"""
def __init__(self, allowed_dirs: list = None):
super().__init__(
name="FileReader",
description="读取文件内容。输入:文件路径",
max_execution_time=10
)
# 限制可访问的目录
self.allowed_dirs = [Path(d).resolve() for d in (allowed_dirs or ['.'])]
def _validate_input(self, inputs: Dict[str, Any]):
"""验证文件路径"""
filepath = inputs.get('input') or inputs.get('filepath')
if not filepath:
raise ToolValidationError("必须提供文件路径")
# 检查路径遍历攻击
filepath = Path(filepath).resolve()
# 检查是否在允许的目录内
if not any(str(filepath).startswith(str(allowed)) for allowed in self.allowed_dirs):
raise ToolValidationError(f"文件路径不在允许的目录内")
# 检查文件是否存在
if not filepath.exists():
raise ToolValidationError(f"文件不存在: {filepath}")
# 检查是否为文件
if not filepath.is_file():
raise ToolValidationError(f"不是文件: {filepath}")
# 检查文件大小
max_size = 10 * 1024 * 1024 # 10MB
if filepath.stat().st_size > max_size:
raise ToolValidationError(f"文件太大(>{max_size/1024/1024}MB)")
def _execute(self, **kwargs) -> str:
filepath = kwargs.get('input') or kwargs.get('filepath')
filepath = Path(filepath).resolve()
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
return f"文件 {filepath.name} 的内容:\n{content}"
except UnicodeDecodeError:
# 尝试其他编码
with open(filepath, 'r', encoding='gbk') as f:
content = f.read()
return f"文件 {filepath.name} 的内容:\n{content}"
工具2:Python代码执行器(沙箱)
"""
code_executor.py - 代码执行工具
"""
import subprocess
import tempfile
import os
from tool_base import Tool, ToolValidationError
class PythonExecutor(Tool):
"""Python代码执行器(Docker沙箱)"""
def __init__(self, use_docker: bool = True):
super().__init__(
name="PythonExecutor",
description="执行Python代码。输入:Python代码字符串",
max_execution_time=30,
require_confirmation=True
)
self.use_docker = use_docker
def _validate_input(self, inputs: Dict[str, Any]):
"""验证代码"""
code = inputs.get('input') or inputs.get('code')
if not code:
raise ToolValidationError("必须提供代码")
# 黑名单检查
dangerous_keywords = [
'os.system', 'subprocess', 'eval', 'exec',
'__import__', 'open', 'file', 'input'
]
for keyword in dangerous_keywords:
if keyword in code:
raise ToolValidationError(f"代码包含危险操作: {keyword}")
# 限制代码长度
if len(code) > 10000:
raise ToolValidationError("代码太长")
def _execute(self, **kwargs) -> str:
code = kwargs.get('input') or kwargs.get('code')
if self.use_docker:
return self._execute_in_docker(code)
else:
return self._execute_local(code)
def _execute_in_docker(self, code: str) -> str:
"""在Docker容器中执行"""
# 创建临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(code)
temp_file = f.name
try:
# 在Docker中执行
cmd = [
'docker', 'run', '--rm',
'--network', 'none', # 禁用网络
'--memory', '256m', # 限制内存
'--cpus', '0.5', # 限制CPU
'-v', f'{temp_file}:/code.py:ro', # 只读挂载
'python:3.9-slim',
'python', '/code.py'
]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=self.max_execution_time
)
if result.returncode == 0:
return f"执行成功:\n{result.stdout}"
else:
return f"执行失败:\n{result.stderr}"
finally:
os.unlink(temp_file)
def _execute_local(self, code: str) -> str:
"""本地执行(受限环境)"""
# 创建受限的全局环境
safe_globals = {
'__builtins__': {
'abs': abs, 'round': round, 'len': len,
'min': min, 'max': max, 'sum': sum,
'range': range, 'list': list, 'dict': dict,
'str': str, 'int': int, 'float': float,
'print': print
}
}
# 捕获输出
from io import StringIO
import sys
old_stdout = sys.stdout
sys.stdout = StringIO()
try:
exec(code, safe_globals)
output = sys.stdout.getvalue()
return f"执行成功:\n{output}"
except Exception as e:
return f"执行失败:{type(e).__name__} - {str(e)}"
finally:
sys.stdout = old_stdout
工具3:HTTP请求工具
"""
http_tools.py - HTTP工具
"""
import requests
from urllib.parse import urlparse
from tool_base import Tool, ToolValidationError
class HTTPRequest(Tool):
"""HTTP请求工具"""
def __init__(self, allowed_domains: list = None):
super().__init__(
name="HTTPRequest",
description="发送HTTP请求。输入:{'url': '...', 'method': 'GET'}",
max_execution_time=30
)
self.allowed_domains = allowed_domains or []
def _validate_input(self, inputs: Dict[str, Any]):
"""验证请求"""
url = inputs.get('url')
if not url:
raise ToolValidationError("必须提供URL")
# 验证URL格式
try:
parsed = urlparse(url)
except Exception:
raise ToolValidationError("无效的URL")
# 检查协议
if parsed.scheme not in ['http', 'https']:
raise ToolValidationError("只支持HTTP/HTTPS协议")
# 检查域名白名单
if self.allowed_domains:
if not any(parsed.netloc.endswith(domain) for domain in self.allowed_domains):
raise ToolValidationError(f"域名不在白名单内: {parsed.netloc}")
# 禁止访问内网
if parsed.netloc in ['localhost', '127.0.0.1'] or parsed.netloc.startswith('192.168.'):
raise ToolValidationError("禁止访问内网地址")
def _execute(self, **kwargs) -> str:
url = kwargs.get('url')
method = kwargs.get('method', 'GET').upper()
headers = kwargs.get('headers', {})
data = kwargs.get('data')
try:
response = requests.request(
method=method,
url=url,
headers=headers,
json=data,
timeout=self.max_execution_time,
allow_redirects=True,
verify=True # 验证SSL证书
)
return f"""
请求成功
状态码: {response.status_code}
响应头: {dict(response.headers)}
响应体: {response.text[:1000]}
"""
except requests.exceptions.Timeout:
return "错误:请求超时"
except requests.exceptions.ConnectionError:
return "错误:连接失败"
except Exception as e:
return f"错误:{type(e).__name__} - {str(e)}"
工具4:数据库查询工具
"""
database_tools.py - 数据库工具
"""
import sqlite3
import re
from tool_base import Tool, ToolValidationError
class SQLQuery(Tool):
"""SQL查询工具"""
def __init__(self, db_path: str):
super().__init__(
name="SQLQuery",
description="执行SQL查询。输入:SQL语句",
max_execution_time=10,
require_confirmation=True
)
self.db_path = db_path
def _validate_input(self, inputs: Dict[str, Any]):
"""验证SQL"""
sql = inputs.get('input') or inputs.get('sql')
if not sql:
raise ToolValidationError("必须提供SQL语句")
sql = sql.strip().upper()
# 只允许SELECT
if not sql.startswith('SELECT'):
raise ToolValidationError("只允许SELECT查询")
# 禁止某些操作
forbidden = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'CREATE']
for word in forbidden:
if word in sql:
raise ToolValidationError(f"禁止的操作: {word}")
def _execute(self, **kwargs) -> str:
sql = kwargs.get('input') or kwargs.get('sql')
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(sql)
rows = cursor.fetchall()
# 获取列名
columns = [desc[0] for desc in cursor.description]
conn.close()
# 格式化输出
result = f"查询成功,返回 {len(rows)} 行\n\n"
result += " | ".join(columns) + "\n"
result += "-" * 50 + "\n"
for row in rows[:100]: # 最多显示100行
result += " | ".join(str(v) for v in row) + "\n"
if len(rows) > 100:
result += f"\n... (还有 {len(rows) - 100} 行未显示)"
return result
except sqlite3.Error as e:
return f"SQL错误:{str(e)}"
工具5:数据分析工具
"""
data_tools.py - 数据分析工具
"""
import pandas as pd
import numpy as np
from tool_base import Tool, ToolValidationError
class DataAnalyzer(Tool):
"""数据分析工具"""
def __init__(self):
super().__init__(
name="DataAnalyzer",
description="分析CSV数据。输入:{'file': '...', 'operation': 'describe'}",
max_execution_time=60
)
def _validate_input(self, inputs: Dict[str, Any]):
"""验证输入"""
file_path = inputs.get('file')
operation = inputs.get('operation')
if not file_path:
raise ToolValidationError("必须提供文件路径")
if not operation:
raise ToolValidationError("必须提供操作类型")
allowed_operations = ['describe', 'head', 'info', 'columns', 'shape']
if operation not in allowed_operations:
raise ToolValidationError(f"不支持的操作: {operation}")
def _execute(self, **kwargs) -> str:
file_path = kwargs.get('file')
operation = kwargs.get('operation')
try:
# 读取CSV
df = pd.read_csv(file_path)
# 执行操作
if operation == 'describe':
result = df.describe().to_string()
return f"数据统计信息:\n{result}"
elif operation == 'head':
n = kwargs.get('n', 5)
result = df.head(n).to_string()
return f"前{n}行数据:\n{result}"
elif operation == 'info':
import io
buffer = io.StringIO()
df.info(buf=buffer)
return f"数据信息:\n{buffer.getvalue()}"
elif operation == 'columns':
return f"列名:{', '.join(df.columns)}"
elif operation == 'shape':
return f"数据形状:{df.shape[0]}行 × {df.shape[1]}列"
except FileNotFoundError:
return f"错误:文件不存在 - {file_path}"
except pd.errors.EmptyDataError:
return "错误:文件为空"
except Exception as e:
return f"错误:{type(e).__name__} - {str(e)}"
五、工具编排:让工具协作起来
单个工具很强大,但真正的威力在于工具的组合使用。
5.1 工具链模式
"""
tool_chain.py - 工具链
"""
from typing import List
from tool_base import Tool
class ToolChain:
"""工具链:按顺序执行多个工具"""
def __init__(self, tools: List[Tool]):
self.tools = tools
def run(self, initial_input: str) -> str:
"""执行工具链"""
current_input = initial_input
for i, tool in enumerate(self.tools, 1):
print(f"\n{'='*50}")
print(f"步骤 {i}/{len(self.tools)}: {tool.name}")
print(f"{'='*50}")
result = tool.run(current_input)
print(f"结果: {result[:200]}...")
# 下一个工具的输入是上一个的输出
current_input = result
return current_input
# 使用示例
chain = ToolChain([
FileReader(allowed_dirs=['.']),
DataAnalyzer(),
])
result = chain.run("data.csv")
5.2 工具路由器
"""
tool_router.py - 工具路由
"""
class ToolRouter:
"""工具路由器:根据条件选择工具"""
def __init__(self, tools: Dict[str, Tool]):
self.tools = tools
def route(self, task: str) -> str:
"""智能路由到合适的工具"""
task_lower = task.lower()
# 简单的关键词匹配
if 'calculate' in task_lower or '计算' in task_lower:
return self.tools['calculator'].run(task)
elif 'file' in task_lower or '文件' in task_lower:
return self.tools['file_reader'].run(task)
elif 'http' in task_lower or 'api' in task_lower:
return self.tools['http_request'].run(task)
else:
return "无法确定使用哪个工具"
# 实际使用中,应该让LLM来做路由决策
六、实战案例:构建数据分析Agent
让我们把所有工具组合起来,打造一个数据分析Agent:
"""
data_analysis_agent.py -
```python
"""
data_analysis_agent.py - 数据分析Agent
"""
from react_agent import ReActAgent
from file_tools import FileReader
from data_tools import DataAnalyzer
from code_executor import PythonExecutor
from tool_base import Tool
class DataAnalysisAgent:
"""专门用于数据分析的Agent"""
def __init__(self):
# 初始化工具
self.tools = [
FileReader(allowed_dirs=['./data']),
DataAnalyzer(),
PythonExecutor(use_docker=True),
self._create_chart_tool()
]
# 创建Agent
self.agent = ReActAgent(
tools=self.tools,
model="gpt-4" # 数据分析用更强的模型
)
def _create_chart_tool(self) -> Tool:
"""创建图表生成工具"""
class ChartGenerator(Tool):
def __init__(self):
super().__init__(
name="ChartGenerator",
description="生成数据可视化图表。输入:{'file': '...', 'chart_type': 'bar/line/pie'}",
max_execution_time=30
)
def _execute(self, **kwargs):
import matplotlib.pyplot as plt
import pandas as pd
file_path = kwargs.get('file')
chart_type = kwargs.get('chart_type', 'bar')
df = pd.read_csv(file_path)
plt.figure(figsize=(10, 6))
if chart_type == 'bar':
df.plot(kind='bar')
elif chart_type == 'line':
df.plot(kind='line')
elif chart_type == 'pie':
df.plot(kind='pie', y=df.columns[0])
output_path = f'chart_{chart_type}.png'
plt.savefig(output_path)
plt.close()
return f"图表已保存到 {output_path}"
return ChartGenerator()
def analyze(self, task: str):
"""执行数据分析任务"""
# 增强的Prompt
enhanced_task = f"""
作为数据分析专家,请完成以下任务:
{task}
分析步骤建议:
1. 先查看数据基本信息(shape, columns, head)
2. 进行统计分析(describe)
3. 根据需要生成图表
4. 总结关键发现
请详细说明你的分析过程。
"""
return self.agent.run(enhanced_task)
# ========== 使用示例 ==========
if __name__ == "__main__":
# 准备测试数据
import pandas as pd
# 创建示例CSV
data = {
'month': ['Jan', 'Feb', 'Mar', 'Apr', 'May'],
'sales': [12000, 15000, 13000, 18000, 21000],
'profit': [3000, 3500, 2800, 4200, 5000]
}
df = pd.DataFrame(data)
df.to_csv('./data/sales.csv', index=False)
# 创建Agent
agent = DataAnalysisAgent()
# 执行分析
result = agent.analyze("""
分析 ./data/sales.csv 文件:
1. 总销售额和平均销售额是多少?
2. 哪个月份销售额最高?
3. 生成销售额趋势图
""")
print("\n" + "="*60)
print("分析结果:")
print("="*60)
print(result)
运行效果
============================================================
🎯 Question: 分析 ./data/sales.csv 文件...
============================================================
【第1轮】
Thought: 我需要先查看数据的基本信息,了解数据结构
Action: DataAnalyzer
Action Input: {"file": "./data/sales.csv", "operation": "head"}
📊 Observation: 前5行数据:
month sales profit
0 Jan 12000 3000
1 Feb 15000 3500
2 Mar 13000 2800
3 Apr 18000 4200
4 May 21000 5000
------------------------------------------------------------
【第2轮】
Thought: 数据包含3列:month, sales, profit
现在我需要计算总销售额和平均销售额
Action: PythonExecutor
Action Input:
import pandas as pd
df = pd.read_csv('./data/sales.csv')
total_sales = df['sales'].sum()
avg_sales = df['sales'].mean()
print(f"总销售额: {total_sales}")
print(f"平均销售额: {avg_sales}")
📊 Observation: 执行成功:
总销售额: 79000
平均销售额: 15800.0
------------------------------------------------------------
【第3轮】
Thought: 已知总销售额79000,平均15800
现在找出销售额最高的月份
Action: PythonExecutor
Action Input:
import pandas as pd
df = pd.read_csv('./data/sales.csv')
max_month = df.loc[df['sales'].idxmax(), 'month']
max_sales = df['sales'].max()
print(f"销售额最高月份: {max_month} ({max_sales})")
📊 Observation: 执行成功:
销售额最高月份: May (21000)
------------------------------------------------------------
【第4轮】
Thought: 最高销售月份是May
现在生成趋势图
Action: ChartGenerator
Action Input: {"file": "./data/sales.csv", "chart_type": "line"}
📊 Observation: 图表已保存到 chart_line.png
------------------------------------------------------------
【第5轮】
Thought: 我已经完成了所有分析任务
Final Answer:
数据分析报告:
1. 销售总览
- 总销售额: ¥79,000
- 平均销售额: ¥15,800/月
2. 最佳表现
- 销售额最高月份: May (¥21,000)
3. 趋势分析
- 销售额整体呈上升趋势
- 从1月的¥12,000增长到5月的¥21,000
- 增长率: 75%
4. 可视化
- 趋势图已保存: chart_line.png
============================================================
✅ 分析完成
============================================================
七、工具安全性最佳实践
7.1 输入验证清单
# ✅ 完整的验证示例
def _validate_input(self, inputs):
# 1. 类型检查
if not isinstance(inputs.get('param'), str):
raise ToolValidationError("参数必须是字符串")
# 2. 长度限制
if len(inputs.get('param')) > 1000:
raise ToolValidationError("参数过长")
# 3. 格式验证
import re
if not re.match(r'^[a-zA-Z0-9_]+$', inputs.get('param')):
raise ToolValidationError("参数包含非法字符")
# 4. 业务规则验证
if inputs.get('amount', 0) < 0:
raise ToolValidationError("金额不能为负")
# 5. 白名单验证
allowed_values = ['option1', 'option2']
if inputs.get('choice') not in allowed_values:
raise ToolValidationError("无效的选项")
7.2 资源限制
class ResourceLimitedTool(Tool):
"""资源受限的工具"""
def __init__(self):
super().__init__(
name="ResourceLimited",
description="资源受限的工具",
max_execution_time=30
)
# 限制
self.max_memory = 100 * 1024 * 1024 # 100MB
self.max_cpu_time = 10 # 10秒
self.max_file_size = 10 * 1024 * 1024 # 10MB
def _execute(self, **kwargs):
import resource
# 设置内存限制
resource.setrlimit(
resource.RLIMIT_AS,
(self.max_memory, self.max_memory)
)
# 设置CPU时间限制
resource.setrlimit(
resource.RLIMIT_CPU,
(self.max_cpu_time, self.max_cpu_time)
)
# 执行实际操作
return self._do_work(**kwargs)
7.3 审计日志
"""
audit.py - 审计日志
"""
import logging
import json
from datetime import datetime
class AuditLogger:
"""审计日志记录器"""
def __init__(self, log_file='audit.log'):
self.logger = logging.getLogger('audit')
handler = logging.FileHandler(log_file)
handler.setFormatter(logging.Formatter(
'%(asctime)s - %(message)s'
))
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def log_tool_execution(
self,
tool_name: str,
inputs: dict,
result: str,
success: bool,
user_id: str = None
):
"""记录工具执行"""
log_entry = {
'timestamp': datetime.now().isoformat(),
'tool': tool_name,
'user': user_id,
'inputs': inputs,
'result': result[:100], # 只记录前100字符
'success': success
}
self.logger.info(json.dumps(log_entry))
# 在Tool基类中集成
class AuditedTool(Tool):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.audit_logger = AuditLogger()
def run(self, input_str: str) -> str:
result = super().run(input_str)
# 记录到审计日志
self.audit_logger.log_tool_execution(
tool_name=self.name,
inputs={'input': input_str},
result=result,
success='错误' not in result
)
return result
八、常见问题与解决方案
Q1: 工具执行太慢怎么办?
问题: 某些工具(如大文件处理)执行时间过长
解决方案:
# 1. 异步执行
import asyncio
class AsyncTool(Tool):
async def run_async(self, input_str: str):
"""异步执行"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
self.run,
input_str
)
# 2. 进度回调
class ProgressTool(Tool):
def __init__(self, progress_callback=None):
super().__init__()
self.progress_callback = progress_callback
def _execute(self, **kwargs):
total = 100
for i in range(total):
# 执行工作
time.sleep(0.01)
# 报告进度
if self.progress_callback:
self.progress_callback(i / total)
return "完成"
# 使用
def show_progress(percent):
print(f"\r进度: {percent:.0%}", end='')
tool = ProgressTool(progress_callback=show_progress)
Q2: 如何处理工具之间的依赖?
问题: 工具B需要工具A的输出
解决方案:
class ToolDependency:
"""工具依赖管理"""
def __init__(self, tools: List[Tool]):
self.tools = {t.name: t for t in tools}
self.results_cache = {}
def run_with_deps(self, tool_name: str, inputs: dict):
"""执行工具(处理依赖)"""
# 检查依赖
deps = self._get_dependencies(tool_name)
# 先执行依赖
for dep in deps:
if dep not in self.results_cache:
dep_result = self.tools[dep].run(inputs)
self.results_cache[dep] = dep_result
# 注入依赖的结果
enhanced_inputs = inputs.copy()
for dep in deps:
enhanced_inputs[f'{dep}_result'] = self.results_cache[dep]
# 执行目标工具
return self.tools[tool_name].run(enhanced_inputs)
Q3: 如何实现工具的版本管理?
解决方案:
class VersionedTool(Tool):
"""版本化的工具"""
VERSION = "1.0.0"
def __init__(self):
super().__init__(
name=f"{self.__class__.__name__}@{self.VERSION}",
description="..."
)
def run(self, input_str: str) -> str:
# 在结果中包含版本信息
result = super().run(input_str)
return f"[v{self.VERSION}] {result}"
# 工具注册表
class ToolRegistry:
"""工具注册表"""
def __init__(self):
self.tools = {}
def register(self, tool: Tool):
"""注册工具"""
key = f"{tool.name}@{getattr(tool, 'VERSION', '1.0.0')}"
self.tools[key] = tool
def get(self, name: str, version: str = None):
"""获取工具"""
if version:
key = f"{name}@{version}"
else:
# 获取最新版本
versions = [k for k in self.tools.keys() if k.startswith(name)]
key = max(versions) if versions else None
return self.tools.get(key)
九、工具测试最佳实践
9.1 单元测试
"""
test_tools.py - 工具测试
"""
import unittest
from file_tools import FileReader
from tool_base import ToolValidationError
class TestFileReader(unittest.TestCase):
"""FileReader工具测试"""
def setUp(self):
"""测试前准备"""
self.tool = FileReader(allowed_dirs=['/tmp'])
# 创建测试文件
with open('/tmp/test.txt', 'w') as f:
f.write("Hello, World!")
def test_read_file_success(self):
"""测试正常读取"""
result = self.tool.run("/tmp/test.txt")
self.assertIn("Hello, World!", result)
def test_read_file_not_found(self):
"""测试文件不存在"""
result = self.tool.run("/tmp/nonexistent.txt")
self.assertIn("错误", result)
def test_path_traversal_attack(self):
"""测试路径遍历攻击"""
result = self.tool.run("/tmp/../etc/passwd")
self.assertIn("错误", result)
def test_file_too_large(self):
"""测试文件过大"""
# 创建大文件
with open('/tmp/large.txt', 'w') as f:
f.write('x' * (20 * 1024 * 1024)) # 20MB
result = self.tool.run("/tmp/large.txt")
self.assertIn("太大", result)
def tearDown(self):
"""测试后清理"""
import os
if os.path.exists('/tmp/test.txt'):
os.remove('/tmp/test.txt')
if os.path.exists('/tmp/large.txt'):
os.remove('/tmp/large.txt')
if __name__ == '__main__':
unittest.main()
9.2 集成测试
"""
test_integration.py - 集成测试
"""
import unittest
from data_analysis_agent import DataAnalysisAgent
class TestDataAnalysisAgent(unittest.TestCase):
"""数据分析Agent集成测试"""
def setUp(self):
"""准备测试数据"""
import pandas as pd
data = {
'product': ['A', 'B', 'C'],
'sales': [100, 200, 150]
}
df = pd.DataFrame(data)
df.to_csv('./test_data.csv', index=False)
self.agent = DataAnalysisAgent()
def test_full_analysis(self):
"""测试完整分析流程"""
result = self.agent.analyze("""
分析 test_data.csv:
1. 总销售额
2. 最畅销产品
""")
# 验证结果
self.assertIn("450", result) # 总销售额
self.assertIn("B", result) # 最畅销产品
def tearDown(self):
"""清理"""
import os
if os.path.exists('./test_data.csv'):
os.remove('./test_data.csv')
十、性能优化技巧
技巧1:缓存结果
from functools import lru_cache
import hashlib
class CachedTool(Tool):
"""带缓存的工具"""
def __init__(self):
super().__init__()
self.cache = {}
def run(self, input_str: str) -> str:
# 计算输入的哈希
cache_key = hashlib.md5(input_str.encode()).hexdigest()
# 检查缓存
if cache_key in self.cache:
logger.info(f"[{self.name}] 缓存命中")
return self.cache[cache_key]
# 执行并缓存
result = super().run(input_str)
self.cache[cache_key] = result
return result
技巧2:并行执行
from concurrent.futures import ThreadPoolExecutor
class ParallelToolExecutor:
"""并行工具执行器"""
def __init__(self, max_workers=4):
self.executor = ThreadPoolExecutor(max_workers=max_workers)
def run_parallel(self, tasks: List[tuple]):
"""并行执行多个工具
Args:
tasks: [(tool, input), (tool, input), ...]
"""
futures = []
for tool, input_str in tasks:
future = self.executor.submit(tool.run, input_str)
futures.append(future)
# 等待所有任务完成
results = [f.result() for f in futures]
return results
# 使用示例
executor = ParallelToolExecutor()
results = executor.run_parallel([
(calculator, "1+1"),
(calculator, "2+2"),
(calculator, "3+3")
])
技巧3:懒加载
class LazyTool(Tool):
"""懒加载工具"""
def __init__(self):
super().__init__()
self._heavy_resource = None
@property
def heavy_resource(self):
"""懒加载重量级资源"""
if self._heavy_resource is None:
print("加载重量级资源...")
self._heavy_resource = self._load_resource()
return self._heavy_resource
def _load_resource(self):
"""加载资源(只在需要时调用)"""
# 比如加载大型模型
import time
time.sleep(2) # 模拟加载时间
return "Heavy Resource"
总结
今天我们深入学习了工具系统设计:
✅ 设计原则: 单一职责、防御性编程、优雅错误处理
✅ 工具基类: 完整的验证、超时、日志、统计
✅ 5个实用工具: 文件、代码、HTTP、数据库、数据分析
✅ 安全实践: 输入验证、资源限制、审计日志
✅ 性能优化: 缓存、并行、懒加载
💡 核心洞察: 工具是Agent的"手和脚",设计好工具系统,Agent才能真正发挥作用。安全性永远是第一位的。
下期预告
第4篇:《记忆系统实战:让Agent记住你说过的每句话》
下期我们将探讨:
- 短期记忆、工作记忆、长期记忆的设计
- 向量数据库的应用(Pinecone、Weaviate)
- 上下文管理策略
- 记忆检索和更新机制
练习题:
- 基础: 实现一个天气查询工具,调用真实的天气API
- 进阶: 给FileReader添加写入功能,确保安全性
- 挑战: 实现一个邮件发送工具,支持附件和HTML格式
欢迎在留言区分享你的实现!
如果这篇文章对你有帮助:
- 👍 点赞让更多人看到
- 🧐 关注架构之旅公众号
- 🔖 收藏方便查阅
- 💬 留言讨论你的想法
下期见!
本文是《AI Agent实战系列》第3 篇,后续还会更新AI Agent进阶玩法。关注公众号【架构之旅】,第一时间解锁全套实战教程,错过不再补~
更多推荐


所有评论(0)