[ENG-1513] Better integration between Jobs and processing Actors (#1974)

* First draft on new task system

* Removing save to disk from task system

* Bunch of concurrency issues

* Solving Future impl issue when pausing tasks

* Fix cancel and abort

* Bunch of fixes on pause, suspend, resume, cancel and abort
Also better error handling on task completion for the user

* New capabilities to return an output on a task

* Introducing a simple way to linear backoff on failed steal

* Sample actor where tasks can dispatch more tasks

* Rustfmt

* Steal test to make sure

* Stale deps cleanup

* Removing unused utils

* Initial lib docs

* Docs ok

* Memory cleanup on idle

---------

Co-authored-by: Vítor Vasconcellos <vasconcellos.dev@gmail.com>
This commit is contained in:
Ericson "Fogo" Soares 2024-02-26 16:45:58 -03:00 committed by GitHub
parent 53713a9f59
commit dba85ebac3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 4064 additions and 7 deletions

View file

@ -74,7 +74,7 @@ To run the landing page:
If you encounter any issues, ensure that you are using the following versions of Rust, Node and Pnpm:
- Rust version: **1.73.0**
- Rust version: **1.75.0**
- Node version: **18.17**
- Pnpm version: **8.0.0**

39
Cargo.lock generated
View file

@ -2277,6 +2277,12 @@ version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
[[package]]
name = "downcast-rs"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650"
[[package]]
name = "dtoa"
version = "1.0.9"
@ -4183,6 +4189,16 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "lending-stream"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6df9b2ed75ba713c108ef896cf3f4cb07b4bbe42b074de3a2d0b0b0f874bac42"
dependencies = [
"futures-core",
"pin-project",
]
[[package]]
name = "libc"
version = "0.2.153"
@ -8089,6 +8105,29 @@ dependencies = [
"thiserror",
]
[[package]]
name = "sd-task-system"
version = "0.1.0"
dependencies = [
"async-channel",
"async-trait",
"downcast-rs",
"futures",
"futures-concurrency",
"lending-stream",
"pin-project",
"rand 0.8.5",
"rmp-serde",
"serde",
"tempfile",
"thiserror",
"tokio",
"tokio-stream",
"tracing",
"tracing-test",
"uuid",
]
[[package]]
name = "sd-utils"
version = "0.1.0"

View file

@ -52,6 +52,7 @@ swift-rs = { version = "1.0.6" }
# Third party dependencies used by one or more of our crates
anyhow = "1.0.75"
async-channel = "2.0.0"
async-trait = "0.1.77"
axum = "0.6.20"
base64 = "0.21.5"
blake3 = "1.5.0"

View file

@ -3,7 +3,7 @@ name = "sd-core"
version = "0.2.4"
description = "Virtual distributed filesystem engine that powers Spacedrive."
authors = ["Spacedrive Technology Inc."]
rust-version = "1.73.0"
rust-version = "1.75.0"
license = { workspace = true }
repository = { workspace = true }
edition = { workspace = true }
@ -51,6 +51,7 @@ sd-cloud-api = { version = "0.1.0", path = "../crates/cloud-api" }
# Workspace dependencies
async-channel = { workspace = true }
async-trait = { workspace = true }
axum = { workspace = true }
base64 = { workspace = true }
blake3 = { workspace = true }
@ -100,7 +101,6 @@ webp = { workspace = true }
# Specific Core dependencies
async-recursion = "1.0.5"
async-stream = "0.3.5"
async-trait = "^0.1.74"
bytes = "1.5.0"
ctor = "0.2.5"
directories = "5.0.1"

View file

@ -4,7 +4,7 @@ version = "0.1.0"
authors = ["Ericson Soares <ericson@spacedrive.com>"]
readme = "README.md"
description = "A simple library to generate video thumbnails using ffmpeg with the webp format"
rust-version = "1.73.0"
rust-version = "1.75.0"
license = { workspace = true }
repository = { workspace = true }
edition = { workspace = true }

View file

@ -3,7 +3,7 @@ name = "sd-file-path-helper"
version = "0.1.0"
authors = ["Ericson Soares <ericson@spacedrive.com>"]
readme = "README.md"
rust-version = "1.73.0"
rust-version = "1.75.0"
license = { workspace = true }
repository = { workspace = true }
edition = { workspace = true }

View file

@ -0,0 +1,42 @@
[package]
name = "sd-task-system"
version = "0.1.0"
authors = ["Ericson \"Fogo\" Soares <ericson.ds999@gmail.com>"]
rust-version = "1.75.0"
license.workspace = true
edition.workspace = true
repository.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
# Workspace deps
async-channel = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
futures-concurrency = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = [
"sync",
"parking_lot",
"rt-multi-thread",
"time",
] }
tokio-stream = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true, features = ["v4"] }
# External deps
downcast-rs = "1.2.0"
pin-project = "1.1.4"
[dev-dependencies]
tokio = { workspace = true, features = ["macros", "test-util", "fs"] }
tempfile = { workspace = true }
rand = "0.8.5"
tracing-test = { version = "^0.2.4", features = ["no-env-filter"] }
thiserror = { workspace = true }
lending-stream = "1.0.0"
serde = { workspace = true, features = ["derive"] }
rmp-serde = { workspace = true }
uuid = { workspace = true, features = ["serde"] }

View file

@ -0,0 +1,28 @@
use std::{error::Error, fmt};
use super::task::TaskId;
/// Task system's error type definition, representing when internal errors occurs.
#[derive(Debug, thiserror::Error)]
pub enum SystemError {
#[error("task not found <task_id='{0}'>")]
TaskNotFound(TaskId),
#[error("task aborted <task_id='{0}'>")]
TaskAborted(TaskId),
#[error("task join error <task_id='{0}'>")]
TaskJoin(TaskId),
#[error("forced abortion for task <task_id='{0}'> timed out")]
TaskForcedAbortTimeout(TaskId),
}
/// Trait for errors that can be returned by tasks, we use this trait as a bound for the task system generic
/// error type.
///
///With this trait, we can have a unified error type through all the tasks in the system.
pub trait RunError: Error + fmt::Debug + Send + Sync + 'static {}
/// We provide a blanket implementation for all types that also implements
/// [`std::error::Error`](https://doc.rust-lang.org/std/error/trait.Error.html) and
/// [`std::fmt::Debug`](https://doc.rust-lang.org/std/fmt/trait.Debug.html).
/// So you will not need to implement this trait for your error type, just implement the `Error` and `Debug`
impl<T: Error + fmt::Debug + Send + Sync + 'static> RunError for T {}

View file

@ -0,0 +1,71 @@
//!
//! # Task System
//!
//! Spacedrive's Task System is a library that provides a way to manage and execute tasks in a concurrent
//! and parallel environment.
//!
//! Just bring your own unified error type and dispatch some tasks, the system will handle enqueueing,
//! parallel execution, and error handling for you. Aside from some niceties like:
//! - Round robin scheduling between workers following the available CPU cores on the user machine;
//! - Work stealing between workers for better load balancing;
//! - Gracefully pause and cancel tasks;
//! - Forced abortion of tasks;
//! - Prioritizing tasks that will suspend running tasks without priority;
//! - When the system is shutdown, it will return all pending and running tasks to theirs dispatchers, so the user can store them on disk or any other storage to be re-dispatched later;
//!
//!
//! ## Basic example
//!
//! ```
//! use sd_task_system::{TaskSystem, Task, TaskId, ExecStatus, TaskOutput, Interrupter, TaskStatus};
//! use async_trait::async_trait;
//! use thiserror::Error;
//!
//! #[derive(Debug, Error)]
//! pub enum SampleError {
//! #[error("Sample error")]
//! SampleError,
//! }
//!
//! #[derive(Debug)]
//! pub struct ReadyTask {
//! id: TaskId,
//! }
//!
//! #[async_trait]
//! impl Task<SampleError> for ReadyTask {
//! fn id(&self) -> TaskId {
//! self.id
//! }
//!
//! async fn run(&mut self, _interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
//! Ok(ExecStatus::Done(TaskOutput::Empty))
//! }
//! }
//!
//! #[tokio::main]
//! async fn main() {
//! let system = TaskSystem::new();
//!
//! let handle = system.dispatch(ReadyTask { id: TaskId::new_v4() }).await;
//!
//! assert!(matches!(
//! handle.await,
//! Ok(TaskStatus::Done(TaskOutput::Empty))
//! ));
//!
//! system.shutdown().await;
//! }
//! ```
mod error;
mod message;
mod system;
mod task;
mod worker;
pub use error::{RunError, SystemError as TaskSystemError};
pub use system::{Dispatcher as TaskDispatcher, System as TaskSystem};
pub use task::{
AnyTaskOutput, ExecStatus, Interrupter, InterrupterFuture, InterruptionKind, IntoAnyTaskOutput,
IntoTask, Task, TaskHandle, TaskId, TaskOutput, TaskStatus,
};

View file

@ -0,0 +1,63 @@
use tokio::sync::oneshot;
use super::{
error::{RunError, SystemError},
task::{TaskId, TaskWorkState},
worker::WorkerId,
};
#[derive(Debug)]
pub(crate) enum SystemMessage {
IdleReport(WorkerId),
WorkingReport(WorkerId),
ResumeTask {
task_id: TaskId,
worker_id: WorkerId,
ack: oneshot::Sender<Result<(), SystemError>>,
},
PauseNotRunningTask {
task_id: TaskId,
worker_id: WorkerId,
ack: oneshot::Sender<Result<(), SystemError>>,
},
CancelNotRunningTask {
task_id: TaskId,
worker_id: WorkerId,
ack: oneshot::Sender<Result<(), SystemError>>,
},
ForceAbortion {
task_id: TaskId,
worker_id: WorkerId,
ack: oneshot::Sender<Result<(), SystemError>>,
},
NotifyIdleWorkers {
start_from: WorkerId,
task_count: usize,
},
ShutdownRequest(oneshot::Sender<Result<(), SystemError>>),
}
#[derive(Debug)]
pub(crate) enum WorkerMessage<E: RunError> {
NewTask(TaskWorkState<E>),
TaskCountRequest(oneshot::Sender<usize>),
ResumeTask {
task_id: TaskId,
ack: oneshot::Sender<Result<(), SystemError>>,
},
PauseNotRunningTask {
task_id: TaskId,
ack: oneshot::Sender<Result<(), SystemError>>,
},
CancelNotRunningTask {
task_id: TaskId,
ack: oneshot::Sender<Result<(), SystemError>>,
},
ForceAbortion {
task_id: TaskId,
ack: oneshot::Sender<Result<(), SystemError>>,
},
ShutdownRequest(oneshot::Sender<()>),
StealRequest(oneshot::Sender<Option<TaskWorkState<E>>>),
WakeUp,
}

View file

@ -0,0 +1,467 @@
use std::{
cell::RefCell,
collections::HashSet,
pin::pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use async_channel as chan;
use futures::StreamExt;
use futures_concurrency::future::Join;
use tokio::{spawn, sync::oneshot, task::JoinHandle};
use tracing::{error, info, trace, warn};
use super::{
error::{RunError, SystemError},
message::SystemMessage,
task::{IntoTask, Task, TaskHandle, TaskId},
worker::{AtomicWorkerId, WorkStealer, Worker, WorkerBuilder, WorkerId},
};
/// The task system is the main entry point for the library, it is responsible for creating and managing the workers
/// and dispatching tasks to them.
///
/// It also provides a way to shutdown the system returning all pending and running tasks.
/// It uses internal mutability so it can be shared without hassles using [`Arc`].
pub struct System<E: RunError> {
workers: Arc<Vec<Worker<E>>>,
msgs_tx: chan::Sender<SystemMessage>,
dispatcher: Dispatcher<E>,
handle: RefCell<Option<JoinHandle<()>>>,
}
impl<E: RunError> System<E> {
/// Created a new task system with a number of workers equal to the available parallelism in the user's machine.
pub fn new() -> Self {
let workers_count = std::thread::available_parallelism().map_or_else(
|e| {
error!("Failed to get available parallelism in the job system: {e:#?}");
1
},
|non_zero| non_zero.get(),
);
let (msgs_tx, msgs_rx) = chan::bounded(8);
let system_comm = SystemComm(msgs_tx.clone());
let (workers_builders, worker_comms) = (0..workers_count)
.map(WorkerBuilder::new)
.unzip::<_, _, Vec<_>, Vec<_>>();
let task_stealer = WorkStealer::new(worker_comms);
let idle_workers = Arc::new((0..workers_count).map(|_| AtomicBool::new(true)).collect());
let workers = Arc::new(
workers_builders
.into_iter()
.map(|builder| builder.build(system_comm.clone(), task_stealer.clone()))
.collect::<Vec<_>>(),
);
let handle = spawn({
let workers = Arc::clone(&workers);
let msgs_rx = msgs_rx.clone();
let idle_workers = Arc::clone(&idle_workers);
async move {
trace!("Task System message processing task starting...");
while let Err(e) = spawn(Self::run(
Arc::clone(&workers),
Arc::clone(&idle_workers),
msgs_rx.clone(),
))
.await
{
if e.is_panic() {
error!("Job system panicked: {e:#?}");
} else {
trace!("Task system received shutdown signal and will exit...");
break;
}
trace!("Restarting task system message processing task...")
}
info!("Task system gracefully shutdown");
}
});
trace!("Task system online!");
Self {
workers: Arc::clone(&workers),
msgs_tx,
dispatcher: Dispatcher {
workers,
idle_workers,
last_worker_id: Arc::new(AtomicWorkerId::new(0)),
},
handle: RefCell::new(Some(handle)),
}
}
/// Returns the number of workers in the system.
pub fn workers_count(&self) -> usize {
self.workers.len()
}
/// Dispatches a task to the system, the task will be assigned to a worker and executed as soon as possible.
pub async fn dispatch(&self, into_task: impl IntoTask<E>) -> TaskHandle<E> {
self.dispatcher.dispatch(into_task).await
}
/// Dispatches many tasks to the system, the tasks will be assigned to workers and executed as soon as possible.
pub async fn dispatch_many(&self, into_tasks: Vec<impl IntoTask<E>>) -> Vec<TaskHandle<E>> {
self.dispatcher.dispatch_many(into_tasks).await
}
/// Returns a dispatcher that can be used to remotely dispatch tasks to the system.
pub fn get_dispatcher(&self) -> Dispatcher<E> {
self.dispatcher.clone()
}
async fn run(
workers: Arc<Vec<Worker<E>>>,
idle_workers: Arc<Vec<AtomicBool>>,
msgs_rx: chan::Receiver<SystemMessage>,
) {
let mut msg_stream = pin!(msgs_rx);
while let Some(msg) = msg_stream.next().await {
match msg {
SystemMessage::IdleReport(worker_id) => {
trace!("Task system received a worker idle report request: <worker_id='{worker_id}'>");
idle_workers[worker_id].store(true, Ordering::Relaxed);
}
SystemMessage::WorkingReport(worker_id) => {
trace!(
"Task system received a working report request: <worker_id='{worker_id}'>"
);
idle_workers[worker_id].store(false, Ordering::Relaxed);
}
SystemMessage::ResumeTask {
task_id,
worker_id,
ack,
} => {
trace!("Task system received a task resume request: <task_id='{task_id}', worker_id='{worker_id}'>");
workers[worker_id].resume_task(task_id, ack).await;
}
SystemMessage::PauseNotRunningTask {
task_id,
worker_id,
ack,
} => {
trace!("Task system received a task resume request: <task_id='{task_id}', worker_id='{worker_id}'>");
workers[worker_id]
.pause_not_running_task(task_id, ack)
.await;
}
SystemMessage::CancelNotRunningTask {
task_id,
worker_id,
ack,
} => {
trace!("Task system received a task resume request: <task_id='{task_id}', worker_id='{worker_id}'>");
workers[worker_id]
.cancel_not_running_task(task_id, ack)
.await;
}
SystemMessage::ForceAbortion {
task_id,
worker_id,
ack,
} => {
trace!(
"Task system received a task force abortion request: \
<task_id='{task_id}', worker_id='{worker_id}'>"
);
workers[worker_id].force_task_abortion(task_id, ack).await;
}
SystemMessage::NotifyIdleWorkers {
start_from,
task_count,
} => {
trace!(
"Task system received a request to notify idle workers: \
<start_from='{start_from}', task_count='{task_count}'>"
);
for idx in (0..workers.len())
.cycle()
.skip(start_from)
.take(usize::min(task_count, workers.len()))
{
if idle_workers[idx].load(Ordering::Relaxed) {
workers[idx].wake().await;
// we don't mark the worker as not idle because we wait for it to
// successfully steal a task and then report it back as active
}
}
}
SystemMessage::ShutdownRequest(tx) => {
trace!("Task system received a shutdown request");
tx.send(Ok(()))
.expect("System channel closed trying to shutdown");
return;
}
}
}
}
/// Shuts down the system, returning all pending and running tasks to their respective handles.
pub async fn shutdown(&self) {
if let Some(handle) = self
.handle
.try_borrow_mut()
.ok()
.and_then(|mut maybe_handle| maybe_handle.take())
{
self.workers
.iter()
.map(|worker| async move { worker.shutdown().await })
.collect::<Vec<_>>()
.join()
.await;
let (tx, rx) = oneshot::channel();
self.msgs_tx
.send(SystemMessage::ShutdownRequest(tx))
.await
.expect("Task system channel closed trying to shutdown");
if let Err(e) = rx
.await
.expect("Task system channel closed trying to shutdown")
{
error!("Task system failed to shutdown: {e:#?}");
}
if let Err(e) = handle.await {
error!("Task system failed to shutdown on handle await: {e:#?}");
}
} else {
warn!("Trying to shutdown the tasks system that was already shutdown");
}
}
}
/// The default implementation of the task system will create a system with a number of workers equal to the available
/// parallelism in the user's machine.
impl<E: RunError> Default for System<E> {
fn default() -> Self {
Self::new()
}
}
/// SAFETY: Due to usage of refcell we lost `Sync` impl, but we only use it to have a shutdown method
/// receiving `&self` which is called once, and we also use `try_borrow_mut` so we never panic
unsafe impl<E: RunError> Sync for System<E> {}
#[derive(Clone, Debug)]
#[repr(transparent)]
pub(crate) struct SystemComm(chan::Sender<SystemMessage>);
impl SystemComm {
pub async fn idle_report(&self, worker_id: usize) {
self.0
.send(SystemMessage::IdleReport(worker_id))
.await
.expect("System channel closed trying to report idle");
}
pub async fn working_report(&self, worker_id: usize) {
self.0
.send(SystemMessage::WorkingReport(worker_id))
.await
.expect("System channel closed trying to report working");
}
pub async fn pause_not_running_task(
&self,
task_id: TaskId,
worker_id: WorkerId,
) -> Result<(), SystemError> {
let (tx, rx) = oneshot::channel();
self.0
.send(SystemMessage::PauseNotRunningTask {
task_id,
worker_id,
ack: tx,
})
.await
.expect("System channel closed trying to pause not running task");
rx.await
.expect("System channel closed trying receive pause not running task response")
}
pub async fn cancel_not_running_task(
&self,
task_id: TaskId,
worker_id: WorkerId,
) -> Result<(), SystemError> {
let (tx, rx) = oneshot::channel();
self.0
.send(SystemMessage::CancelNotRunningTask {
task_id,
worker_id,
ack: tx,
})
.await
.expect("System channel closed trying to cancel a not running task");
rx.await
.expect("System channel closed trying receive cancel a not running task response")
}
pub async fn request_help(&self, worker_id: WorkerId, task_count: usize) {
self.0
.send(SystemMessage::NotifyIdleWorkers {
start_from: worker_id,
task_count,
})
.await
.expect("System channel closed trying to request help");
}
pub async fn resume_task(
&self,
task_id: TaskId,
worker_id: WorkerId,
) -> Result<(), SystemError> {
let (tx, rx) = oneshot::channel();
self.0
.send(SystemMessage::ResumeTask {
task_id,
worker_id,
ack: tx,
})
.await
.expect("System channel closed trying to resume task");
rx.await
.expect("System channel closed trying receive resume task response")
}
pub async fn force_abortion(
&self,
task_id: TaskId,
worker_id: WorkerId,
) -> Result<(), SystemError> {
let (tx, rx) = oneshot::channel();
self.0
.send(SystemMessage::ForceAbortion {
task_id,
worker_id,
ack: tx,
})
.await
.expect("System channel closed trying to resume task");
rx.await
.expect("System channel closed trying receive resume task response")
}
}
/// A remote dispatcher of tasks.
///
/// It can be used to dispatch tasks to the system from other threads or tasks.
/// It uses [`Arc`] internally so it can be cheaply cloned and put inside tasks so tasks can dispatch other tasks.
#[derive(Debug)]
pub struct Dispatcher<E: RunError> {
workers: Arc<Vec<Worker<E>>>,
idle_workers: Arc<Vec<AtomicBool>>,
last_worker_id: Arc<AtomicWorkerId>,
}
impl<E: RunError> Clone for Dispatcher<E> {
fn clone(&self) -> Self {
Self {
workers: Arc::clone(&self.workers),
idle_workers: Arc::clone(&self.idle_workers),
last_worker_id: Arc::clone(&self.last_worker_id),
}
}
}
impl<E: RunError> Dispatcher<E> {
/// Dispatches a task to the system, the task will be assigned to a worker and executed as soon as possible.
pub async fn dispatch(&self, into_task: impl IntoTask<E>) -> TaskHandle<E> {
let task = into_task.into_task();
async fn inner<E: RunError>(this: &Dispatcher<E>, task: Box<dyn Task<E>>) -> TaskHandle<E> {
let worker_id = this
.last_worker_id
.fetch_update(Ordering::Release, Ordering::Acquire, |last_worker_id| {
Some((last_worker_id + 1) % this.workers.len())
})
.expect("we hardcoded the update function to always return Some(next_worker_id) through dispatcher");
trace!(
"Dispatching task to worker: <worker_id='{worker_id}', task_id='{}'>",
task.id()
);
let handle = this.workers[worker_id].add_task(task).await;
this.idle_workers[worker_id].store(false, Ordering::Relaxed);
handle
}
inner(self, task).await
}
/// Dispatches many tasks to the system, the tasks will be assigned to workers and executed as soon as possible.
pub async fn dispatch_many(&self, into_tasks: Vec<impl IntoTask<E>>) -> Vec<TaskHandle<E>> {
let mut workers_task_count = self
.workers
.iter()
.map(|worker| async move { (worker.id, worker.task_count().await) })
.collect::<Vec<_>>()
.join()
.await;
workers_task_count.sort_by_key(|(_id, count)| *count);
let (handles, workers_ids_set) = into_tasks
.into_iter()
.map(IntoTask::into_task)
.zip(workers_task_count.into_iter().cycle())
.map(|(task, (worker_id, _))| async move {
(self.workers[worker_id].add_task(task).await, worker_id)
})
.collect::<Vec<_>>()
.join()
.await
.into_iter()
.unzip::<_, _, Vec<_>, HashSet<_>>();
workers_ids_set.into_iter().for_each(|worker_id| {
self.idle_workers[worker_id].store(false, Ordering::Relaxed);
});
handles
}
/// Returns the number of workers in the system.
pub fn workers_count(&self) -> usize {
self.workers.len()
}
}

View file

@ -0,0 +1,484 @@
use std::{
fmt,
future::{Future, IntoFuture},
pin::Pin,
sync::{
atomic::{AtomicBool, AtomicU8, Ordering},
Arc,
},
task::{Context, Poll},
};
use async_channel as chan;
use async_trait::async_trait;
use chan::{Recv, RecvError};
use downcast_rs::{impl_downcast, Downcast};
use tokio::sync::oneshot;
use tracing::{trace, warn};
use uuid::Uuid;
use super::{
error::{RunError, SystemError},
system::SystemComm,
worker::{AtomicWorkerId, WorkerId},
};
/// A unique identifier for a task using the [`uuid`](https://docs.rs/uuid) crate.
pub type TaskId = Uuid;
/// A trait that represents any kind of output that a task can return.
///
/// The user will downcast it to the concrete type that the task returns. Most of the time,
/// tasks will not return anything, so it isn't a costly abstraction, as only a heap allocation
/// is needed when the user wants to return a [`Box<dyn AnyTaskOutput>`].
pub trait AnyTaskOutput: Send + fmt::Debug + Downcast + 'static {}
impl_downcast!(AnyTaskOutput);
/// Blanket implementation for all types that implements `std::fmt::Debug + Send + 'static`
impl<T: fmt::Debug + Send + 'static> AnyTaskOutput for T {}
/// A helper trait to convert any type that implements [`AnyTaskOutput`] into a [`TaskOutput`], boxing it.
pub trait IntoAnyTaskOutput {
fn into_output(self) -> TaskOutput;
}
/// Blanket implementation for all types that implements AnyTaskOutput
impl<T: AnyTaskOutput + 'static> IntoAnyTaskOutput for T {
fn into_output(self) -> TaskOutput {
TaskOutput::Out(Box::new(self))
}
}
/// An enum representing whether a task returned anything or not.
#[derive(Debug)]
pub enum TaskOutput {
Out(Box<dyn AnyTaskOutput>),
Empty,
}
/// An enum representing all possible outcomes for a task.
#[derive(Debug)]
pub enum TaskStatus<E: RunError> {
/// The task has finished successfully and maybe has some output for the user.
Done(TaskOutput),
/// Task was gracefully cancelled by the user.
Canceled,
/// Task was forcefully aborted by the user.
ForcedAbortion,
/// The task system was shutdown and we give back the task to the user so they can downcast it
/// back to the original concrete type and store it on disk or any other storage to be re-dispatched later.
Shutdown(Box<dyn Task<E>>),
/// Task had and error so we return it back and the user can handle it appropriately.
Error(E),
}
/// Represents whether the current [`Task::run`] method on a task finished successfully or was interrupted.
///
/// `Done` and `Canceled` variants can only happen once, while `Paused` can happen multiple times,
/// whenever the user wants to pause the task.
#[derive(Debug)]
pub enum ExecStatus {
Done(TaskOutput),
Paused,
Canceled,
}
#[derive(Debug)]
pub(crate) enum InternalTaskExecStatus<E: RunError> {
Done(TaskOutput),
Paused,
Canceled,
Suspend,
Error(E),
}
impl<E: RunError> From<Result<ExecStatus, E>> for InternalTaskExecStatus<E> {
fn from(result: Result<ExecStatus, E>) -> Self {
result
.map(|status| match status {
ExecStatus::Done(out) => Self::Done(out),
ExecStatus::Paused => Self::Paused,
ExecStatus::Canceled => Self::Canceled,
})
.unwrap_or_else(|e| Self::Error(e))
}
}
/// A helper trait to convert any type that implements [`Task<E>`] into a [`Box<dyn Task<E>>`], boxing it.
pub trait IntoTask<E> {
fn into_task(self) -> Box<dyn Task<E>>;
}
/// Blanket implementation for all types that implements [`Task<E>`] and `'static`
impl<T: Task<E> + 'static, E: RunError> IntoTask<E> for T {
fn into_task(self) -> Box<dyn Task<E>> {
Box::new(self)
}
}
/// The main trait that represents a task that can be dispatched to the task system.
///
/// All traits in the task system must return the same generic error type, so we can have a unified
/// error handling.
///
/// We're currently using the [`async_trait`](https://docs.rs/async-trait) crate to allow dyn async traits,
/// due to a limitation in the Rust language.
#[async_trait]
pub trait Task<E: RunError>: fmt::Debug + Downcast + Send + 'static {
/// This method represent the work that should be done by the worker, it will be called by the
/// worker when there is a slot available in its internal queue.
/// We receive a `&mut self` so any internal data can be mutated on each `run` invocation.
///
/// The [`interrupter`](Interrupter) is a helper object that can be used to check if the user requested a pause or a cancel,
/// so the user can decide the appropriated moment to pause or cancel the task. Avoiding corrupted data or
/// inconsistent states.
async fn run(&mut self, interrupter: &Interrupter) -> Result<ExecStatus, E>;
/// This method defines whether a task should run with priority or not. The task system has a mechanism
/// to suspend non-priority tasks on any worker and run priority tasks ASAP. This is useful for tasks that
/// are more important than others, like a task that should be concluded and show results immediately to the user,
/// as thumbnails being generated for the current open directory or copy/paste operations.
fn with_priority(&self) -> bool {
false
}
/// An unique identifier for the task, it will be used to identify the task on the system and also to the user.
fn id(&self) -> TaskId;
}
impl_downcast!(Task<E> where E: RunError);
/// Intermediate struct to wait until a pause or a cancel commands are sent by the user.
#[must_use = "`InterrupterFuture` does nothing unless polled"]
#[pin_project::pin_project]
pub struct InterrupterFuture<'recv> {
#[pin]
fut: Recv<'recv, InterruptionRequest>,
has_interrupted: &'recv AtomicU8,
}
impl Future for InterrupterFuture<'_> {
type Output = InterruptionKind;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.fut.poll(cx) {
Poll::Ready(Ok(InterruptionRequest { kind, ack })) => {
if ack.send(Ok(())).is_err() {
warn!("TaskInterrupter ack channel closed");
}
this.has_interrupted.store(kind as u8, Ordering::Relaxed);
Poll::Ready(kind)
}
Poll::Ready(Err(RecvError)) => {
// In case the task handle was dropped, we can't receive any more interrupt messages
// so we will never interrupt and the task will run freely until ended
warn!("Task interrupter channel closed, will run task until it finishes!");
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
/// We use an [`IntoFuture`] implementation to allow the user to use the `await` syntax on the [`Interrupter`] object.
/// With this trait, we return an [`InterrupterFuture`] that will await until the user requests a pause or a cancel.
impl<'recv> IntoFuture for &'recv Interrupter {
type Output = InterruptionKind;
type IntoFuture = InterrupterFuture<'recv>;
fn into_future(self) -> Self::IntoFuture {
InterrupterFuture {
fut: self.interrupt_rx.recv(),
has_interrupted: &self.has_interrupted,
}
}
}
/// A helper object that can be used to check if the user requested a pause or a cancel, so the task `run`
/// implementation can decide the appropriated moment to pause or cancel the task.
#[derive(Debug)]
pub struct Interrupter {
interrupt_rx: chan::Receiver<InterruptionRequest>,
has_interrupted: AtomicU8,
}
impl Interrupter {
pub(crate) fn new(interrupt_tx: chan::Receiver<InterruptionRequest>) -> Self {
Self {
interrupt_rx: interrupt_tx,
has_interrupted: AtomicU8::new(0),
}
}
/// Check if the user requested a pause or a cancel, returning the kind of interruption that was requested
/// in a non-blocking manner.
pub fn try_check_interrupt(&self) -> Option<InterruptionKind> {
if let Some(kind) = InterruptionKind::load(&self.has_interrupted) {
Some(kind)
} else if let Ok(InterruptionRequest { kind, ack }) = self.interrupt_rx.try_recv() {
if ack.send(Ok(())).is_err() {
warn!("TaskInterrupter ack channel closed");
}
self.has_interrupted.store(kind as u8, Ordering::Relaxed);
Some(kind)
} else {
None
}
}
pub(super) fn reset(&self) {
self.has_interrupted
.compare_exchange(
InterruptionKind::Pause as u8,
0,
Ordering::Release,
Ordering::Relaxed,
)
.expect("we must only reset paused tasks");
}
}
/// The kind of interruption that can be requested by the user, a pause or a cancel
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum InterruptionKind {
Pause = 1,
Cancel = 2,
}
impl InterruptionKind {
fn load(kind: &AtomicU8) -> Option<Self> {
match kind.load(Ordering::Relaxed) {
1 => Some(Self::Pause),
2 => Some(Self::Cancel),
_ => None,
}
}
}
#[derive(Debug)]
pub(crate) struct InterruptionRequest {
kind: InterruptionKind,
ack: oneshot::Sender<Result<(), SystemError>>,
}
/// A handle returned when a task is dispatched to the task system, it can be used to pause, cancel, resume, or wait
/// until the task gets completed.
#[derive(Debug)]
pub struct TaskHandle<E: RunError> {
pub(crate) worktable: Arc<TaskWorktable>,
pub(crate) done_rx: oneshot::Receiver<Result<TaskStatus<E>, SystemError>>,
pub(crate) system_comm: SystemComm,
pub(crate) task_id: TaskId,
}
impl<E: RunError> Future for TaskHandle<E> {
type Output = Result<TaskStatus<E>, SystemError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.done_rx)
.poll(cx)
.map(|res| res.expect("TaskHandle done channel unexpectedly closed"))
}
}
impl<E: RunError> TaskHandle<E> {
/// Get the unique identifier of the task
pub fn task_id(&self) -> TaskId {
self.task_id
}
/// Gracefully pause the task at a safe point defined by the user using the [`Interrupter`]
pub async fn pause(&self) -> Result<(), SystemError> {
let is_paused = self.worktable.is_paused.load(Ordering::Relaxed);
let is_canceled = self.worktable.is_canceled.load(Ordering::Relaxed);
let is_done = self.worktable.is_done.load(Ordering::Relaxed);
trace!("Received pause command task: <is_canceled={is_canceled}, is_done={is_done}>");
if !is_paused && !is_canceled && !is_done {
if self.worktable.is_running.load(Ordering::Relaxed) {
let (tx, rx) = oneshot::channel();
trace!("Task is running, sending pause request");
self.worktable.pause(tx).await;
rx.await.expect("Worker failed to ack pause request")?;
} else {
trace!("Task is not running, setting is_paused flag");
self.worktable.is_paused.store(true, Ordering::Relaxed);
return self
.system_comm
.pause_not_running_task(
self.task_id,
self.worktable.current_worker_id.load(Ordering::Relaxed),
)
.await;
}
}
Ok(())
}
/// Gracefully cancel the task at a safe point defined by the user using the [`Interrupter`]
pub async fn cancel(&self) -> Result<(), SystemError> {
let is_canceled = self.worktable.is_canceled.load(Ordering::Relaxed);
let is_done = self.worktable.is_done.load(Ordering::Relaxed);
trace!("Received cancel command task: <is_canceled={is_canceled}, is_done={is_done}>");
if !is_canceled && !is_done {
if self.worktable.is_running.load(Ordering::Relaxed) {
let (tx, rx) = oneshot::channel();
trace!("Task is running, sending cancel request");
self.worktable.cancel(tx).await;
rx.await.expect("Worker failed to ack cancel request")?;
} else {
trace!("Task is not running, setting is_canceled flag");
self.worktable.is_canceled.store(true, Ordering::Relaxed);
return self
.system_comm
.cancel_not_running_task(
self.task_id,
self.worktable.current_worker_id.load(Ordering::Relaxed),
)
.await;
}
}
Ok(())
}
/// Forcefully abort the task, this can lead to corrupted data or inconsistent states, so use it with caution.
pub async fn force_abortion(&self) -> Result<(), SystemError> {
self.worktable.set_aborted();
self.system_comm
.force_abortion(
self.task_id,
self.worktable.current_worker_id.load(Ordering::Relaxed),
)
.await
}
/// Marks the task to be resumed by the task system, the worker will start processing it if there is a slot
/// available or will be enqueued otherwise.
pub async fn resume(&self) -> Result<(), SystemError> {
self.system_comm
.resume_task(
self.task_id,
self.worktable.current_worker_id.load(Ordering::Relaxed),
)
.await
}
}
#[derive(Debug)]
pub(crate) struct TaskWorktable {
started: AtomicBool,
is_running: AtomicBool,
is_done: AtomicBool,
is_paused: AtomicBool,
is_canceled: AtomicBool,
is_aborted: AtomicBool,
interrupt_tx: chan::Sender<InterruptionRequest>,
current_worker_id: AtomicWorkerId,
}
impl TaskWorktable {
pub fn new(worker_id: WorkerId, interrupt_tx: chan::Sender<InterruptionRequest>) -> Self {
Self {
started: AtomicBool::new(false),
is_running: AtomicBool::new(false),
is_done: AtomicBool::new(false),
is_paused: AtomicBool::new(false),
is_canceled: AtomicBool::new(false),
is_aborted: AtomicBool::new(false),
interrupt_tx,
current_worker_id: AtomicWorkerId::new(worker_id),
}
}
pub fn set_started(&self) {
self.started.store(true, Ordering::Relaxed);
self.is_running.store(true, Ordering::Relaxed);
}
pub fn set_completed(&self) {
self.is_done.store(true, Ordering::Relaxed);
self.is_running.store(false, Ordering::Relaxed);
}
pub fn set_unpause(&self) {
self.is_paused.store(false, Ordering::Relaxed);
}
pub fn set_aborted(&self) {
self.is_aborted.store(true, Ordering::Relaxed);
}
pub async fn pause(&self, tx: oneshot::Sender<Result<(), SystemError>>) {
self.is_paused.store(true, Ordering::Relaxed);
self.is_running.store(false, Ordering::Relaxed);
trace!("Sending pause signal to Interrupter object on task");
self.interrupt_tx
.send(InterruptionRequest {
kind: InterruptionKind::Pause,
ack: tx,
})
.await
.expect("Worker channel closed trying to pause task");
}
pub async fn cancel(&self, tx: oneshot::Sender<Result<(), SystemError>>) {
self.is_canceled.store(true, Ordering::Relaxed);
self.is_running.store(false, Ordering::Relaxed);
self.interrupt_tx
.send(InterruptionRequest {
kind: InterruptionKind::Cancel,
ack: tx,
})
.await
.expect("Worker channel closed trying to pause task");
}
pub fn is_paused(&self) -> bool {
self.is_paused.load(Ordering::Relaxed)
}
pub fn is_canceled(&self) -> bool {
self.is_canceled.load(Ordering::Relaxed)
}
pub fn is_aborted(&self) -> bool {
self.is_aborted.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub(crate) struct TaskWorkState<E: RunError> {
pub(crate) task: Box<dyn Task<E>>,
pub(crate) worktable: Arc<TaskWorktable>,
pub(crate) done_tx: oneshot::Sender<Result<TaskStatus<E>, SystemError>>,
pub(crate) interrupter: Arc<Interrupter>,
}
impl<E: RunError> TaskWorkState<E> {
pub fn change_worker(&self, new_worker_id: WorkerId) {
self.worktable
.current_worker_id
.store(new_worker_id, Ordering::Relaxed);
}
}

View file

@ -0,0 +1,328 @@
use std::{
cell::RefCell,
sync::{atomic::AtomicUsize, Arc},
time::Duration,
};
use async_channel as chan;
use tokio::{spawn, sync::oneshot, task::JoinHandle};
use tracing::{error, info, trace, warn};
use super::{
error::{RunError, SystemError},
message::WorkerMessage,
system::SystemComm,
task::{
InternalTaskExecStatus, Interrupter, Task, TaskHandle, TaskId, TaskWorkState, TaskWorktable,
},
};
mod run;
mod runner;
use run::run;
const ONE_SECOND: Duration = Duration::from_secs(1);
pub(crate) type WorkerId = usize;
pub(crate) type AtomicWorkerId = AtomicUsize;
pub(crate) struct WorkerBuilder<E: RunError> {
id: usize,
msgs_tx: chan::Sender<WorkerMessage<E>>,
msgs_rx: chan::Receiver<WorkerMessage<E>>,
}
impl<E: RunError> WorkerBuilder<E> {
pub fn new(id: WorkerId) -> (Self, WorkerComm<E>) {
let (msgs_tx, msgs_rx) = chan::bounded(8);
let worker_comm = WorkerComm {
worker_id: id,
msgs_tx: msgs_tx.clone(),
};
(
Self {
id,
msgs_tx,
msgs_rx,
},
worker_comm,
)
}
pub fn build(self, system_comm: SystemComm, task_stealer: WorkStealer<E>) -> Worker<E> {
let Self {
id,
msgs_tx,
msgs_rx,
} = self;
let handle = spawn({
let msgs_rx = msgs_rx.clone();
let system_comm = system_comm.clone();
let task_stealer = task_stealer.clone();
async move {
trace!("Worker <worker_id='{id}'> message processing task starting...");
while let Err(e) = spawn(run(
id,
system_comm.clone(),
task_stealer.clone(),
msgs_rx.clone(),
))
.await
{
if e.is_panic() {
error!(
"Worker <worker_id='{id}'> critically failed and will restart: \
{e:#?}"
);
} else {
trace!(
"Worker <worker_id='{id}'> received shutdown signal and will exit..."
);
break;
}
}
info!("Worker <worker_id='{id}'> gracefully shutdown");
}
});
Worker {
id,
system_comm,
msgs_tx,
handle: RefCell::new(Some(handle)),
}
}
}
#[derive(Debug)]
pub(crate) struct Worker<E: RunError> {
pub id: usize,
system_comm: SystemComm,
msgs_tx: chan::Sender<WorkerMessage<E>>,
handle: RefCell<Option<JoinHandle<()>>>,
}
impl<E: RunError> Worker<E> {
pub async fn add_task(&self, new_task: Box<dyn Task<E>>) -> TaskHandle<E> {
let (done_tx, done_rx) = oneshot::channel();
let (interrupt_tx, interrupt_rx) = chan::bounded(1);
let worktable = Arc::new(TaskWorktable::new(self.id, interrupt_tx));
let task_id = new_task.id();
self.msgs_tx
.send(WorkerMessage::NewTask(TaskWorkState {
task: new_task,
worktable: Arc::clone(&worktable),
interrupter: Arc::new(Interrupter::new(interrupt_rx)),
done_tx,
}))
.await
.expect("Worker channel closed trying to add task");
TaskHandle {
worktable,
done_rx,
system_comm: self.system_comm.clone(),
task_id,
}
}
pub async fn task_count(&self) -> usize {
let (tx, rx) = oneshot::channel();
self.msgs_tx
.send(WorkerMessage::TaskCountRequest(tx))
.await
.expect("Worker channel closed trying to get task count");
rx.await
.expect("Worker channel closed trying to receive task count response")
}
pub async fn resume_task(
&self,
task_id: TaskId,
ack: oneshot::Sender<Result<(), SystemError>>,
) {
self.msgs_tx
.send(WorkerMessage::ResumeTask { task_id, ack })
.await
.expect("Worker channel closed trying to resume task");
}
pub async fn pause_not_running_task(
&self,
task_id: TaskId,
ack: oneshot::Sender<Result<(), SystemError>>,
) {
self.msgs_tx
.send(WorkerMessage::PauseNotRunningTask { task_id, ack })
.await
.expect("Worker channel closed trying to pause a not running task");
}
pub async fn cancel_not_running_task(
&self,
task_id: TaskId,
ack: oneshot::Sender<Result<(), SystemError>>,
) {
self.msgs_tx
.send(WorkerMessage::CancelNotRunningTask { task_id, ack })
.await
.expect("Worker channel closed trying to cancel a not running task");
}
pub async fn force_task_abortion(
&self,
task_id: TaskId,
ack: oneshot::Sender<Result<(), SystemError>>,
) {
self.msgs_tx
.send(WorkerMessage::ForceAbortion { task_id, ack })
.await
.expect("Worker channel closed trying to force task abortion");
}
pub async fn shutdown(&self) {
if let Some(handle) = self
.handle
.try_borrow_mut()
.ok()
.and_then(|mut maybe_handle| maybe_handle.take())
{
let (tx, rx) = oneshot::channel();
self.msgs_tx
.send(WorkerMessage::ShutdownRequest(tx))
.await
.expect("Worker channel closed trying to shutdown");
rx.await.expect("Worker channel closed trying to shutdown");
if let Err(e) = handle.await {
if e.is_panic() {
error!("Worker {} critically failed: {e:#?}", self.id);
}
}
} else {
warn!("Trying to shutdown a worker that was already shutdown");
}
}
pub async fn wake(&self) {
self.msgs_tx
.send(WorkerMessage::WakeUp)
.await
.expect("Worker channel closed trying to wake up");
}
}
/// SAFETY: Due to usage of refcell we lost `Sync` impl, but we only use it to have a shutdown method
/// receiving `&self` which is called once, and we also use `try_borrow_mut` so we never panic
unsafe impl<E: RunError> Sync for Worker<E> {}
#[derive(Clone)]
pub(crate) struct WorkerComm<E: RunError> {
worker_id: WorkerId,
msgs_tx: chan::Sender<WorkerMessage<E>>,
}
impl<E: RunError> WorkerComm<E> {
pub async fn steal_task(&self, worker_id: WorkerId) -> Option<TaskWorkState<E>> {
let (tx, rx) = oneshot::channel();
self.msgs_tx
.send(WorkerMessage::StealRequest(tx))
.await
.expect("Worker channel closed trying to steal task");
rx.await
.expect("Worker channel closed trying to steal task")
.map(|task_work_state| {
trace!(
"Worker stole task: \
<worker_id='{worker_id}', stolen_worker_id='{}', task_id='{}'>",
self.worker_id,
task_work_state.task.id()
);
task_work_state.change_worker(worker_id);
task_work_state
})
}
}
pub(crate) struct WorkStealer<E: RunError> {
worker_comms: Arc<Vec<WorkerComm<E>>>,
}
impl<E: RunError> Clone for WorkStealer<E> {
fn clone(&self) -> Self {
Self {
worker_comms: Arc::clone(&self.worker_comms),
}
}
}
impl<E: RunError> WorkStealer<E> {
pub fn new(worker_comms: Vec<WorkerComm<E>>) -> Self {
Self {
worker_comms: Arc::new(worker_comms),
}
}
pub async fn steal(&self, worker_id: WorkerId) -> Option<TaskWorkState<E>> {
let total_workers = self.worker_comms.len();
for worker_comm in self
.worker_comms
.iter()
// Cycling over the workers
.cycle()
// Starting from the next worker id
.skip(worker_id)
// Taking the total amount of workers
.take(total_workers)
// Removing the current worker as we can't steal from ourselves
.filter(|worker_comm| worker_comm.worker_id != worker_id)
{
trace!(
"Trying to steal from worker <worker_id='{}', stealer_id='{worker_id}'>",
worker_comm.worker_id
);
if let Some(task) = worker_comm.steal_task(worker_id).await {
return Some(task);
} else {
trace!(
"Worker <worker_id='{}', stealer_id='{worker_id}'> has no tasks to steal",
worker_comm.worker_id
);
}
}
None
}
pub fn workers_count(&self) -> usize {
self.worker_comms.len()
}
}
struct TaskRunnerOutput<E: RunError> {
task_work_state: TaskWorkState<E>,
status: InternalTaskExecStatus<E>,
}
enum RunnerMessage<E: RunError> {
TaskOutput(TaskId, Result<TaskRunnerOutput<E>, ()>),
StoleTask(Option<TaskWorkState<E>>),
}

View file

@ -0,0 +1,113 @@
use std::pin::pin;
use async_channel as chan;
use futures::StreamExt;
use futures_concurrency::stream::Merge;
use tokio::time::{interval_at, Instant};
use tokio_stream::wrappers::IntervalStream;
use tracing::{error, warn};
use super::{
super::{error::RunError, message::WorkerMessage, system::SystemComm},
runner::Runner,
RunnerMessage, WorkStealer, WorkerId, ONE_SECOND,
};
pub(super) async fn run<E: RunError>(
id: WorkerId,
system_comm: SystemComm,
work_stealer: WorkStealer<E>,
msgs_rx: chan::Receiver<WorkerMessage<E>>,
) {
let (mut runner, runner_rx) = Runner::new(id, work_stealer, system_comm);
let mut idle_checker_interval = interval_at(Instant::now(), ONE_SECOND);
idle_checker_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
enum StreamMessage<E: RunError> {
Commands(WorkerMessage<E>),
RunnerMsg(RunnerMessage<E>),
IdleCheck,
}
let mut msg_stream = pin!((
msgs_rx.map(StreamMessage::Commands),
runner_rx.map(StreamMessage::RunnerMsg),
IntervalStream::new(idle_checker_interval).map(|_| StreamMessage::IdleCheck),
)
.merge());
while let Some(msg) = msg_stream.next().await {
match msg {
// Worker messages
StreamMessage::Commands(WorkerMessage::NewTask(task_work_state)) => {
runner.abort_steal_task();
runner.new_task(task_work_state).await;
}
StreamMessage::Commands(WorkerMessage::TaskCountRequest(tx)) => {
if tx.send(runner.total_tasks()).is_err() {
warn!("Task count request channel closed before sending task count");
}
}
StreamMessage::Commands(WorkerMessage::ResumeTask { task_id, ack }) => {
if ack.send(runner.resume_task(task_id).await).is_err() {
warn!("Resume task channel closed before sending ack");
}
}
StreamMessage::Commands(WorkerMessage::PauseNotRunningTask { task_id, ack }) => {
if ack
.send(runner.pause_not_running_task(task_id).await)
.is_err()
{
warn!("Resume task channel closed before sending ack");
}
}
StreamMessage::Commands(WorkerMessage::CancelNotRunningTask { task_id, ack }) => {
if ack
.send(runner.cancel_not_running_task(task_id).await)
.is_err()
{
warn!("Resume task channel closed before sending ack");
}
}
StreamMessage::Commands(WorkerMessage::ForceAbortion { task_id, ack }) => {
if ack.send(runner.force_task_abortion(task_id).await).is_err() {
warn!("Force abortion channel closed before sending ack");
}
}
StreamMessage::Commands(WorkerMessage::ShutdownRequest(tx)) => {
return runner.shutdown(tx).await;
}
StreamMessage::Commands(WorkerMessage::StealRequest(tx)) => runner.steal_request(tx),
StreamMessage::Commands(WorkerMessage::WakeUp) => runner.wake_up().await,
// Runner messages
StreamMessage::RunnerMsg(RunnerMessage::TaskOutput(task_id, Ok(output))) => {
runner.process_task_output(task_id, output).await
}
StreamMessage::RunnerMsg(RunnerMessage::TaskOutput(task_id, Err(()))) => {
error!("Task failed <worker_id='{id}', task_id='{task_id}'>");
runner.clean_suspended_task(task_id);
runner.dispatch_next_task(task_id).await;
}
StreamMessage::RunnerMsg(RunnerMessage::StoleTask(maybe_new_task)) => {
runner.process_stolen_task(maybe_new_task).await;
}
// Idle checking to steal some work
StreamMessage::IdleCheck => runner.idle_check().await,
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,389 @@
use sd_task_system::{
ExecStatus, Interrupter, Task, TaskDispatcher, TaskHandle, TaskId, TaskOutput, TaskStatus,
};
use std::{
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use async_channel as chan;
use async_trait::async_trait;
use futures::stream::{self, FuturesUnordered, StreamExt};
use futures_concurrency::future::Race;
use serde::{Deserialize, Serialize};
use tokio::{fs, spawn, sync::broadcast};
use tracing::{error, info, trace, warn};
use crate::common::tasks::TimedTaskOutput;
use super::tasks::{SampleError, TimeTask};
const SAMPLE_ACTOR_SAVE_STATE_FILE_NAME: &str = "sample_actor_save_state.bin";
pub struct SampleActor {
data: Arc<String>, // Can hold any kind of actor data, like an AI model
task_dispatcher: TaskDispatcher<SampleError>,
task_handles_tx: chan::Sender<TaskHandle<SampleError>>,
}
impl SampleActor {
pub async fn new(
data_directory: impl AsRef<Path>,
data: String,
task_dispatcher: TaskDispatcher<SampleError>,
) -> (Self, broadcast::Receiver<()>) {
let (task_handles_tx, task_handles_rx) = chan::bounded(8);
let (idle_tx, idle_rx) = broadcast::channel(1);
let save_state_file_path = data_directory
.as_ref()
.join(SAMPLE_ACTOR_SAVE_STATE_FILE_NAME);
let data = Arc::new(data);
let pending_tasks = fs::read(&save_state_file_path)
.await
.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
info!("No saved actor tasks found");
} else {
error!("Failed to read saved actor tasks: {e:#?}");
}
})
.ok()
.and_then(|data| {
rmp_serde::from_slice::<Vec<SampleActorTaskSaveState>>(&data)
.map_err(|e| {
error!("Failed to deserialize saved actor tasks: {e:#?}");
})
.ok()
})
.unwrap_or_default();
spawn(Self::run(save_state_file_path, task_handles_rx, idle_tx));
for SampleActorTaskSaveState {
id,
duration,
has_priority,
paused_count,
} in pending_tasks
{
task_handles_tx
.send(if has_priority {
task_dispatcher
.dispatch(SampleActorTaskWithPriority::with_id(
id,
duration,
Arc::clone(&data),
paused_count,
))
.await
} else {
task_dispatcher
.dispatch(SampleActorTask::with_id(
id,
duration,
Arc::clone(&data),
paused_count,
))
.await
})
.await
.expect("Task handle receiver dropped");
}
(
Self {
data,
task_dispatcher,
task_handles_tx,
},
idle_rx,
)
}
pub fn new_task(&self, duration: Duration) -> SampleActorTask {
SampleActorTask::new(duration, Arc::clone(&self.data))
}
pub fn new_priority_task(&self, duration: Duration) -> SampleActorTaskWithPriority {
SampleActorTaskWithPriority::new(duration, Arc::clone(&self.data))
}
async fn inner_process(&self, duration: Duration, has_priority: bool) {
self.task_handles_tx
.send(if has_priority {
self.task_dispatcher
.dispatch(self.new_priority_task(duration))
.await
} else {
self.task_dispatcher.dispatch(self.new_task(duration)).await
})
.await
.expect("Task handle receiver dropped");
}
pub async fn process(&self, duration: Duration) {
self.inner_process(duration, false).await
}
pub async fn process_with_priority(&self, duration: Duration) {
self.inner_process(duration, true).await
}
async fn run(
save_state_file_path: PathBuf,
task_handles_rx: chan::Receiver<TaskHandle<SampleError>>,
idle_tx: broadcast::Sender<()>,
) {
let mut handles = FuturesUnordered::<TaskHandle<SampleError>>::new();
enum RaceOutput {
NewHandle(TaskHandle<SampleError>),
CompletedHandle,
Stop(Option<Box<dyn Task<SampleError>>>),
}
let mut pending = 0usize;
loop {
match (
async {
if let Ok(handle) = task_handles_rx.recv().await {
RaceOutput::NewHandle(handle)
} else {
RaceOutput::Stop(None)
}
},
async {
if let Some(out) = handles.next().await {
match out {
Ok(TaskStatus::Done(maybe_out)) => {
if let TaskOutput::Out(out) = maybe_out {
info!(
"Task completed: {:?}",
out.downcast::<TimedTaskOutput>()
.expect("we know the task type")
);
}
}
Ok(TaskStatus::Canceled) => {
trace!("Task was canceled")
}
Ok(TaskStatus::ForcedAbortion) => {
warn!("Task was forcibly aborted");
}
Ok(TaskStatus::Shutdown(task)) => {
// If a task was shutdown, it means the task system is shutting down
// so all other tasks will also be shutdown
return RaceOutput::Stop(Some(task));
}
Ok(TaskStatus::Error(e)) => {
error!("Task failed: {e:#?}");
}
Err(e) => {
error!("Task system failed: {e:#?}");
}
}
RaceOutput::CompletedHandle
} else {
RaceOutput::Stop(None)
}
},
)
.race()
.await
{
RaceOutput::NewHandle(handle) => {
pending += 1;
info!("Received new task handle, total pending tasks: {pending}");
handles.push(handle);
}
RaceOutput::CompletedHandle => {
pending -= 1;
info!("Task completed, total pending tasks: {pending}");
if pending == 0 {
info!("All tasks completed, sending idle report...");
idle_tx.send(()).expect("idle receiver dropped");
}
}
RaceOutput::Stop(maybe_task) => {
task_handles_rx.close();
task_handles_rx
.for_each(|handle| async { handles.push(handle) })
.await;
let tasks = stream::iter(
maybe_task
.into_iter()
.map(SampleActorTaskSaveState::from_task),
)
.chain(handles.filter_map(|handle| async move {
match handle {
Ok(TaskStatus::Done(maybe_out)) => {
if let TaskOutput::Out(out) = maybe_out {
info!(
"Task completed: {:?}",
out.downcast::<TimedTaskOutput>()
.expect("we know the task type")
);
}
None
}
Ok(TaskStatus::Canceled) => None,
Ok(TaskStatus::ForcedAbortion) => {
warn!("Task was forcibly aborted");
None
}
Ok(TaskStatus::Shutdown(task)) => {
Some(SampleActorTaskSaveState::from_task(task))
}
Ok(TaskStatus::Error(e)) => {
error!("Task failed: {e:#?}");
None
}
Err(e) => {
error!("Task system failed: {e:#?}");
None
}
}
}))
.collect::<Vec<_>>()
.await;
if let Err(e) = fs::write(
&save_state_file_path,
rmp_serde::to_vec_named(&tasks).expect("failed to serialize"),
)
.await
{
error!("Failed to save actor tasks: {e:#?}");
}
return;
}
}
}
}
}
impl Drop for SampleActor {
fn drop(&mut self) {
self.task_handles_tx.close();
}
}
#[derive(Debug)]
pub struct SampleActorTask {
timed_task: TimeTask,
actor_data: Arc<String>, // Can hold any kind of actor data
}
impl SampleActorTask {
pub fn new(duration: Duration, actor_data: Arc<String>) -> Self {
Self {
timed_task: TimeTask::new(duration, false),
actor_data,
}
}
fn with_id(id: TaskId, duration: Duration, actor_data: Arc<String>, paused_count: u32) -> Self {
Self {
timed_task: TimeTask::with_id(id, duration, false, paused_count),
actor_data,
}
}
}
#[derive(Debug)]
pub struct SampleActorTaskWithPriority {
timed_task: TimeTask,
actor_data: Arc<String>, // Can hold any kind of actor data
}
impl SampleActorTaskWithPriority {
fn new(duration: Duration, actor_data: Arc<String>) -> SampleActorTaskWithPriority {
Self {
timed_task: TimeTask::new(duration, true),
actor_data,
}
}
fn with_id(id: TaskId, duration: Duration, actor_data: Arc<String>, paused_count: u32) -> Self {
Self {
timed_task: TimeTask::with_id(id, duration, true, paused_count),
actor_data,
}
}
}
#[async_trait]
impl Task<SampleError> for SampleActorTask {
fn id(&self) -> TaskId {
self.timed_task.id()
}
async fn run(&mut self, interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
info!("Actor data: {:#?}", self.actor_data);
self.timed_task.run(interrupter).await
}
fn with_priority(&self) -> bool {
self.timed_task.with_priority()
}
}
#[async_trait]
impl Task<SampleError> for SampleActorTaskWithPriority {
fn id(&self) -> TaskId {
self.timed_task.id()
}
async fn run(&mut self, interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
info!("Actor data: {:#?}", self.actor_data);
self.timed_task.run(interrupter).await
}
fn with_priority(&self) -> bool {
self.timed_task.with_priority()
}
}
#[derive(Debug, Serialize, Deserialize)]
struct SampleActorTaskSaveState {
id: TaskId,
duration: Duration,
has_priority: bool,
paused_count: u32,
}
impl SampleActorTaskSaveState {
fn from_task(dyn_task: Box<dyn Task<SampleError>>) -> Self {
match dyn_task.downcast::<SampleActorTask>() {
Ok(concrete_task) => SampleActorTaskSaveState {
id: concrete_task.timed_task.id(),
duration: concrete_task.timed_task.duration,
has_priority: false,
paused_count: concrete_task.timed_task.paused_count,
},
Err(dyn_task) => {
let concrete_task = dyn_task
.downcast::<SampleActorTaskWithPriority>()
.expect("we know the task type");
SampleActorTaskSaveState {
id: concrete_task.timed_task.id(),
duration: concrete_task.timed_task.duration,
has_priority: true,
paused_count: concrete_task.timed_task.paused_count,
}
}
}
}
}

View file

@ -0,0 +1,119 @@
use async_trait::async_trait;
use futures_concurrency::future::FutureGroup;
use lending_stream::{LendingStream, StreamExt};
use sd_task_system::{
ExecStatus, Interrupter, IntoAnyTaskOutput, Task, TaskDispatcher, TaskHandle, TaskId,
TaskOutput, TaskStatus,
};
use tracing::trace;
use super::tasks::SampleError;
#[derive(Debug)]
pub struct SampleJob {
total_steps: u32,
task_dispatcher: TaskDispatcher<SampleError>,
}
impl SampleJob {
pub fn new(total_steps: u32, task_dispatcher: TaskDispatcher<SampleError>) -> Self {
Self {
total_steps,
task_dispatcher,
}
}
pub async fn run(self) -> Result<(), SampleError> {
let Self {
total_steps,
task_dispatcher,
} = self;
let initial_steps = (0..task_dispatcher.workers_count())
.map(|_| SampleJobTask {
id: TaskId::new_v4(),
expected_children: total_steps - 1,
task_dispatcher: task_dispatcher.clone(),
})
.collect::<Vec<_>>();
let mut group = FutureGroup::from_iter(
task_dispatcher
.dispatch_many(initial_steps)
.await
.into_iter(),
)
.lend_mut();
while let Some((group, res)) = group.next().await {
match res.unwrap() {
TaskStatus::Done(TaskOutput::Out(out)) => {
group.insert(
out.downcast::<Output>()
.expect("we know the output type")
.children_handle,
);
trace!("Received more tasks to wait for ({} left)", group.len());
}
TaskStatus::Done(TaskOutput::Empty) => {
trace!(
"Step done, waiting for all children to finish ({} left)",
group.len()
);
}
TaskStatus::Canceled => {
trace!("Task was canceled");
}
TaskStatus::ForcedAbortion => {
trace!("Aborted")
}
TaskStatus::Shutdown(task) => {
trace!("Task was shutdown: {:?}", task);
}
TaskStatus::Error(e) => return Err(e),
}
}
Ok(())
}
}
#[derive(Debug)]
struct SampleJobTask {
id: TaskId,
expected_children: u32,
task_dispatcher: TaskDispatcher<SampleError>,
}
#[derive(Debug)]
struct Output {
children_handle: TaskHandle<SampleError>,
}
#[async_trait]
impl Task<SampleError> for SampleJobTask {
fn id(&self) -> TaskId {
self.id
}
async fn run(&mut self, _interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
if self.expected_children > 0 {
Ok(ExecStatus::Done(
Output {
children_handle: self
.task_dispatcher
.dispatch(SampleJobTask {
id: TaskId::new_v4(),
expected_children: self.expected_children - 1,
task_dispatcher: self.task_dispatcher.clone(),
})
.await,
}
.into_output(),
))
} else {
Ok(ExecStatus::Done(TaskOutput::Empty))
}
}
}

View file

@ -0,0 +1,3 @@
pub mod actors;
pub mod jobs;
pub mod tasks;

View file

@ -0,0 +1,278 @@
use std::{future::pending, time::Duration};
use sd_task_system::{
ExecStatus, Interrupter, InterruptionKind, IntoAnyTaskOutput, Task, TaskId, TaskOutput,
};
use async_trait::async_trait;
use futures_concurrency::future::Race;
use thiserror::Error;
use tokio::{
sync::oneshot,
time::{sleep, Instant},
};
use tracing::{error, info};
#[derive(Debug, Error)]
pub enum SampleError {
#[error("Sample error")]
SampleError,
}
#[derive(Debug)]
pub struct NeverTask {
id: TaskId,
}
impl Default for NeverTask {
fn default() -> Self {
Self {
id: TaskId::new_v4(),
}
}
}
#[async_trait]
impl Task<SampleError> for NeverTask {
fn id(&self) -> TaskId {
self.id
}
async fn run(&mut self, interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
match interrupter.await {
InterruptionKind::Pause => {
info!("Pausing NeverTask <id='{}'>", self.id);
Ok(ExecStatus::Paused)
}
InterruptionKind::Cancel => {
info!("Canceling NeverTask <id='{}'>", self.id);
Ok(ExecStatus::Canceled)
}
}
}
}
#[derive(Debug)]
pub struct ReadyTask {
id: TaskId,
}
impl Default for ReadyTask {
fn default() -> Self {
Self {
id: TaskId::new_v4(),
}
}
}
#[async_trait]
impl Task<SampleError> for ReadyTask {
fn id(&self) -> TaskId {
self.id
}
async fn run(&mut self, _interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
Ok(ExecStatus::Done(TaskOutput::Empty))
}
}
#[derive(Debug)]
pub struct BogusTask {
id: TaskId,
}
impl Default for BogusTask {
fn default() -> Self {
Self {
id: TaskId::new_v4(),
}
}
}
#[async_trait]
impl Task<SampleError> for BogusTask {
fn id(&self) -> TaskId {
self.id
}
async fn run(&mut self, _interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
Err(SampleError::SampleError)
}
}
#[derive(Debug)]
pub struct TimeTask {
id: TaskId,
pub duration: Duration,
priority: bool,
pub paused_count: u32,
}
impl TimeTask {
pub fn new(duration: Duration, priority: bool) -> Self {
Self {
id: TaskId::new_v4(),
duration,
priority,
paused_count: 0,
}
}
pub fn with_id(id: TaskId, duration: Duration, priority: bool, paused_count: u32) -> Self {
Self {
id,
duration,
priority,
paused_count,
}
}
}
#[derive(Debug)]
pub struct TimedTaskOutput {
pub pauses_count: u32,
}
#[async_trait]
impl Task<SampleError> for TimeTask {
fn id(&self) -> TaskId {
self.id
}
async fn run(&mut self, interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
let start = Instant::now();
info!("Running timed task for {:#?}", self.duration);
enum RaceOutput {
Paused(Duration),
Canceled,
Completed,
}
let task_work_fut = async {
sleep(self.duration).await;
RaceOutput::Completed
};
let interrupt_fut = async {
let elapsed = start.elapsed();
match interrupter.await {
InterruptionKind::Pause => RaceOutput::Paused(if elapsed < self.duration {
self.duration - elapsed
} else {
Duration::ZERO
}),
InterruptionKind::Cancel => RaceOutput::Canceled,
}
};
Ok(match (task_work_fut, interrupt_fut).race().await {
RaceOutput::Completed | RaceOutput::Paused(Duration::ZERO) => ExecStatus::Done(
TimedTaskOutput {
pauses_count: self.paused_count,
}
.into_output(),
),
RaceOutput::Paused(remaining_duration) => {
self.duration = remaining_duration;
self.paused_count += 1;
ExecStatus::Paused
}
RaceOutput::Canceled => ExecStatus::Canceled,
})
}
fn with_priority(&self) -> bool {
self.priority
}
}
#[derive(Debug)]
pub struct PauseOnceTask {
id: TaskId,
has_paused: bool,
began_tx: Option<oneshot::Sender<()>>,
}
impl PauseOnceTask {
pub fn new() -> (Self, oneshot::Receiver<()>) {
let (tx, rx) = oneshot::channel();
(
Self {
id: TaskId::new_v4(),
has_paused: false,
began_tx: Some(tx),
},
rx,
)
}
}
#[async_trait]
impl Task<SampleError> for PauseOnceTask {
fn id(&self) -> TaskId {
self.id
}
async fn run(&mut self, interrupter: &Interrupter) -> Result<ExecStatus, SampleError> {
if let Some(began_tx) = self.began_tx.take() {
if began_tx.send(()).is_err() {
error!("Failed to send began signal");
}
}
if !self.has_paused {
self.has_paused = true;
match interrupter.await {
InterruptionKind::Pause => {
info!("Pausing PauseOnceTask <id='{}'>", self.id);
self.has_paused = true;
Ok(ExecStatus::Paused)
}
InterruptionKind::Cancel => {
info!("Canceling PauseOnceTask <id='{}'>", self.id);
Ok(ExecStatus::Canceled)
}
}
} else {
Ok(ExecStatus::Done(TaskOutput::Empty))
}
}
}
#[derive(Debug)]
pub struct BrokenTask {
id: TaskId,
began_tx: Option<oneshot::Sender<()>>,
}
impl BrokenTask {
pub fn new() -> (Self, oneshot::Receiver<()>) {
let (tx, rx) = oneshot::channel();
(
Self {
id: TaskId::new_v4(),
began_tx: Some(tx),
},
rx,
)
}
}
#[async_trait]
impl Task<SampleError> for BrokenTask {
fn id(&self) -> TaskId {
self.id
}
async fn run(&mut self, _: &Interrupter) -> Result<ExecStatus, SampleError> {
if let Some(began_tx) = self.began_tx.take() {
if began_tx.send(()).is_err() {
error!("Failed to send began signal");
}
}
pending().await
}
}

View file

@ -0,0 +1,224 @@
use sd_task_system::{TaskOutput, TaskStatus, TaskSystem};
use std::{collections::VecDeque, time::Duration};
use futures_concurrency::future::Join;
use rand::Rng;
use tempfile::tempdir;
use tracing::info;
use tracing_test::traced_test;
mod common;
use common::{
actors::SampleActor,
tasks::{BogusTask, BrokenTask, NeverTask, PauseOnceTask, ReadyTask, SampleError},
};
use crate::common::jobs::SampleJob;
#[tokio::test]
#[traced_test]
async fn test_actor() {
let data_dir = tempdir().unwrap();
let system = TaskSystem::new();
let (actor, mut actor_idle_rx) =
SampleActor::new(data_dir.path(), "test".to_string(), system.get_dispatcher()).await;
let mut rng = rand::thread_rng();
for i in 0..=250 {
if rng.gen_bool(0.1) {
info!("dispatching priority task {i}");
actor
.process_with_priority(Duration::from_millis(rng.gen_range(50..150)))
.await;
} else {
info!("dispatching task {i}");
actor
.process(Duration::from_millis(rng.gen_range(200..500)))
.await;
}
}
info!("all tasks dispatched, now we wait a bit...");
actor_idle_rx.recv().await.unwrap();
system.shutdown().await;
info!("done");
}
#[tokio::test]
#[traced_test]
async fn shutdown_test() {
let system = TaskSystem::new();
let handle = system.dispatch(NeverTask::default()).await;
system.shutdown().await;
assert!(matches!(handle.await, Ok(TaskStatus::Shutdown(_))));
}
#[tokio::test]
#[traced_test]
async fn cancel_test() {
let system = TaskSystem::new();
let handle = system.dispatch(NeverTask::default()).await;
info!("issuing cancel");
handle.cancel().await.unwrap();
assert!(matches!(handle.await, Ok(TaskStatus::Canceled)));
system.shutdown().await;
}
#[tokio::test]
#[traced_test]
async fn done_test() {
let system = TaskSystem::new();
let handle = system.dispatch(ReadyTask::default()).await;
assert!(matches!(
handle.await,
Ok(TaskStatus::Done(TaskOutput::Empty))
));
system.shutdown().await;
}
#[tokio::test]
#[traced_test]
async fn abort_test() {
let system = TaskSystem::new();
let (task, began_rx) = BrokenTask::new();
let handle = system.dispatch(task).await;
began_rx.await.unwrap();
handle.force_abortion().await.unwrap();
assert!(matches!(handle.await, Ok(TaskStatus::ForcedAbortion)));
system.shutdown().await;
}
#[tokio::test]
#[traced_test]
async fn error_test() {
let system = TaskSystem::new();
let handle = system.dispatch(BogusTask::default()).await;
assert!(matches!(
handle.await,
Ok(TaskStatus::Error(SampleError::SampleError))
));
system.shutdown().await;
}
#[tokio::test]
#[traced_test]
async fn pause_test() {
let system = TaskSystem::new();
let (task, began_rx) = PauseOnceTask::new();
let handle = system.dispatch(task).await;
info!("Task dispatched, now we wait for it to begin...");
began_rx.await.unwrap();
handle.pause().await.unwrap();
info!("Paused task, now we resume it...");
handle.resume().await.unwrap();
info!("Resumed task, now we wait for it to complete...");
assert!(matches!(
handle.await,
Ok(TaskStatus::Done(TaskOutput::Empty))
));
system.shutdown().await;
}
#[tokio::test]
#[traced_test]
async fn jobs_test() {
let system = TaskSystem::new();
let task_dispatcher = system.get_dispatcher();
let job = SampleJob::new(256, task_dispatcher.clone());
job.run().await.unwrap();
system.shutdown().await;
}
#[tokio::test]
#[traced_test]
async fn steal_test() {
let system = TaskSystem::new();
let workers_count = system.workers_count();
let (pause_tasks, pause_begans) = (0..workers_count)
.map(|_| PauseOnceTask::new())
.unzip::<_, _, Vec<_>, Vec<_>>();
// With this, all workers will be busy
let mut pause_handles = VecDeque::from(system.dispatch_many(pause_tasks).await);
let ready_handles = system
.dispatch_many((0..100).map(|_| ReadyTask::default()).collect())
.await;
pause_begans
.into_iter()
.map(|began_rx| async move { began_rx.await.unwrap() })
.collect::<Vec<_>>()
.join()
.await;
let first_paused_handle = pause_handles.pop_front().unwrap();
info!("All tasks dispatched, will now release the first one, so the first worker can steal everything...");
first_paused_handle.pause().await.unwrap();
first_paused_handle.resume().await.unwrap();
first_paused_handle.await.unwrap();
ready_handles.join().await.into_iter().for_each(|res| {
res.unwrap();
});
pause_handles
.into_iter()
.map(|handle| async move {
handle.pause().await.unwrap();
handle.resume().await.unwrap();
handle.await.unwrap();
})
.collect::<Vec<_>>()
.join()
.await;
system.shutdown().await;
}

View file

@ -9,6 +9,6 @@ edition = "2021"
sd-prisma = { path = "../prisma" }
prisma-client-rust = { workspace = true }
rspc = { workspace = true }
rspc = { workspace = true, features = ["unstable"] }
thiserror = { workspace = true }
uuid = { workspace = true }

View file

@ -1,2 +1,2 @@
[toolchain]
channel = "1.75"
channel = "stable"