Source code for OmniSimulator.utils.task_verifier

"""
任务验证器 - 负责验证子任务的完成情况

该模块提供了TaskVerifier类,用于根据task.json文件中的validation_checks字段
检查子任务是否已完成,并返回详细的验证结果。
"""

from typing import Dict, List, Any, Optional, Tuple
import logging

logger = logging.getLogger(__name__)


[docs] class TaskVerificationResult: """任务验证结果类"""
[docs] def __init__(self, task_id: str, task_description: str): self.task_id = task_id self.task_description = task_description self.is_completed = False self.completion_details = {} self.error_message = None
[docs] def mark_completed(self, details: Dict[str, Any] = None): """标记任务为已完成""" self.is_completed = True self.completion_details = details or {}
[docs] def mark_failed(self, error_message: str): """标记任务验证失败""" self.is_completed = False self.error_message = error_message
[docs] def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" return { "task_id": self.task_id, "task_description": self.task_description, "is_completed": self.is_completed, "completion_details": self.completion_details, "error_message": self.error_message }
[docs] class TaskVerifier: """任务验证器 - 验证子任务完成情况"""
[docs] def __init__(self, task_data: Dict[str, Any], env_manager, config: Dict[str, Any] = None): """ 初始化任务验证器 Args: task_data: 任务数据,来自task.json文件 env_manager: 环境管理器,用于获取物体状态 config: 验证配置 """ self.task_data = task_data self.env_manager = env_manager self.config = config or {} # 解析任务数据中的验证信息 self.tasks = task_data.get("tasks", []) # 验证结果缓存 self.verification_cache = {} # 任务完成状态持久化存储(维护递增性) self.completed_tasks = set() # 存储已完成的任务ID
[docs] def verify_all_tasks(self) -> List[TaskVerificationResult]: """ 验证所有任务 Returns: List[TaskVerificationResult]: 验证结果列表 """ results = [] # 验证所有任务 for task in self.tasks: result = self._verify_single_task(task) results.append(result) return results
[docs] def verify_task_category(self, category: str) -> List[TaskVerificationResult]: """ 验证特定类别的任务 Args: category: 任务类别,如 "direct_command", "tool_use"等 Returns: List[TaskVerificationResult]: 验证结果列表 """ results = [] # 筛选指定类别的任务 filtered_tasks = [task for task in self.tasks if task.get('task_category') == category] for task in filtered_tasks: result = self._verify_single_task(task) results.append(result) return results
def _verify_single_task(self, task: Dict[str, Any]) -> TaskVerificationResult: """ 验证单个任务 Args: task: 任务定义,包含验证条件 Returns: TaskVerificationResult: 验证结果 """ # 使用任务描述作为ID(因为新格式没有单独的ID字段) task_description = task.get("task_description", "") task_id = f"task_{hash(task_description) % 10000}" # 生成简短的任务ID result = TaskVerificationResult(task_id, task_description) # 如果任务已经完成过,直接返回完成状态(维护递增性) if task_id in self.completed_tasks: result.mark_completed({"previously_completed": True}) logger.debug(f"任务已完成(缓存): {task_id}") return result try: # 获取验证检查列表 validation_checks = task.get("validation_checks", []) if not validation_checks: result.mark_failed("任务没有验证条件") return result # 检查所有验证条件 verification_passed = True completion_details = {} for check in validation_checks: check_id = check.get("id") if not check_id: verification_passed = False logger.debug("验证检查缺少id字段") continue # 获取目标物体 obj = self.env_manager.get_object_by_id(check_id) if not obj: verification_passed = False logger.debug(f"目标物体不存在: {check_id}") continue # 检查各种验证条件 for state_key, expected_value in check.items(): if state_key == "id": continue if state_key == "location_id": # 检查位置 current_location = obj.get("location_id") location_match = self._check_location_match(current_location, expected_value) if not location_match: verification_passed = False logger.debug(f"位置验证失败 - 物体: {check_id}, 期望: {expected_value}, 实际: {current_location}") else: completion_details[f"{check_id}_location_verified"] = True elif state_key.startswith("is_"): # 检查状态属性 current_value = obj.get("states", {}).get(state_key) if current_value != expected_value: verification_passed = False logger.debug(f"状态验证失败 - 物体: {check_id}, {state_key}: 期望 {expected_value}, 实际 {current_value}") else: # 检查是否为合作任务,如果是则需要验证合作标记 if self._is_cooperative_task(task): coop_attrs = obj.get("states", {}).get("cooperative_modified_attributes", []) if state_key in coop_attrs: completion_details[f"{check_id}_{state_key}_verified"] = True else: verification_passed = False logger.debug(f"合作任务验证失败 - 物体: {check_id}, 属性 {state_key} 未通过合作方式修改") else: completion_details[f"{check_id}_{state_key}_verified"] = True if verification_passed: result.mark_completed(completion_details) # 记录已完成的任务(维护递增性) self.completed_tasks.add(task_id) logger.debug(f"任务验证成功: {task_id}") else: result.mark_failed("验证条件不满足") logger.debug(f"任务验证失败: {task_id}") except Exception as e: result.mark_failed(f"验证过程中发生错误: {str(e)}") logger.error(f"验证任务 {task_id} 时发生错误: {e}") return result def _is_cooperative_task(self, task: Dict[str, Any]) -> bool: """ 判断任务是否为合作任务 Args: task: 任务定义 Returns: bool: 如果是合作任务返回True,否则返回False """ # 通过task_category判断 task_category = task.get("task_category", "") cooperative_categories = { "explicit_collaboration", "implicit_collaboration", "compound_collaboration" } if task_category in cooperative_categories: return True # 通过任务描述中的关键词判断 task_description = task.get("task_description", "").lower() cooperative_keywords = [ "cooperate", "cooperation", "cooperatively", "work together", "collaborate", "collaboration", "together", "jointly", "team up" ] for keyword in cooperative_keywords: if keyword in task_description: return True return False
[docs] def get_completion_summary(self) -> Dict[str, Any]: """ 获取任务完成情况摘要 Returns: Dict[str, Any]: 完成情况摘要 """ all_results = self.verify_all_tasks() summary = { "total_tasks": len(all_results), "completed_tasks": sum(1 for r in all_results if r.is_completed), "completion_rate": 0.0, "categories": {} } # 按类别统计 category_stats = {} for result in all_results: # 从任务数据中获取类别信息 task_category = None for task in self.tasks: if task.get("task_description") == result.task_description: task_category = task.get("task_category", "unknown") break if task_category not in category_stats: category_stats[task_category] = {"total": 0, "completed": 0} category_stats[task_category]["total"] += 1 if result.is_completed: category_stats[task_category]["completed"] += 1 # 计算各类别完成率 for category, stats in category_stats.items(): summary["categories"][category] = { "total": stats["total"], "completed": stats["completed"], "completion_rate": stats["completed"] / stats["total"] if stats["total"] > 0 else 0.0 } if summary["total_tasks"] > 0: summary["completion_rate"] = summary["completed_tasks"] / summary["total_tasks"] return summary
[docs] def get_subtask_completion_list(self) -> List[bool]: """ 获取所有子任务的完成状态列表 Returns: List[bool]: 按顺序返回每个子任务的完成状态 [True, False, True, ...] """ all_results = self.verify_all_tasks() completion_list = [] # 按任务顺序收集完成状态 for result in all_results: completion_list.append(result.is_completed) return completion_list
[docs] def verify_single_subtask(self, subtask: Dict[str, Any]) -> TaskVerificationResult: """ 验证单个子任务 Args: subtask: 子任务定义,包含验证条件 Returns: TaskVerificationResult: 验证结果 """ return self._verify_single_task(subtask)
[docs] def get_current_completion_status(self) -> Dict[str, Any]: """ 获取当前所有任务的完成状态 Returns: Dict[str, Any]: 包含完成状态的详细信息 """ all_results = self.verify_all_tasks() completed_tasks = [] for result in all_results: if result.is_completed: completed_tasks.append({ 'task_id': result.task_id, 'task_description': result.task_description, 'completion_details': result.completion_details }) return { 'total_tasks': len(all_results), 'completed_tasks': len(completed_tasks), 'completion_rate': len(completed_tasks) / len(all_results) if all_results else 0.0, 'completed_task_details': completed_tasks }
def _check_location_match(self, current_location: str, expected_location: str) -> bool: """ 检查位置是否匹配,支持灵活的in/on判定 支持的格式: - "in:location" - 精确匹配in前缀 - "on:location" - 精确匹配on前缀 - ":location" - 空前缀,匹配任何前缀的location - "location" - 无前缀,匹配任何前缀的location Args: current_location: 当前位置,如 "in:restoration_lab" expected_location: 期望位置,如 ":restoration_lab" 或 "in:restoration_lab" Returns: bool: 位置是否匹配 """ if not current_location or not expected_location: return current_location == expected_location # 解析期望位置 if expected_location.startswith(("in:", "on:")): # 有明确前缀(in: 或 on:),进行精确匹配 return current_location == expected_location elif expected_location.startswith(":"): # 空前缀格式 ":location",提取基础位置名 expected_base = expected_location[1:] # 去掉":"前缀 else: # 没有前缀,直接使用原值作为基础位置名 expected_base = expected_location # 解析当前位置,提取基础位置名 if current_location.startswith("in:"): current_base = current_location[3:] # 去掉"in:"前缀 elif current_location.startswith("on:"): current_base = current_location[3:] # 去掉"on:"前缀 elif current_location.startswith(":"): current_base = current_location[1:] # 去掉":"前缀 else: current_base = current_location # 比较基础位置名 return current_base == expected_base