Skip to content

Tasks

Task Definitions

TaskDefinition

Represent's a user's definition of a task.

Usually not meant to be instantiated directly.

The @<state>.task decorator wraps the user defined task executor function in a TaskDefinition behind the scenes (actually wraps the functions in a TaskExecutor which is composed in a TaskDefinition.)

Source code in germinate_ai/core/tasks/definition.py
@attr.define(init=False, repr=False)
class TaskDefinition:
    """Represent's a user's definition of a task.

    Usually not meant to be instantiated directly.

    The `@<state>.task` decorator wraps the user defined task executor function in a `TaskDefinition` behind the scenes (actually
    wraps the functions in a `TaskExecutor` which is composed in a `TaskDefinition`.)
    """

    name: str
    state: typ.Optional["State"]

    # Define one of 
    # - executor (define your own executor), or
    # - executor_name (use a preset executor)
    executor: typ.Optional[TaskExecutor]
    executor_name: typ.Optional[str]

    _parents: set["TaskDefinition"]
    _children: set["TaskDefinition"]

    def __init__(
        self,
        name: str,
        *,
        executor: TaskExecutor = None,
        executor_name: str = None,
        state: "State" = None,
    ):
        self.name = name
        self.state = state

        self.executor = executor
        self.executor_name = executor_name

        self._parents = set()
        self._children = set()

    def __call__(self, *args: typ.Any, **kwargs: typ.Any) -> BaseModel:
        # delegate to executor
        return self.executor(*args, **kwargs)

    def __hash__(self):
        return hash(f"{self.state.name}.{self.name}")

    @property
    def parents(self) -> list[str]:
        return self._parents

    @property
    def children(self) -> list[str]:
        return self._children

    def add_parents(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
        """This task depends on other tasks."""
        if not isinstance(others, typ.Sequence):
            others = [others]

        self._parents.update(others)
        for other in others:
            other._children.add(self)

    def add_children(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
        """Other tasks depend on this task."""
        if not isinstance(others, typ.Sequence):
            others = [others]

        self._children.update(others)
        for other in others:
            other._parents.add(self)

    def __lshift__(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
        """
        Task << Task | [Task]

        This task depends on other task(s).
        """
        self.add_parents(others)
        # Note: Returning right hand side for fluent like DAG definition
        return others

    def __rshift__(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
        """
        Task >> Task | [Task]

        Other task(s) depend on this task.
        """
        self.add_children(others)
        # Note: Returning right hand side for fluent like DAG definition
        return others

    def __rlshift__(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
        """
        Task | [Task] << (this) Task

        Other task(s) depend on this task.
        """
        self.__rshift__(others)
        return self

    def __rrshift__(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
        """
        Task | [Task] >> (this) task

        This task depends on other task(s).
        """
        self.__lshift__(others)
        return self

    def __repr__(self) -> str:
        state_name = ""
        if self.state:
            state_name = f"{self.state.name}."
        prefix = f"<Task: {state_name}{self.name} "
        parents = ", ".join(t.name for t in self.parents)
        children = ", ".join(t.name for t in self.children)
        return prefix + f"(parents: [{parents}], children: [{children}])>"

__lshift__(others)

Task << Task | [Task]

This task depends on other task(s).

Source code in germinate_ai/core/tasks/definition.py
def __lshift__(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
    """
    Task << Task | [Task]

    This task depends on other task(s).
    """
    self.add_parents(others)
    # Note: Returning right hand side for fluent like DAG definition
    return others

__rlshift__(others)

Task | [Task] << (this) Task

Other task(s) depend on this task.

Source code in germinate_ai/core/tasks/definition.py
def __rlshift__(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
    """
    Task | [Task] << (this) Task

    Other task(s) depend on this task.
    """
    self.__rshift__(others)
    return self

__rrshift__(others)

Task | [Task] >> (this) task

This task depends on other task(s).

Source code in germinate_ai/core/tasks/definition.py
def __rrshift__(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
    """
    Task | [Task] >> (this) task

    This task depends on other task(s).
    """
    self.__lshift__(others)
    return self

__rshift__(others)

Task >> Task | [Task]

Other task(s) depend on this task.

Source code in germinate_ai/core/tasks/definition.py
def __rshift__(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
    """
    Task >> Task | [Task]

    Other task(s) depend on this task.
    """
    self.add_children(others)
    # Note: Returning right hand side for fluent like DAG definition
    return others

add_children(others)

Other tasks depend on this task.

Source code in germinate_ai/core/tasks/definition.py
def add_children(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
    """Other tasks depend on this task."""
    if not isinstance(others, typ.Sequence):
        others = [others]

    self._children.update(others)
    for other in others:
        other._parents.add(self)

add_parents(others)

This task depends on other tasks.

Source code in germinate_ai/core/tasks/definition.py
def add_parents(self, others: "TaskDefinition" | typ.Sequence["TaskDefinition"]):
    """This task depends on other tasks."""
    if not isinstance(others, typ.Sequence):
        others = [others]

    self._parents.update(others)
    for other in others:
        other._children.add(self)

Task Executors

BaseTaskExecutor

Bases: Callable[Concatenate[...], Any]

Base class for task executors.

Source code in germinate_ai/core/tasks/executors/base.py
class BaseTaskExecutor(Callable[typ.Concatenate[...], typ.Any]):
    """Base class for task executors."""

    # The name used to find this executor from the registry
    name: str

    @abstractmethod
    def __call__(self, *args: typ.Any, **kwargs: typ.Any) -> typ.Any:
        pass

TaskExecutor

Bases: BaseTaskExecutor

TaskExecutors execute tasks on workers.

Task definitions, using a namespace and the executor name, specify the executor that workers should use to execute specific tasks.

Source code in germinate_ai/core/tasks/executors/task_executor.py
@attr.define(init=False, repr=False)
class TaskExecutor(BaseTaskExecutor):
    """`TaskExecutor`s execute tasks on workers.

    Task definitions, using a namespace and the executor name, specify the executor that workers should use to execute specific tasks.
    """

    namespace: str
    name: str

    input_schema: BaseModel
    output_schema: BaseModel

    callable: TaskExecutorCallable
    _callable_sig: inspect.Signature

    def __init__(
        self,
        name: str,
        callable: TaskExecutorCallable,
        *,
        input_schema: BaseModel = None,
        output_schema: BaseModel = None,
        namespace: str = "custom_tasks",
    ):
        self.namespace = namespace
        self.name = name
        self.input_schema = input_schema
        self.output_schema = output_schema

        self.callable = callable
        self._callable_sig = None

    def __call__(self, *args: typ.Any, **kwargs: typ.Any) -> BaseModel:
        return self.callable(*args, **kwargs)

    def __hash__(self):
        return hash(f"{self.namepace}.{self.name}")

    def __repr__(self) -> str:
        return f"<Task Executor: {self.name} >"

    def is_async(self) -> bool:
        """Is the underlying callable a coroutine function?"""
        return asyncio.iscoroutinefunction(self.callable)

    @property
    def registered_name(self) -> str:
        """Name the task executor is registered under."""
        return f"{self.namespace}.{self.name}"

registered_name: str property

Name the task executor is registered under.

is_async()

Is the underlying callable a coroutine function?

Source code in germinate_ai/core/tasks/executors/task_executor.py
def is_async(self) -> bool:
    """Is the underlying callable a coroutine function?"""
    return asyncio.iscoroutinefunction(self.callable)

Task Registry

Registry for task executors.

TaskKey

Key used to find a task in the registry by namespace and task name.

Source code in germinate_ai/core/tasks/registry.py
@attr.define(frozen=True, slots=True)
class TaskKey:
    """Key used to find a task in the registry by namespace and task name."""

    namespace: str
    name: str

TaskRegistry

Registry of registered task definitions.

Source code in germinate_ai/core/tasks/registry.py
class TaskRegistry:
    """Registry of registered task definitions."""

    _namespaces: set[str] = set()
    _tasks: dict[TaskKey, TaskExecutor] = {}

    @classmethod
    def register(cls, namespace: str, name: str, executor: TaskExecutor):
        """Register a task executor."""
        cls._namespaces.add(namespace)
        k = TaskKey(namespace=namespace, name=name)
        # TODO check if overwrites
        cls._tasks[k] = executor

    @classmethod
    def get_executor(cls, executor_name: str) -> TaskExecutor:
        namespace, name = executor_name.split(sep=".", maxsplit=1)
        return cls.get(namespace=namespace, name=name)

    @classmethod
    def get(cls, namespace: str, name: str) -> TaskExecutor:
        """Get a registered task executor."""
        if namespace not in cls._namespaces:
            raise KeyError(f"No such namespace {namespace}")
        k = TaskKey(namespace=namespace, name=name)
        if k not in cls._tasks:
            raise KeyError(f"No task executor registered for {name}")
        return cls._tasks[k]

get(namespace, name) classmethod

Get a registered task executor.

Source code in germinate_ai/core/tasks/registry.py
@classmethod
def get(cls, namespace: str, name: str) -> TaskExecutor:
    """Get a registered task executor."""
    if namespace not in cls._namespaces:
        raise KeyError(f"No such namespace {namespace}")
    k = TaskKey(namespace=namespace, name=name)
    if k not in cls._tasks:
        raise KeyError(f"No task executor registered for {name}")
    return cls._tasks[k]

register(namespace, name, executor) classmethod

Register a task executor.

Source code in germinate_ai/core/tasks/registry.py
@classmethod
def register(cls, namespace: str, name: str, executor: TaskExecutor):
    """Register a task executor."""
    cls._namespaces.add(namespace)
    k = TaskKey(namespace=namespace, name=name)
    # TODO check if overwrites
    cls._tasks[k] = executor

Decorators

Decorators to convert functions into task definitions.

task_decorator_factory(namespace, state)

Creates a decorator that registers a task execution in the given namespace.

Source code in germinate_ai/core/tasks/decorators.py
def task_decorator_factory(
    namespace: str, state: "State"
) -> typ.Callable[[typ.Callable], typ.Callable]:
    """Creates a decorator that registers a task execution in the given namespace."""

    def decorate(func: typ.Callable):
        # Get IO Schemas
        input_schema, output_schema = get_io_schemas(func)

        # Resolve DI arguments
        wrapper = resolve_dependencies_wrapper(func)

        # Wrap in an executor callable
        executor_name = (
            func.__name__
            if func.__name__.endswith("executor")
            else f"{func.__name__}_executor"
        )
        executor = TaskExecutor(
            namespace=namespace,
            name=executor_name,
            input_schema=input_schema,
            output_schema=output_schema,
            callable=wrapper,
        )

        # Wrap executor in a task definition
        task = TaskDefinition(
            name=func.__name__,
            state=state,
            executor=executor,
        )

        # Add to state
        state.add_task(task)

        # Register task executor
        TaskRegistry.register(
            namespace=namespace, name=executor_name, executor=executor
        )
        return task

    return decorate

Algorithms

Topological sorting tasks DAG using networkx.

build_tasks_dag(tasks)

Build NetworkX Directed Graph from a list of tasks.

Source code in germinate_ai/core/tasks/algorithms.py
def build_tasks_dag(tasks: typ.Sequence[TaskDefinition]) -> nx.DiGraph:
    """Build NetworkX Directed Graph from a list of tasks."""
    g = nx.DiGraph()
    for task in tasks:
        g.add_node(task.name, task=task)
        for dep in task.parents:
            g.add_edge(dep.name, task.name)
    return g

is_dag(g)

Check that the graph is a DAG.

Source code in germinate_ai/core/tasks/algorithms.py
def is_dag(g: nx.DiGraph) -> bool:
    """Check that the graph is a DAG."""
    return nx.is_directed_acyclic_graph(g)

topological_generations(g)

Get topologically sorted generations from a DAG.

Source code in germinate_ai/core/tasks/algorithms.py
def topological_generations(g: nx.DiGraph) -> typ.Generator[list[str], None, None]:
    """Get topologically sorted generations from a DAG."""
    if not nx.is_directed_acyclic_graph(g):
        raise TypeError("Invalid tasks specification: not a DAG")

    return nx.topological_generations(g)

toposort_tasks_phases(tasks)

Get "phases" of tasks from a list of tasks, so that all the tasks in a single phase can be run in parallel.

Source code in germinate_ai/core/tasks/algorithms.py
def toposort_tasks_phases(tasks: typ.Sequence[TaskDefinition]) -> typ.Generator[list[str], None, None]:
    """Get "phases" of tasks from a list of tasks, so that all the tasks in a single phase can be run in parallel."""
    g = build_tasks_dag(tasks)
    phases = topological_generations(g)
    return phases