task_registry.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from functools import wraps
  2. from typing import Dict, Type, List
  3. # 存储所有注册的任务类
  4. registered_tasks: Dict[str, Type] = {}
  5. def register_task(task_name: str):
  6. """
  7. 任务注册装饰器,用于将任务类注册到系统中。
  8. Args:
  9. task_name: 任务的唯一标识名称
  10. Returns:
  11. 装饰器函数
  12. """
  13. def decorator(task_class: Type):
  14. @wraps(task_class)
  15. def wrapper(*args, **kwargs):
  16. return task_class(*args, **kwargs)
  17. # 将任务类注册到字典中
  18. registered_tasks[task_name] = task_class
  19. return wrapper
  20. return decorator
  21. def get_registered_tasks() -> Dict[str, Type]:
  22. """
  23. 获取所有注册的任务类。
  24. Returns:
  25. Dict[str, Type]: 包含所有注册任务的字典,键为任务名称,值为任务类
  26. """
  27. return registered_tasks
  28. def get_task_class(task_name: str) -> Type:
  29. """
  30. 根据任务名称获取对应的任务类。
  31. Args:
  32. task_name: 任务的唯一标识名称
  33. Returns:
  34. Type: 任务类
  35. Raises:
  36. KeyError: 如果任务名称不存在
  37. """
  38. if task_name not in registered_tasks:
  39. raise KeyError(f"任务 '{task_name}' 未注册")
  40. return registered_tasks[task_name]