Skip to content

Worker

This module implements a Worker class that waits for task assignments from the message bus, hands over assignments to the TaskDispatcher, and notifies listeners on task completions via the message bus.

Worker

Polls and processes tasks from the assignments queue.

Source code in germinate_ai/worker/worker.py
class Worker:
    """Polls and processes tasks from the assignments queue."""

    def __init__(
        self,
        nc: nats.NatsConnection,
        task_dispatcher: TaskDispatcher,
        id: int = 0,
        tick_interval: int = 10,
    ):
        self.nc = nc
        self.id = id
        self.tick_interval = tick_interval
        self.task_dispatcher = task_dispatcher

    async def run(self):
        """Connect to messaging bus, wait for assignments, and execute them."""
        next_tick = get_next_tick(self.tick_interval)

        await self.connect()
        logger.success(f"Worker #{self.id}: connected to cluster! Waiting for tasks...")

        while True:
            last_tick = time.time()

            try:
                msg = await self.assignments_queue.dequeue()
                task_json = msg.data
                # acknowledge message so we don't see it again
                await msg.ack()
                success = await self._run_task(task_json)
                if not success:
                    logger.exception("Task execution failure", task_json)
            except TimeoutError:
                pass
            except asyncio.CancelledError:
                logger.debug(f"Worker #{self.id}: Cancelled! Shutting down worker...")
                break
            except Exception as e:
                logger.exception("Error while reading from NATS queue: ", e)

            await asyncio.sleep(next_tick(last_tick=last_tick))

    async def connect(self):
        """Connect to task assignments and completions queue so we can get assignments/send task completion notifications via the message bus."""
        self.assignments_queue = NATSQueue(
            stream="jobs",
            subject="jobs.task_assignments",
            durable_consumer="task_runner",
            connection=self.nc,
        )
        await self.assignments_queue.connect()
        # TODO write only:
        self.completions_queue = NATSQueue(
            connection=self.nc, stream="jobs", subject="jobs.task_completions"
        )
        await self.completions_queue.connect()

    async def _run_task(self, task_json: str) -> bool:
        """Validate task assignment data, delegate execution to TaskDispatcher, and notify listeners on completion via the messaging bus."""
        try:
            assignment = TaskAssignment.model_validate_json(task_json)
        except ValidationError as e:
            logger.error(f"Worker #{self.id}: Skipping invalid task `{task_json}`: {e}")
            return

        logger.info(f"Worker #{self.id}: Starting task {assignment.name}")
        task = await self.task_dispatcher.execute(assignment=assignment)

        # Queue task completed message
        logger.debug(f"Queueing completed message {task.name}")
        await self.completions_queue.enqueue(assignment.model_dump_json())

        # return True => mark task as completed
        # TODO handle failures
        return True

connect() async

Connect to task assignments and completions queue so we can get assignments/send task completion notifications via the message bus.

Source code in germinate_ai/worker/worker.py
async def connect(self):
    """Connect to task assignments and completions queue so we can get assignments/send task completion notifications via the message bus."""
    self.assignments_queue = NATSQueue(
        stream="jobs",
        subject="jobs.task_assignments",
        durable_consumer="task_runner",
        connection=self.nc,
    )
    await self.assignments_queue.connect()
    # TODO write only:
    self.completions_queue = NATSQueue(
        connection=self.nc, stream="jobs", subject="jobs.task_completions"
    )
    await self.completions_queue.connect()

run() async

Connect to messaging bus, wait for assignments, and execute them.

Source code in germinate_ai/worker/worker.py
async def run(self):
    """Connect to messaging bus, wait for assignments, and execute them."""
    next_tick = get_next_tick(self.tick_interval)

    await self.connect()
    logger.success(f"Worker #{self.id}: connected to cluster! Waiting for tasks...")

    while True:
        last_tick = time.time()

        try:
            msg = await self.assignments_queue.dequeue()
            task_json = msg.data
            # acknowledge message so we don't see it again
            await msg.ack()
            success = await self._run_task(task_json)
            if not success:
                logger.exception("Task execution failure", task_json)
        except TimeoutError:
            pass
        except asyncio.CancelledError:
            logger.debug(f"Worker #{self.id}: Cancelled! Shutting down worker...")
            break
        except Exception as e:
            logger.exception("Error while reading from NATS queue: ", e)

        await asyncio.sleep(next_tick(last_tick=last_tick))

TaskDispatcher

Dispatches tasks to appropriate executors, and updates task state accordingly.

Given an enqueued task data, TaskDispatcher gets the correct executor from TaskRegistry and uses it to run the corresponding Task.

TaskDispatcher also validates input and output schemas for the task, and updates the task's state before ("queued"), and after ("completed"/"failed") execution.

Source code in germinate_ai/worker/task_dispatcher.py
class TaskDispatcher:
    """
    Dispatches tasks to appropriate executors, and updates task state accordingly.

    Given an enqueued task data, TaskDispatcher gets the correct executor from TaskRegistry and uses it to run the corresponding Task.

    TaskDispatcher also validates input and output schemas for the task, and updates the task's state before ("queued"),
    and after ("completed"/"failed") execution.
    """

    def __init__(self, nc: nats.NatsConnection, sessionmaker: typ.Callable):
        self.nc = nc
        self.sessionmaker = sessionmaker

    async def execute(self, assignment: TaskAssignment) -> TaskInstance:
        """Execute the enqueued task.

        Args:
            assignment (TaskAssignment): Task assignment data

        Returns:
            TaskInstance: SQLAlchemy model representing persisted task state
        """
        with self.sessionmaker() as db:
            # Get corresponding task from DB
            task = get_task_instance_from_assignment(db, assignment)
            if task is None:
                logger.error(f"No such task: skipping `{assignment}`")
                return None

            # Get executor for the task
            executor = TaskRegistry.get_executor(task.task_executor_name)

            # Get task inputs from dependencies' outputs
            task_input = await self._get_task_inputs(task)

            # Validate task input
            task_input = executor.input_schema.model_validate(task_input)

            # Update task's input
            task.input = task_input.model_dump()

            # Update task's state
            task.state = TaskStateEnum.queued
            db.add(task)
            db.commit()

            # TODO Run task executor pre-exec hook, if any

            # Run the task with executor
            # TODO handle failures
            # TODO async tasks
            logger.debug(
                f"Executing task {task.name} with executor {task.task_executor_name}..."
            )
            if executor.is_async():
                output = await executor(task_input)
            else:
                output = executor(task_input)

            # Validate task output
            task_output = executor.output_schema.model_validate(output)
            task.output = task_output.model_dump()

            # TODO Run task executor post-exec hook, if any

            # Save task state
            logger.debug(f"Completed task {task.name}!")
            task.state = TaskStateEnum.completed
            db.add(task)
            db.commit()

            # Write output to message bus for children tasks
            await self._put_task_output(task)

            return task

    async def _get_task_inputs(self, task: TaskInstance) -> dict:
        """Get Task inputs (i.e. outputs from parent tasks) from message bus."""

        # TODO nats interface refactor -- not clean here
        logger.debug(
            f"Getting task {task.name}'s dependencies' outputs: `{task.depends_on}`"
        )

        task_inputs = []
        for dep in task.depends_on:
            input_queue = NATSQueue(
                connection=self.nc,
                stream="jobs",
                subject=f"jobs.{task.state_instance_id}.from_{dep}.to_descendant",
            )
            await input_queue.connect()
            msg = await input_queue.dequeue()
            message = Message.model_validate_json(msg.data)
            task_inputs.append(message.payload)
            await msg.ack()

        # Merge all dicts
        input = dict(ChainMap(*task_inputs))

        return input

    async def _put_task_output(self, task: TaskInstance):
        """Write Task output into message bus for input to any children Tasks."""

        logger.debug(f"Writing task {task.name}'s output")

        # TODO refactor
        output_queue = NATSQueue(
            connection=self.nc,
            stream="jobs",
            subject=f"jobs.{task.state_instance_id}.from_{task.name}.to_descendant",
        )
        await output_queue.connect()

        msg = Message(source=task.name, payload=task.output)
        await output_queue.enqueue(msg.model_dump_json())

execute(assignment) async

Execute the enqueued task.

Parameters:

Name Type Description Default
assignment TaskAssignment

Task assignment data

required

Returns:

Name Type Description
TaskInstance TaskInstance

SQLAlchemy model representing persisted task state

Source code in germinate_ai/worker/task_dispatcher.py
async def execute(self, assignment: TaskAssignment) -> TaskInstance:
    """Execute the enqueued task.

    Args:
        assignment (TaskAssignment): Task assignment data

    Returns:
        TaskInstance: SQLAlchemy model representing persisted task state
    """
    with self.sessionmaker() as db:
        # Get corresponding task from DB
        task = get_task_instance_from_assignment(db, assignment)
        if task is None:
            logger.error(f"No such task: skipping `{assignment}`")
            return None

        # Get executor for the task
        executor = TaskRegistry.get_executor(task.task_executor_name)

        # Get task inputs from dependencies' outputs
        task_input = await self._get_task_inputs(task)

        # Validate task input
        task_input = executor.input_schema.model_validate(task_input)

        # Update task's input
        task.input = task_input.model_dump()

        # Update task's state
        task.state = TaskStateEnum.queued
        db.add(task)
        db.commit()

        # TODO Run task executor pre-exec hook, if any

        # Run the task with executor
        # TODO handle failures
        # TODO async tasks
        logger.debug(
            f"Executing task {task.name} with executor {task.task_executor_name}..."
        )
        if executor.is_async():
            output = await executor(task_input)
        else:
            output = executor(task_input)

        # Validate task output
        task_output = executor.output_schema.model_validate(output)
        task.output = task_output.model_dump()

        # TODO Run task executor post-exec hook, if any

        # Save task state
        logger.debug(f"Completed task {task.name}!")
        task.state = TaskStateEnum.completed
        db.add(task)
        db.commit()

        # Write output to message bus for children tasks
        await self._put_task_output(task)

        return task