AI Agent实战系列·第3篇
从简单计算器到Docker沙箱的进化之路

写在前面

上期回顾: 第2篇我们学习了ReAct模式,让Agent能够"边想边做"。但我们只用了3个简单工具,远远不够。

本期你将收获:

  • ✅ 掌握工具系统的设计原则
  • ✅ 实现5个生产级实用工具
  • ✅ 学会Docker沙箱隔离技术
  • ✅ 理解工具编排和错误处理

一、翻车现场:我的Agent把服务器删了

上周,我给Agent加了个"文件操作"工具,结果…

# 我写的工具(危险版本)
def file_operation(command: str):
    """执行文件操作"""
    return os.system(command)  # ❌ 巨大的安全隐患

# 某次对话
我:帮我清理一下临时文件
Agent:好的
  → 使用工具:file_operation("rm -rf /tmp/*")
  → 哦不,我理解错了...
  → 使用工具:file_operation("rm -rf /*")  # 💀

# 服务器:💥

这个惨痛教训告诉我们:

  • ❌ 工具设计不能随便
  • ❌ 必须要有安全边界
  • ❌ 需要严格的输入验证
  • ❌ 要有回滚机制

今天,我们就来设计一个安全、健壮、可扩展的工具系统。


二、工具系统的5大设计原则

原则1️⃣:单一职责

每个工具只做一件事,做好一件事。

# ❌ 不好:一个工具做太多事
def super_tool(action, **kwargs):
    if action == "calculate":
        return eval(kwargs['expression'])
    elif action == "search":
        return search_web(kwargs['query'])
    elif action == "read_file":
        return read_file(kwargs['filename'])
    # ... 100行

# ✅ 好:每个工具专注一个功能
class Calculator(Tool):
    """只负责计算"""
    
class WebSearch(Tool):
    """只负责搜索"""
    
class FileReader(Tool):
    """只负责读文件"""

原则2️⃣:明确的输入输出

# ❌ 不好:模糊的接口
def process(data):
    """处理数据"""  # 什么数据?怎么处理?
    return result   # 什么结果?

# ✅ 好:清晰的接口
def calculate(expression: str) -> float:
    """
    执行数学计算
    
    Args:
        expression: 数学表达式,如 "2 + 3 * 4"
        
    Returns:
        float: 计算结果
        
    Raises:
        ValueError: 表达式无效时
        
    Example:
        >>> calculate("2 + 3")
        5.0
    """

原则3️⃣:防御性编程

假设一切输入都是恶意的。

class SafeTool(Tool):
    def run(self, **kwargs):
        # 1. 验证输入
        self._validate_input(kwargs)
        
        # 2. 清理数据
        cleaned = self._sanitize(kwargs)
        
        # 3. 限制执行时间
        with timeout(self.max_time):
            result = self._execute(cleaned)
        
        # 4. 验证输出
        return self._validate_output(result)

原则4️⃣:优雅的错误处理

class RobustTool(Tool):
    def run(self, **kwargs):
        try:
            return self._execute(**kwargs)
        except FileNotFoundError as e:
            return f"错误:文件不存在 - {e.filename}"
        except PermissionError:
            return "错误:没有权限执行此操作"
        except TimeoutError:
            return "错误:操作超时"
        except Exception as e:
            # 记录详细日志
            logger.error(f"工具执行失败: {e}", exc_info=True)
            # 返回用户友好的信息
            return f"错误:{type(e).__name__}"

原则5️⃣:可观测性

每个工具都应该记录日志、指标。

class ObservableTool(Tool):
    def run(self, **kwargs):
        start_time = time.time()
        
        logger.info(f"工具开始执行: {self.name}", extra={
            "tool": self.name,
            "inputs": kwargs
        })
        
        try:
            result = self._execute(**kwargs)
            
            # 记录成功指标
            metrics.increment(f"tool.{self.name}.success")
            
            return result
        except Exception as e:
            # 记录失败指标
            metrics.increment(f"tool.{self.name}.error")
            raise
        finally:
            duration = time.time() - start_time
            metrics.timing(f"tool.{self.name}.duration", duration)

三、完整的工具基类设计

基于上述原则,我们设计一个完善的工具基类:

"""
tool_base.py - 工具基类
"""

import time
import logging
from typing import Any, Dict, Optional
from abc import ABC, abstractmethod
import json
from functools import wraps

logger = logging.getLogger(__name__)

class ToolExecutionError(Exception):
    """工具执行错误"""
    pass

class ToolTimeoutError(ToolExecutionError):
    """工具超时错误"""
    pass

class ToolValidationError(ToolExecutionError):
    """工具验证错误"""
    pass


def timeout(seconds):
    """超时装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            import signal
            
            def handler(signum, frame):
                raise ToolTimeoutError(f"操作超时({seconds}秒)")
            
            # 设置超时
            signal.signal(signal.SIGALRM, handler)
            signal.alarm(seconds)
            
            try:
                result = func(*args, **kwargs)
            finally:
                signal.alarm(0)  # 取消超时
            
            return result
        return wrapper
    return decorator


class Tool(ABC):
    """工具基类"""
    
    def __init__(
        self,
        name: str,
        description: str,
        max_execution_time: int = 30,
        require_confirmation: bool = False
    ):
        self.name = name
        self.description = description
        self.max_execution_time = max_execution_time
        self.require_confirmation = require_confirmation
        
        # 统计信息
        self.execution_count = 0
        self.success_count = 0
        self.error_count = 0
        self.total_time = 0.0
    
    def run(self, input_str: str) -> str:
        """
        执行工具(对外接口)
        
        Args:
            input_str: 输入参数(JSON字符串或普通字符串)
            
        Returns:
            str: 执行结果
        """
        start_time = time.time()
        self.execution_count += 1
        
        logger.info(f"[{self.name}] 开始执行", extra={
            "tool": self.name,
            "input": input_str[:100]  # 只记录前100个字符
        })
        
        try:
            # 1. 解析输入
            parsed_input = self._parse_input(input_str)
            
            # 2. 验证输入
            self._validate_input(parsed_input)
            
            # 3. 确认(如果需要)
            if self.require_confirmation:
                if not self._confirm_execution(parsed_input):
                    return "操作已取消"
            
            # 4. 执行(带超时)
            @timeout(self.max_execution_time)
            def execute():
                return self._execute(**parsed_input)
            
            result = execute()
            
            # 5. 验证输出
            validated_result = self._validate_output(result)
            
            # 6. 记录成功
            self.success_count += 1
            execution_time = time.time() - start_time
            self.total_time += execution_time
            
            logger.info(f"[{self.name}] 执行成功", extra={
                "tool": self.name,
                "duration": execution_time,
                "result_length": len(str(validated_result))
            })
            
            return validated_result
            
        except ToolTimeoutError as e:
            self.error_count += 1
            logger.error(f"[{self.name}] 超时: {e}")
            return f"错误:操作超时(>{self.max_execution_time}秒)"
            
        except ToolValidationError as e:
            self.error_count += 1
            logger.error(f"[{self.name}] 验证失败: {e}")
            return f"错误:{str(e)}"
            
        except Exception as e:
            self.error_count += 1
            logger.error(f"[{self.name}] 执行失败", exc_info=True)
            return f"错误:{type(e).__name__} - {str(e)}"
    
    def _parse_input(self, input_str: str) -> Dict[str, Any]:
        """解析输入"""
        # 尝试解析JSON
        input_str = input_str.strip()
        
        if input_str.startswith('{'):
            try:
                return json.loads(input_str)
            except json.JSONDecodeError:
                pass
        
        # 否则作为单个参数
        return {"input": input_str}
    
    def _validate_input(self, inputs: Dict[str, Any]):
        """
        验证输入(子类可覆盖)
        
        Raises:
            ToolValidationError: 验证失败时
        """
        pass
    
    def _validate_output(self, result: Any) -> str:
        """
        验证输出(子类可覆盖)
        
        Args:
            result: 执行结果
            
        Returns:
            str: 验证后的结果字符串
        """
        result_str = str(result)
        
        # 限制输出长度
        max_length = 5000
        if len(result_str) > max_length:
            result_str = result_str[:max_length] + f"\n... (截断,总长度{len(result_str)})"
        
        return result_str
    
    def _confirm_execution(self, inputs: Dict[str, Any]) -> bool:
        """
        确认执行(危险操作时使用)
        
        实际应该通过某种机制让用户确认
        这里简化处理
        """
        print(f"⚠️  工具 {self.name} 需要确认")
        print(f"参数: {inputs}")
        # 实际场景应该等待用户输入
        return True
    
    @abstractmethod
    def _execute(self, **kwargs) -> Any:
        """
        实际执行逻辑(子类必须实现)
        
        Returns:
            Any: 执行结果
        """
        raise NotImplementedError
    
    def get_stats(self) -> Dict[str, Any]:
        """获取统计信息"""
        avg_time = self.total_time / self.execution_count if self.execution_count > 0 else 0
        success_rate = self.success_count / self.execution_count if self.execution_count > 0 else 0
        
        return {
            "name": self.name,
            "executions": self.execution_count,
            "success": self.success_count,
            "errors": self.error_count,
            "success_rate": f"{success_rate:.1%}",
            "avg_time": f"{avg_time:.2f}s",
            "total_time": f"{self.total_time:.2f}s"
        }
    
    def __str__(self):
        return f"{self.name}: {self.description}"

四、5个生产级工具实现

现在基于这个基类,我们实现5个实用工具。

工具1:安全的文件读取器

"""
file_tools.py - 文件操作工具
"""

import os
from pathlib import Path
from tool_base import Tool, ToolValidationError

class FileReader(Tool):
    """安全的文件读取工具"""
    
    def __init__(self, allowed_dirs: list = None):
        super().__init__(
            name="FileReader",
            description="读取文件内容。输入:文件路径",
            max_execution_time=10
        )
        # 限制可访问的目录
        self.allowed_dirs = [Path(d).resolve() for d in (allowed_dirs or ['.'])]
    
    def _validate_input(self, inputs: Dict[str, Any]):
        """验证文件路径"""
        filepath = inputs.get('input') or inputs.get('filepath')
        
        if not filepath:
            raise ToolValidationError("必须提供文件路径")
        
        # 检查路径遍历攻击
        filepath = Path(filepath).resolve()
        
        # 检查是否在允许的目录内
        if not any(str(filepath).startswith(str(allowed)) for allowed in self.allowed_dirs):
            raise ToolValidationError(f"文件路径不在允许的目录内")
        
        # 检查文件是否存在
        if not filepath.exists():
            raise ToolValidationError(f"文件不存在: {filepath}")
        
        # 检查是否为文件
        if not filepath.is_file():
            raise ToolValidationError(f"不是文件: {filepath}")
        
        # 检查文件大小
        max_size = 10 * 1024 * 1024  # 10MB
        if filepath.stat().st_size > max_size:
            raise ToolValidationError(f"文件太大(>{max_size/1024/1024}MB)")
    
    def _execute(self, **kwargs) -> str:
        filepath = kwargs.get('input') or kwargs.get('filepath')
        filepath = Path(filepath).resolve()
        
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                content = f.read()
            
            return f"文件 {filepath.name} 的内容:\n{content}"
        
        except UnicodeDecodeError:
            # 尝试其他编码
            with open(filepath, 'r', encoding='gbk') as f:
                content = f.read()
            return f"文件 {filepath.name} 的内容:\n{content}"

工具2:Python代码执行器(沙箱)

"""
code_executor.py - 代码执行工具
"""

import subprocess
import tempfile
import os
from tool_base import Tool, ToolValidationError

class PythonExecutor(Tool):
    """Python代码执行器(Docker沙箱)"""
    
    def __init__(self, use_docker: bool = True):
        super().__init__(
            name="PythonExecutor",
            description="执行Python代码。输入:Python代码字符串",
            max_execution_time=30,
            require_confirmation=True
        )
        self.use_docker = use_docker
    
    def _validate_input(self, inputs: Dict[str, Any]):
        """验证代码"""
        code = inputs.get('input') or inputs.get('code')
        
        if not code:
            raise ToolValidationError("必须提供代码")
        
        # 黑名单检查
        dangerous_keywords = [
            'os.system', 'subprocess', 'eval', 'exec',
            '__import__', 'open', 'file', 'input'
        ]
        
        for keyword in dangerous_keywords:
            if keyword in code:
                raise ToolValidationError(f"代码包含危险操作: {keyword}")
        
        # 限制代码长度
        if len(code) > 10000:
            raise ToolValidationError("代码太长")
    
    def _execute(self, **kwargs) -> str:
        code = kwargs.get('input') or kwargs.get('code')
        
        if self.use_docker:
            return self._execute_in_docker(code)
        else:
            return self._execute_local(code)
    
    def _execute_in_docker(self, code: str) -> str:
        """在Docker容器中执行"""
        # 创建临时文件
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(code)
            temp_file = f.name
        
        try:
            # 在Docker中执行
            cmd = [
                'docker', 'run', '--rm',
                '--network', 'none',  # 禁用网络
                '--memory', '256m',   # 限制内存
                '--cpus', '0.5',      # 限制CPU
                '-v', f'{temp_file}:/code.py:ro',  # 只读挂载
                'python:3.9-slim',
                'python', '/code.py'
            ]
            
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=self.max_execution_time
            )
            
            if result.returncode == 0:
                return f"执行成功:\n{result.stdout}"
            else:
                return f"执行失败:\n{result.stderr}"
        
        finally:
            os.unlink(temp_file)
    
    def _execute_local(self, code: str) -> str:
        """本地执行(受限环境)"""
        # 创建受限的全局环境
        safe_globals = {
            '__builtins__': {
                'abs': abs, 'round': round, 'len': len,
                'min': min, 'max': max, 'sum': sum,
                'range': range, 'list': list, 'dict': dict,
                'str': str, 'int': int, 'float': float,
                'print': print
            }
        }
        
        # 捕获输出
        from io import StringIO
        import sys
        
        old_stdout = sys.stdout
        sys.stdout = StringIO()
        
        try:
            exec(code, safe_globals)
            output = sys.stdout.getvalue()
            return f"执行成功:\n{output}"
        except Exception as e:
            return f"执行失败:{type(e).__name__} - {str(e)}"
        finally:
            sys.stdout = old_stdout

工具3:HTTP请求工具

"""
http_tools.py - HTTP工具
"""

import requests
from urllib.parse import urlparse
from tool_base import Tool, ToolValidationError

class HTTPRequest(Tool):
    """HTTP请求工具"""
    
    def __init__(self, allowed_domains: list = None):
        super().__init__(
            name="HTTPRequest",
            description="发送HTTP请求。输入:{'url': '...', 'method': 'GET'}",
            max_execution_time=30
        )
        self.allowed_domains = allowed_domains or []
    
    def _validate_input(self, inputs: Dict[str, Any]):
        """验证请求"""
        url = inputs.get('url')
        
        if not url:
            raise ToolValidationError("必须提供URL")
        
        # 验证URL格式
        try:
            parsed = urlparse(url)
        except Exception:
            raise ToolValidationError("无效的URL")
        
        # 检查协议
        if parsed.scheme not in ['http', 'https']:
            raise ToolValidationError("只支持HTTP/HTTPS协议")
        
        # 检查域名白名单
        if self.allowed_domains:
            if not any(parsed.netloc.endswith(domain) for domain in self.allowed_domains):
                raise ToolValidationError(f"域名不在白名单内: {parsed.netloc}")
        
        # 禁止访问内网
        if parsed.netloc in ['localhost', '127.0.0.1'] or parsed.netloc.startswith('192.168.'):
            raise ToolValidationError("禁止访问内网地址")
    
    def _execute(self, **kwargs) -> str:
        url = kwargs.get('url')
        method = kwargs.get('method', 'GET').upper()
        headers = kwargs.get('headers', {})
        data = kwargs.get('data')
        
        try:
            response = requests.request(
                method=method,
                url=url,
                headers=headers,
                json=data,
                timeout=self.max_execution_time,
                allow_redirects=True,
                verify=True  # 验证SSL证书
            )
            
            return f"""
请求成功
状态码: {response.status_code}
响应头: {dict(response.headers)}
响应体: {response.text[:1000]}
"""
        
        except requests.exceptions.Timeout:
            return "错误:请求超时"
        except requests.exceptions.ConnectionError:
            return "错误:连接失败"
        except Exception as e:
            return f"错误:{type(e).__name__} - {str(e)}"

工具4:数据库查询工具

"""
database_tools.py - 数据库工具
"""

import sqlite3
import re
from tool_base import Tool, ToolValidationError

class SQLQuery(Tool):
    """SQL查询工具"""
    
    def __init__(self, db_path: str):
        super().__init__(
            name="SQLQuery",
            description="执行SQL查询。输入:SQL语句",
            max_execution_time=10,
            require_confirmation=True
        )
        self.db_path = db_path
    
    def _validate_input(self, inputs: Dict[str, Any]):
        """验证SQL"""
        sql = inputs.get('input') or inputs.get('sql')
        
        if not sql:
            raise ToolValidationError("必须提供SQL语句")
        
        sql = sql.strip().upper()
        
        # 只允许SELECT
        if not sql.startswith('SELECT'):
            raise ToolValidationError("只允许SELECT查询")
        
        # 禁止某些操作
        forbidden = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'CREATE']
        for word in forbidden:
            if word in sql:
                raise ToolValidationError(f"禁止的操作: {word}")
    
    def _execute(self, **kwargs) -> str:
        sql = kwargs.get('input') or kwargs.get('sql')
        
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            
            cursor.execute(sql)
            rows = cursor.fetchall()
            
            # 获取列名
            columns = [desc[0] for desc in cursor.description]
            
            conn.close()
            
            # 格式化输出
            result = f"查询成功,返回 {len(rows)} 行\n\n"
            result += " | ".join(columns) + "\n"
            result += "-" * 50 + "\n"
            
            for row in rows[:100]:  # 最多显示100行
                result += " | ".join(str(v) for v in row) + "\n"
            
            if len(rows) > 100:
                result += f"\n... (还有 {len(rows) - 100} 行未显示)"
            
            return result
        
        except sqlite3.Error as e:
            return f"SQL错误:{str(e)}"

工具5:数据分析工具

"""
data_tools.py - 数据分析工具
"""

import pandas as pd
import numpy as np
from tool_base import Tool, ToolValidationError

class DataAnalyzer(Tool):
    """数据分析工具"""
    
    def __init__(self):
        super().__init__(
            name="DataAnalyzer",
            description="分析CSV数据。输入:{'file': '...', 'operation': 'describe'}",
            max_execution_time=60
        )
    
    def _validate_input(self, inputs: Dict[str, Any]):
        """验证输入"""
        file_path = inputs.get('file')
        operation = inputs.get('operation')
        
        if not file_path:
            raise ToolValidationError("必须提供文件路径")
        
        if not operation:
            raise ToolValidationError("必须提供操作类型")
        
        allowed_operations = ['describe', 'head', 'info', 'columns', 'shape']
        if operation not in allowed_operations:
            raise ToolValidationError(f"不支持的操作: {operation}")
    
    def _execute(self, **kwargs) -> str:
        file_path = kwargs.get('file')
        operation = kwargs.get('operation')
        
        try:
            # 读取CSV
            df = pd.read_csv(file_path)
            
            # 执行操作
            if operation == 'describe':
                result = df.describe().to_string()
                return f"数据统计信息:\n{result}"
            
            elif operation == 'head':
                n = kwargs.get('n', 5)
                result = df.head(n).to_string()
                return f"前{n}行数据:\n{result}"
            
            elif operation == 'info':
                import io
                buffer = io.StringIO()
                df.info(buf=buffer)
                return f"数据信息:\n{buffer.getvalue()}"
            
            elif operation == 'columns':
                return f"列名:{', '.join(df.columns)}"
            
            elif operation == 'shape':
                return f"数据形状:{df.shape[0]}行 × {df.shape[1]}列"
        
        except FileNotFoundError:
            return f"错误:文件不存在 - {file_path}"
        except pd.errors.EmptyDataError:
            return "错误:文件为空"
        except Exception as e:
            return f"错误:{type(e).__name__} - {str(e)}"

五、工具编排:让工具协作起来

单个工具很强大,但真正的威力在于工具的组合使用

5.1 工具链模式

"""
tool_chain.py - 工具链
"""

from typing import List
from tool_base import Tool

class ToolChain:
    """工具链:按顺序执行多个工具"""
    
    def __init__(self, tools: List[Tool]):
        self.tools = tools
    
    def run(self, initial_input: str) -> str:
        """执行工具链"""
        current_input = initial_input
        
        for i, tool in enumerate(self.tools, 1):
            print(f"\n{'='*50}")
            print(f"步骤 {i}/{len(self.tools)}: {tool.name}")
            print(f"{'='*50}")
            
            result = tool.run(current_input)
            print(f"结果: {result[:200]}...")
            
            # 下一个工具的输入是上一个的输出
            current_input = result
        
        return current_input

# 使用示例
chain = ToolChain([
    FileReader(allowed_dirs=['.']),
    DataAnalyzer(),
])

result = chain.run("data.csv")

5.2 工具路由器

"""
tool_router.py - 工具路由
"""

class ToolRouter:
    """工具路由器:根据条件选择工具"""
    
    def __init__(self, tools: Dict[str, Tool]):
        self.tools = tools
    
    def route(self, task: str) -> str:
        """智能路由到合适的工具"""
        task_lower = task.lower()
        
        # 简单的关键词匹配
        if 'calculate' in task_lower or '计算' in task_lower:
            return self.tools['calculator'].run(task)
        
        elif 'file' in task_lower or '文件' in task_lower:
            return self.tools['file_reader'].run(task)
        
        elif 'http' in task_lower or 'api' in task_lower:
            return self.tools['http_request'].run(task)
        
        else:
            return "无法确定使用哪个工具"

# 实际使用中,应该让LLM来做路由决策

六、实战案例:构建数据分析Agent

让我们把所有工具组合起来,打造一个数据分析Agent:

"""
data_analysis_agent.py - 


```python
"""
data_analysis_agent.py - 数据分析Agent
"""

from react_agent import ReActAgent
from file_tools import FileReader
from data_tools import DataAnalyzer
from code_executor import PythonExecutor
from tool_base import Tool

class DataAnalysisAgent:
    """专门用于数据分析的Agent"""
    
    def __init__(self):
        # 初始化工具
        self.tools = [
            FileReader(allowed_dirs=['./data']),
            DataAnalyzer(),
            PythonExecutor(use_docker=True),
            self._create_chart_tool()
        ]
        
        # 创建Agent
        self.agent = ReActAgent(
            tools=self.tools,
            model="gpt-4"  # 数据分析用更强的模型
        )
    
    def _create_chart_tool(self) -> Tool:
        """创建图表生成工具"""
        class ChartGenerator(Tool):
            def __init__(self):
                super().__init__(
                    name="ChartGenerator",
                    description="生成数据可视化图表。输入:{'file': '...', 'chart_type': 'bar/line/pie'}",
                    max_execution_time=30
                )
            
            def _execute(self, **kwargs):
                import matplotlib.pyplot as plt
                import pandas as pd
                
                file_path = kwargs.get('file')
                chart_type = kwargs.get('chart_type', 'bar')
                
                df = pd.read_csv(file_path)
                
                plt.figure(figsize=(10, 6))
                
                if chart_type == 'bar':
                    df.plot(kind='bar')
                elif chart_type == 'line':
                    df.plot(kind='line')
                elif chart_type == 'pie':
                    df.plot(kind='pie', y=df.columns[0])
                
                output_path = f'chart_{chart_type}.png'
                plt.savefig(output_path)
                plt.close()
                
                return f"图表已保存到 {output_path}"
        
        return ChartGenerator()
    
    def analyze(self, task: str):
        """执行数据分析任务"""
        # 增强的Prompt
        enhanced_task = f"""
作为数据分析专家,请完成以下任务:

{task}

分析步骤建议:
1. 先查看数据基本信息(shape, columns, head)
2. 进行统计分析(describe)
3. 根据需要生成图表
4. 总结关键发现

请详细说明你的分析过程。
"""
        return self.agent.run(enhanced_task)

# ========== 使用示例 ==========

if __name__ == "__main__":
    # 准备测试数据
    import pandas as pd
    
    # 创建示例CSV
    data = {
        'month': ['Jan', 'Feb', 'Mar', 'Apr', 'May'],
        'sales': [12000, 15000, 13000, 18000, 21000],
        'profit': [3000, 3500, 2800, 4200, 5000]
    }
    df = pd.DataFrame(data)
    df.to_csv('./data/sales.csv', index=False)
    
    # 创建Agent
    agent = DataAnalysisAgent()
    
    # 执行分析
    result = agent.analyze("""
    分析 ./data/sales.csv 文件:
    1. 总销售额和平均销售额是多少?
    2. 哪个月份销售额最高?
    3. 生成销售额趋势图
    """)
    
    print("\n" + "="*60)
    print("分析结果:")
    print("="*60)
    print(result)

运行效果

============================================================
🎯 Question: 分析 ./data/sales.csv 文件...
============================================================

【第1轮】
Thought: 我需要先查看数据的基本信息,了解数据结构
Action: DataAnalyzer
Action Input: {"file": "./data/sales.csv", "operation": "head"}

📊 Observation: 前5行数据:
  month  sales  profit
0   Jan  12000    3000
1   Feb  15000    3500
2   Mar  13000    2800
3   Apr  18000    4200
4   May  21000    5000
------------------------------------------------------------

【第2轮】
Thought: 数据包含3列:month, sales, profit
现在我需要计算总销售额和平均销售额
Action: PythonExecutor
Action Input: 
import pandas as pd
df = pd.read_csv('./data/sales.csv')
total_sales = df['sales'].sum()
avg_sales = df['sales'].mean()
print(f"总销售额: {total_sales}")
print(f"平均销售额: {avg_sales}")

📊 Observation: 执行成功:
总销售额: 79000
平均销售额: 15800.0
------------------------------------------------------------

【第3轮】
Thought: 已知总销售额79000,平均15800
现在找出销售额最高的月份
Action: PythonExecutor
Action Input:
import pandas as pd
df = pd.read_csv('./data/sales.csv')
max_month = df.loc[df['sales'].idxmax(), 'month']
max_sales = df['sales'].max()
print(f"销售额最高月份: {max_month} ({max_sales})")

📊 Observation: 执行成功:
销售额最高月份: May (21000)
------------------------------------------------------------

【第4轮】
Thought: 最高销售月份是May
现在生成趋势图
Action: ChartGenerator
Action Input: {"file": "./data/sales.csv", "chart_type": "line"}

📊 Observation: 图表已保存到 chart_line.png
------------------------------------------------------------

【第5轮】
Thought: 我已经完成了所有分析任务
Final Answer: 
数据分析报告:

1. 销售总览
   - 总销售额: ¥79,000
   - 平均销售额: ¥15,800/月

2. 最佳表现
   - 销售额最高月份: May (¥21,000)
   
3. 趋势分析
   - 销售额整体呈上升趋势
   - 从1月的¥12,000增长到5月的¥21,000
   - 增长率: 75%
   
4. 可视化
   - 趋势图已保存: chart_line.png

============================================================
✅ 分析完成
============================================================

七、工具安全性最佳实践

7.1 输入验证清单

# ✅ 完整的验证示例
def _validate_input(self, inputs):
    # 1. 类型检查
    if not isinstance(inputs.get('param'), str):
        raise ToolValidationError("参数必须是字符串")
    
    # 2. 长度限制
    if len(inputs.get('param')) > 1000:
        raise ToolValidationError("参数过长")
    
    # 3. 格式验证
    import re
    if not re.match(r'^[a-zA-Z0-9_]+$', inputs.get('param')):
        raise ToolValidationError("参数包含非法字符")
    
    # 4. 业务规则验证
    if inputs.get('amount', 0) < 0:
        raise ToolValidationError("金额不能为负")
    
    # 5. 白名单验证
    allowed_values = ['option1', 'option2']
    if inputs.get('choice') not in allowed_values:
        raise ToolValidationError("无效的选项")

7.2 资源限制

class ResourceLimitedTool(Tool):
    """资源受限的工具"""
    
    def __init__(self):
        super().__init__(
            name="ResourceLimited",
            description="资源受限的工具",
            max_execution_time=30
        )
        
        # 限制
        self.max_memory = 100 * 1024 * 1024  # 100MB
        self.max_cpu_time = 10  # 10秒
        self.max_file_size = 10 * 1024 * 1024  # 10MB
    
    def _execute(self, **kwargs):
        import resource
        
        # 设置内存限制
        resource.setrlimit(
            resource.RLIMIT_AS,
            (self.max_memory, self.max_memory)
        )
        
        # 设置CPU时间限制
        resource.setrlimit(
            resource.RLIMIT_CPU,
            (self.max_cpu_time, self.max_cpu_time)
        )
        
        # 执行实际操作
        return self._do_work(**kwargs)

7.3 审计日志

"""
audit.py - 审计日志
"""

import logging
import json
from datetime import datetime

class AuditLogger:
    """审计日志记录器"""
    
    def __init__(self, log_file='audit.log'):
        self.logger = logging.getLogger('audit')
        handler = logging.FileHandler(log_file)
        handler.setFormatter(logging.Formatter(
            '%(asctime)s - %(message)s'
        ))
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)
    
    def log_tool_execution(
        self,
        tool_name: str,
        inputs: dict,
        result: str,
        success: bool,
        user_id: str = None
    ):
        """记录工具执行"""
        log_entry = {
            'timestamp': datetime.now().isoformat(),
            'tool': tool_name,
            'user': user_id,
            'inputs': inputs,
            'result': result[:100],  # 只记录前100字符
            'success': success
        }
        
        self.logger.info(json.dumps(log_entry))

# 在Tool基类中集成
class AuditedTool(Tool):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.audit_logger = AuditLogger()
    
    def run(self, input_str: str) -> str:
        result = super().run(input_str)
        
        # 记录到审计日志
        self.audit_logger.log_tool_execution(
            tool_name=self.name,
            inputs={'input': input_str},
            result=result,
            success='错误' not in result
        )
        
        return result

八、常见问题与解决方案

Q1: 工具执行太慢怎么办?

问题: 某些工具(如大文件处理)执行时间过长

解决方案:

# 1. 异步执行
import asyncio

class AsyncTool(Tool):
    async def run_async(self, input_str: str):
        """异步执行"""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            None,
            self.run,
            input_str
        )

# 2. 进度回调
class ProgressTool(Tool):
    def __init__(self, progress_callback=None):
        super().__init__()
        self.progress_callback = progress_callback
    
    def _execute(self, **kwargs):
        total = 100
        for i in range(total):
            # 执行工作
            time.sleep(0.01)
            
            # 报告进度
            if self.progress_callback:
                self.progress_callback(i / total)
        
        return "完成"

# 使用
def show_progress(percent):
    print(f"\r进度: {percent:.0%}", end='')

tool = ProgressTool(progress_callback=show_progress)

Q2: 如何处理工具之间的依赖?

问题: 工具B需要工具A的输出

解决方案:

class ToolDependency:
    """工具依赖管理"""
    
    def __init__(self, tools: List[Tool]):
        self.tools = {t.name: t for t in tools}
        self.results_cache = {}
    
    def run_with_deps(self, tool_name: str, inputs: dict):
        """执行工具(处理依赖)"""
        # 检查依赖
        deps = self._get_dependencies(tool_name)
        
        # 先执行依赖
        for dep in deps:
            if dep not in self.results_cache:
                dep_result = self.tools[dep].run(inputs)
                self.results_cache[dep] = dep_result
        
        # 注入依赖的结果
        enhanced_inputs = inputs.copy()
        for dep in deps:
            enhanced_inputs[f'{dep}_result'] = self.results_cache[dep]
        
        # 执行目标工具
        return self.tools[tool_name].run(enhanced_inputs)

Q3: 如何实现工具的版本管理?

解决方案:

class VersionedTool(Tool):
    """版本化的工具"""
    
    VERSION = "1.0.0"
    
    def __init__(self):
        super().__init__(
            name=f"{self.__class__.__name__}@{self.VERSION}",
            description="..."
        )
    
    def run(self, input_str: str) -> str:
        # 在结果中包含版本信息
        result = super().run(input_str)
        return f"[v{self.VERSION}] {result}"

# 工具注册表
class ToolRegistry:
    """工具注册表"""
    
    def __init__(self):
        self.tools = {}
    
    def register(self, tool: Tool):
        """注册工具"""
        key = f"{tool.name}@{getattr(tool, 'VERSION', '1.0.0')}"
        self.tools[key] = tool
    
    def get(self, name: str, version: str = None):
        """获取工具"""
        if version:
            key = f"{name}@{version}"
        else:
            # 获取最新版本
            versions = [k for k in self.tools.keys() if k.startswith(name)]
            key = max(versions) if versions else None
        
        return self.tools.get(key)

九、工具测试最佳实践

9.1 单元测试

"""
test_tools.py - 工具测试
"""

import unittest
from file_tools import FileReader
from tool_base import ToolValidationError

class TestFileReader(unittest.TestCase):
    """FileReader工具测试"""
    
    def setUp(self):
        """测试前准备"""
        self.tool = FileReader(allowed_dirs=['/tmp'])
        
        # 创建测试文件
        with open('/tmp/test.txt', 'w') as f:
            f.write("Hello, World!")
    
    def test_read_file_success(self):
        """测试正常读取"""
        result = self.tool.run("/tmp/test.txt")
        self.assertIn("Hello, World!", result)
    
    def test_read_file_not_found(self):
        """测试文件不存在"""
        result = self.tool.run("/tmp/nonexistent.txt")
        self.assertIn("错误", result)
    
    def test_path_traversal_attack(self):
        """测试路径遍历攻击"""
        result = self.tool.run("/tmp/../etc/passwd")
        self.assertIn("错误", result)
    
    def test_file_too_large(self):
        """测试文件过大"""
        # 创建大文件
        with open('/tmp/large.txt', 'w') as f:
            f.write('x' * (20 * 1024 * 1024))  # 20MB
        
        result = self.tool.run("/tmp/large.txt")
        self.assertIn("太大", result)
    
    def tearDown(self):
        """测试后清理"""
        import os
        if os.path.exists('/tmp/test.txt'):
            os.remove('/tmp/test.txt')
        if os.path.exists('/tmp/large.txt'):
            os.remove('/tmp/large.txt')

if __name__ == '__main__':
    unittest.main()

9.2 集成测试

"""
test_integration.py - 集成测试
"""

import unittest
from data_analysis_agent import DataAnalysisAgent

class TestDataAnalysisAgent(unittest.TestCase):
    """数据分析Agent集成测试"""
    
    def setUp(self):
        """准备测试数据"""
        import pandas as pd
        
        data = {
            'product': ['A', 'B', 'C'],
            'sales': [100, 200, 150]
        }
        df = pd.DataFrame(data)
        df.to_csv('./test_data.csv', index=False)
        
        self.agent = DataAnalysisAgent()
    
    def test_full_analysis(self):
        """测试完整分析流程"""
        result = self.agent.analyze("""
        分析 test_data.csv:
        1. 总销售额
        2. 最畅销产品
        """)
        
        # 验证结果
        self.assertIn("450", result)  # 总销售额
        self.assertIn("B", result)    # 最畅销产品
    
    def tearDown(self):
        """清理"""
        import os
        if os.path.exists('./test_data.csv'):
            os.remove('./test_data.csv')

十、性能优化技巧

技巧1:缓存结果

from functools import lru_cache
import hashlib

class CachedTool(Tool):
    """带缓存的工具"""
    
    def __init__(self):
        super().__init__()
        self.cache = {}
    
    def run(self, input_str: str) -> str:
        # 计算输入的哈希
        cache_key = hashlib.md5(input_str.encode()).hexdigest()
        
        # 检查缓存
        if cache_key in self.cache:
            logger.info(f"[{self.name}] 缓存命中")
            return self.cache[cache_key]
        
        # 执行并缓存
        result = super().run(input_str)
        self.cache[cache_key] = result
        
        return result

技巧2:并行执行

from concurrent.futures import ThreadPoolExecutor

class ParallelToolExecutor:
    """并行工具执行器"""
    
    def __init__(self, max_workers=4):
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
    
    def run_parallel(self, tasks: List[tuple]):
        """并行执行多个工具
        
        Args:
            tasks: [(tool, input), (tool, input), ...]
        """
        futures = []
        for tool, input_str in tasks:
            future = self.executor.submit(tool.run, input_str)
            futures.append(future)
        
        # 等待所有任务完成
        results = [f.result() for f in futures]
        return results

# 使用示例
executor = ParallelToolExecutor()
results = executor.run_parallel([
    (calculator, "1+1"),
    (calculator, "2+2"),
    (calculator, "3+3")
])

技巧3:懒加载

class LazyTool(Tool):
    """懒加载工具"""
    
    def __init__(self):
        super().__init__()
        self._heavy_resource = None
    
    @property
    def heavy_resource(self):
        """懒加载重量级资源"""
        if self._heavy_resource is None:
            print("加载重量级资源...")
            self._heavy_resource = self._load_resource()
        return self._heavy_resource
    
    def _load_resource(self):
        """加载资源(只在需要时调用)"""
        # 比如加载大型模型
        import time
        time.sleep(2)  # 模拟加载时间
        return "Heavy Resource"

总结

今天我们深入学习了工具系统设计:

设计原则: 单一职责、防御性编程、优雅错误处理
工具基类: 完整的验证、超时、日志、统计
5个实用工具: 文件、代码、HTTP、数据库、数据分析
安全实践: 输入验证、资源限制、审计日志
性能优化: 缓存、并行、懒加载

💡 核心洞察: 工具是Agent的"手和脚",设计好工具系统,Agent才能真正发挥作用。安全性永远是第一位的。


下期预告

第4篇:《记忆系统实战:让Agent记住你说过的每句话》

下期我们将探讨:

  • 短期记忆、工作记忆、长期记忆的设计
  • 向量数据库的应用(Pinecone、Weaviate)
  • 上下文管理策略
  • 记忆检索和更新机制

练习题:

  1. 基础: 实现一个天气查询工具,调用真实的天气API
  2. 进阶: 给FileReader添加写入功能,确保安全性
  3. 挑战: 实现一个邮件发送工具,支持附件和HTML格式

欢迎在留言区分享你的实现!


如果这篇文章对你有帮助:

  • 👍 点赞让更多人看到
  • 🧐 关注架构之旅公众号
  • 🔖 收藏方便查阅
  • 💬 留言讨论你的想法

下期见!
本文是《AI Agent实战系列》第3 篇,后续还会更新AI Agent进阶玩法。关注公众号【架构之旅】,第一时间解锁全套实战教程,错过不再补~

Logo

这里是“一人公司”的成长家园。我们提供从产品曝光、技术变现到法律财税的全栈内容,并连接云服务、办公空间等稀缺资源,助你专注创造,无忧运营。

更多推荐