Skip to content

States

States

States are state components of theWorkflow state machines and are themselves containers of tasks DAG.

State

User defined State in the workflow state machine.

Source code in germinate_ai/core/states/states.py
@attr.define(init=False, repr=False)
class State:
    """
    User defined State in the workflow state machine.
    """

    name: str
    _tasks: typ.Sequence["TaskDefinition"]
    _tasks_dict: dict[str, "TaskDefinition"]
    _conditions: typ.Sequence["Condition"]
    _dag: typ.Any
    _phases: typ.Sequence[typ.Sequence["TaskDefinition"]]

    def __init__(self, name: str):
        self.name = name

        self._tasks = []
        self._tasks_dict = dict()
        self._conditions = []
        self._dag = None
        self._phases = []

    # Return a task executor/task def type
    def task(self, namespace: str = "agent") -> typ.Callable:
        """Decorator that adds a task to the state's task DAG."""
        decorate = task_decorator_factory(namespace=namespace, state=self)
        return decorate

    def add_task(self, task: "TaskDefinition"):
        self._tasks.append(task)
        self._tasks_dict[task.name] = task

    def condition(self) -> typ.Callable:
        """Decorator that adds a condition to the state so that it can be used to define state transitions from it to other states."""
        decorate = condition_decorator_factory(state=self)
        return decorate

    def add_condition(self, condition: "Condition"):
        self._conditions.append(condition)

    @property
    def tasks(self) -> typ.Sequence["TaskDefinition"]:
        return self._tasks

    # TODO
    def tree(self):
        """Print tasks DAG."""
        pass

    def __and__(self, other: "Condition") -> "Transition":
        """Supports `State & Condition >> State` internal DSL by returning a new transition corresponding to `State & Condition`."""
        valid = isinstance(other, Condition)
        if not valid:
            raise TypeError("Protocol only supports `State & Condition`")
        transition = Transition(source=self, condition=other)
        return transition

    def __repr__(self) -> str:
        return f"<State: {self.name}>"

    @property
    def dag(self) -> typ.Any:
        if self._dag is None:
            self._dag = build_tasks_dag(self._tasks)
        return self._dag

    def validate(self) -> bool:
        """Validate that the tasks DAG is acyclic."""
        valid = is_dag(self.dag)
        if not valid:
            raise InvalidTasksDagException("Invalid Tasks DAG")


    def sorted_tasks_dag(self) -> typ.Sequence[typ.Sequence[str]]:
        """Topologically sort tasks DAG and return task names for each phase."""
        return list(toposort_tasks_phases(self.tasks))

    @property
    def conditions(self):
        return self._conditions

    @property
    def transitions(self):
        for condition in self._conditions:
            yield condition.transition


    def build(self):
        """Validate tasks DAG, add transition condition evaluation tasks to the DAG, figure out overall input and output schemas."""
        self.validate()

        phases = self.sorted_tasks_dag()
        # print(phases)

        # If there are multiple tasks at the end
        # add a task that merges the outputs
        if len(phases[-1]) > 1:
            raise NotImplementedError("Not implemented: more than one end task in DAG")

        # keep a ref to a task at the end of the DAG
        end_task_name = phases[-1][0]
        end_task = self._tasks_dict[end_task_name]

        # Add transition condition evaluation tasks to 
        # the DAG at the end
        for condition in self._conditions:
            transition = condition.transition
            if transition is not None and transition.is_valid():
                # print(f"Adding transition from {self} to {transition.target} on {condition}")
                # Order condition task as depending on end task
                end_task >> condition.task
                self.add_task(condition.task)

        phases = self.sorted_tasks_dag()
        logger.debug(f"{self.name} DAG phases: {phases}")

__and__(other)

Supports State & Condition >> State internal DSL by returning a new transition corresponding to State & Condition.

Source code in germinate_ai/core/states/states.py
def __and__(self, other: "Condition") -> "Transition":
    """Supports `State & Condition >> State` internal DSL by returning a new transition corresponding to `State & Condition`."""
    valid = isinstance(other, Condition)
    if not valid:
        raise TypeError("Protocol only supports `State & Condition`")
    transition = Transition(source=self, condition=other)
    return transition

build()

Validate tasks DAG, add transition condition evaluation tasks to the DAG, figure out overall input and output schemas.

Source code in germinate_ai/core/states/states.py
def build(self):
    """Validate tasks DAG, add transition condition evaluation tasks to the DAG, figure out overall input and output schemas."""
    self.validate()

    phases = self.sorted_tasks_dag()
    # print(phases)

    # If there are multiple tasks at the end
    # add a task that merges the outputs
    if len(phases[-1]) > 1:
        raise NotImplementedError("Not implemented: more than one end task in DAG")

    # keep a ref to a task at the end of the DAG
    end_task_name = phases[-1][0]
    end_task = self._tasks_dict[end_task_name]

    # Add transition condition evaluation tasks to 
    # the DAG at the end
    for condition in self._conditions:
        transition = condition.transition
        if transition is not None and transition.is_valid():
            # print(f"Adding transition from {self} to {transition.target} on {condition}")
            # Order condition task as depending on end task
            end_task >> condition.task
            self.add_task(condition.task)

    phases = self.sorted_tasks_dag()
    logger.debug(f"{self.name} DAG phases: {phases}")

condition()

Decorator that adds a condition to the state so that it can be used to define state transitions from it to other states.

Source code in germinate_ai/core/states/states.py
def condition(self) -> typ.Callable:
    """Decorator that adds a condition to the state so that it can be used to define state transitions from it to other states."""
    decorate = condition_decorator_factory(state=self)
    return decorate

sorted_tasks_dag()

Topologically sort tasks DAG and return task names for each phase.

Source code in germinate_ai/core/states/states.py
def sorted_tasks_dag(self) -> typ.Sequence[typ.Sequence[str]]:
    """Topologically sort tasks DAG and return task names for each phase."""
    return list(toposort_tasks_phases(self.tasks))

task(namespace='agent')

Decorator that adds a task to the state's task DAG.

Source code in germinate_ai/core/states/states.py
def task(self, namespace: str = "agent") -> typ.Callable:
    """Decorator that adds a task to the state's task DAG."""
    decorate = task_decorator_factory(namespace=namespace, state=self)
    return decorate

tree()

Print tasks DAG.

Source code in germinate_ai/core/states/states.py
def tree(self):
    """Print tasks DAG."""
    pass

validate()

Validate that the tasks DAG is acyclic.

Source code in germinate_ai/core/states/states.py
def validate(self) -> bool:
    """Validate that the tasks DAG is acyclic."""
    valid = is_dag(self.dag)
    if not valid:
        raise InvalidTasksDagException("Invalid Tasks DAG")

Conditions

Types of Task that evaluate if a state Transition should be triggered.

Condition

Conditions evaluate if a state transition should be triggered.

Source code in germinate_ai/core/states/conditions.py
@attr.define(init=False, repr=False)
class Condition:
    """Conditions evaluate if a state transition should be triggered."""

    name: str

    state: "State"

    task: "TaskDefinition"
    executor: "TaskExecutor"

    transition: typ.Optional["Transition"]

    negated_condition: typ.Optional["Condition"]

    def __init__(
        self,
        name: str,
        state: "State" = None,
        transition: "Transition" = None,
        task: "TaskDefinition" = None,
        executor: "TaskExecutor" = None,
    ) -> None:
        self.name = name
        self.state = state
        self.task = task
        self.executor = executor
        self.transition = transition

    def __invert__(self) -> str:
        """Return a copy of this condition with the condition negated."""
        raise NotImplementedError("Not implemented yet.")

    def __repr__(self) -> str:
        return f"<Condition: state={self.state.name}, task={self.task.name}>"

__invert__()

Return a copy of this condition with the condition negated.

Source code in germinate_ai/core/states/conditions.py
def __invert__(self) -> str:
    """Return a copy of this condition with the condition negated."""
    raise NotImplementedError("Not implemented yet.")

ConditionOutputSchema

Bases: BaseModel

Source code in germinate_ai/core/states/conditions.py
class ConditionOutputSchema(BaseModel):
    condition_evaluation: bool
    """Result of evaluating the condition."""

condition_evaluation: bool instance-attribute

Result of evaluating the condition.

condition_decorator_factory(state, namespace=None)

Creates a decorator that registers a state transition condition for a given State.

Source code in germinate_ai/core/states/conditions.py
def condition_decorator_factory(
    state: "State", namespace: str = None
) -> typ.Callable[[typ.Callable], typ.Callable]:
    """Creates a decorator that registers a state transition condition for a given State."""
    if namespace is None:
        namespace = "transition_conditions"

    def decorate(func: typ.Callable[[typ.Concatenate[...]], bool]):
        # Re-wrap the wrapped condition function in a function that matches the task executor
        # function protocol
        if asyncio.iscoroutinefunction(func):
            @wraps(func)
            async def wrapper(*args: typ.Any, **kwargs: typ.Any) -> ConditionOutputSchema:
                bound = resolve_dependencies(func, *args, **kwargs)
                result = await func(*bound.args, **bound.kwargs)
                return ConditionOutputSchema(condition_evaluation=result)
        else:
            @wraps(func)
            def wrapper(*args: typ.Any, **kwargs: typ.Any) -> ConditionOutputSchema:
                bound = resolve_dependencies(func, *args, **kwargs)
                result = func(*bound.args, **bound.kwargs)
                return ConditionOutputSchema(condition_evaluation=result)

        # Compose in an executor callable
        executor_name = (
            func.__name__
            if func.__name__.endswith("executor")
            else f"{func.__name__}_executor"
        )
        executor = TaskExecutor(
            namespace=namespace,
            name=executor_name,
            # TODO
            # input schema for condition is the combined output of the final
            # phase of a state's tasks
            # output schema is always the eval schema
            # input_schema=input_schema,
            input_schema=ConditionInputSchema,
            output_schema=ConditionOutputSchema,
            callable=wrapper,
        )
        # Compose executor in a task definition so we can add it to
        # the end of a state's tasks DAG
        task = TaskDefinition(
            name=func.__name__,
            state=state,
            executor=executor,
        )
        # Compose the whole thing in a Condition
        condition = Condition(
            name=func.__name__, state=state, task=task, executor=executor
        )

        # Add condition (not the corresponding task which is added when building the DAG) to state
        state.add_condition(condition)

        # Register task executor so we can actually schedule and evaluate the condition on workers
        TaskRegistry.register(
            namespace=namespace, name=executor_name, executor=executor
        )
        return condition

    return decorate

Transitions

Transition from one Workflow State to another if a Condition is fulfilled.

Transition

Represents a transition from a source state to a target state when a condition is fulfilled.

Source code in germinate_ai/core/states/transitions.py
@attr.define(init=False, repr=False)
class Transition:
    """Represents a transition from a source state to a target state when a condition is fulfilled."""

    _source: "State"
    _target: "State"
    _condition: Condition

    def __init__(
        self,
        source: "State" = None,
        condition: Condition = None,
        target: "State" = None,
    ) -> None:
        self._source = source
        self._condition = condition
        self._target = target
        self._condition.transition = self

    @property
    def source(self) -> "State":
        return self._source

    @property
    def target(self) -> "State":
        return self._target

    @property
    def condition(self) -> Condition:
        return self._condition

    def __rshift__(self, other: "State") -> "Transition":
        # TODO
        from .states import State
        valid = isinstance(other, State)
        if not valid:
            raise TypeError("Protocol only supports `Transition >> State`")
        self._target = other
        # TODO register transition and use it
        return self

    def is_valid(self):
        """Is a valid transaction? I.e. has valid source, target and condition."""
        return (
            self._source is not None
            and self._condition is not None
            and self._target is not None
        )

    def __repr__(self) -> str:
        return f"<Transition: {self._source.name} -> {self._target.name}>"

is_valid()

Is a valid transaction? I.e. has valid source, target and condition.

Source code in germinate_ai/core/states/transitions.py
def is_valid(self):
    """Is a valid transaction? I.e. has valid source, target and condition."""
    return (
        self._source is not None
        and self._condition is not None
        and self._target is not None
    )