diff --git a/Cargo.lock b/Cargo.lock index 82f494c00..0954fd44b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3288,12 +3288,13 @@ dependencies = [ [[package]] name = "futures-concurrency" -version = "7.4.3" +version = "7.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef6712e11cdeed5c8cf21ea0b90fec40fbe64afc9bbf2339356197eeca829fc3" +checksum = "51ee14e256b9143bfafbf2fddeede6f396650bacf95d06fc1b3f2b503df129a0" dependencies = [ "bitvec", "futures-core", + "futures-lite 1.13.0", "pin-project", "slab", "smallvec 1.13.1", @@ -9077,8 +9078,10 @@ dependencies = [ "futures", "futures-concurrency", "globset", + "image", "itertools 0.12.0", "lending-stream", + "once_cell", "prisma-client-rust", "rmp-serde", "rmpv", @@ -9087,7 +9090,10 @@ dependencies = [ "sd-core-indexer-rules", "sd-core-prisma-helpers", "sd-core-sync", + "sd-ffmpeg", "sd-file-ext", + "sd-images", + "sd-media-metadata", "sd-prisma", "sd-sync", "sd-task-system", @@ -9104,6 +9110,7 @@ dependencies = [ "tracing", "tracing-test", "uuid", + "webp", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 776cdb6d3..bd7e4cb1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ blake3 = "1.5.0" chrono = "0.4.38" clap = "4.4.7" futures = "0.3.30" -futures-concurrency = "7.4.3" +futures-concurrency = "7.6.0" globset = "^0.4.13" hex = "0.4.3" http = "0.2.9" @@ -61,7 +61,7 @@ itertools = "0.12.0" lending-stream = "1.0.0" libc = "0.2" normpath = "1.1.1" -once_cell = "1.18.0" +once_cell = "1.19.0" pin-project-lite = "0.2.13" rand = "0.8.5" rand_chacha = "0.3.1" diff --git a/apps/desktop/src-tauri/src/menu.rs b/apps/desktop/src-tauri/src/menu.rs index 307deba89..90cec4af4 100644 --- a/apps/desktop/src-tauri/src/menu.rs +++ b/apps/desktop/src-tauri/src/menu.rs @@ -2,14 +2,15 @@ use std::str::FromStr; use serde::Deserialize; use specta::Type; -use strum::{AsRefStr, EnumString}; use tauri::{ menu::{Menu, MenuItemKind}, AppHandle, Manager, Wry, }; use tracing::error; -#[derive(Debug, Clone, Copy, EnumString, AsRefStr, Type, Deserialize)] +#[derive( + Debug, Clone, Copy, Type, Deserialize, strum::EnumString, strum::AsRefStr, strum::Display, +)] pub enum MenuEvent { NewLibrary, NewFile, @@ -27,12 +28,6 @@ pub enum MenuEvent { ReloadWebview, } -impl ToString for MenuEvent { - fn to_string(&self) -> String { - self.as_ref().to_string() - } -} - /// Menu items which require a library to be open to use. /// They will be disabled/enabled automatically. const LIBRARY_LOCKED_MENU_IDS: &[MenuEvent] = &[ diff --git a/core/Cargo.toml b/core/Cargo.toml index 9fb13341a..e16f6f0dc 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -13,7 +13,7 @@ default = [] # This feature allows features to be disabled when the Core is running on mobile. mobile = [] # This feature controls whether the Spacedrive Core contains functionality which requires FFmpeg. -ffmpeg = ["dep:sd-ffmpeg", "sd-media-metadata/ffmpeg"] +ffmpeg = ["dep:sd-ffmpeg", "sd-core-heavy-lifting/ffmpeg", "sd-media-metadata/ffmpeg"] heif = ["sd-images/heif"] ai = ["dep:sd-ai"] crypto = ["dep:sd-crypto"] diff --git a/core/crates/heavy-lifting/Cargo.toml b/core/crates/heavy-lifting/Cargo.toml index 380443708..32e251a13 100644 --- a/core/crates/heavy-lifting/Cargo.toml +++ b/core/crates/heavy-lifting/Cargo.toml @@ -8,6 +8,11 @@ edition = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = [] +# This feature controls whether the Spacedrive Heavy Lifting contains functionality which requires FFmpeg. +ffmpeg = ["dep:sd-ffmpeg"] + [dependencies] # Inner Core Sub-crates sd-core-file-path-helper = { path = "../file-path-helper" } @@ -15,7 +20,10 @@ sd-core-indexer-rules = { path = "../indexer-rules" } sd-core-prisma-helpers = { path = "../prisma-helpers" } sd-core-sync = { path = "../sync" } # Sub-crates +sd-ffmpeg = { path = "../../../crates/ffmpeg", optional = true } sd-file-ext = { path = "../../../crates/file-ext" } +sd-images = { path = "../../../crates/images" } +sd-media-metadata = { path = "../../../crates/media-metadata" } sd-prisma = { path = "../../../crates/prisma" } sd-sync = { path = "../../../crates/sync" } sd-task-system = { path = "../../../crates/task-system" } @@ -28,8 +36,10 @@ chrono = { workspace = true, features = ["serde"] } futures = { workspace = true } futures-concurrency = { workspace = true } globset = { workspace = true } +image = { workspace = true } itertools = { workspace = true } lending-stream = { workspace = true } +once_cell = { workspace = true } prisma-client-rust = { workspace = true } rmp-serde = { workspace = true } rmpv = { workspace = true } @@ -44,6 +54,8 @@ tokio = { workspace = true, features = ["fs", "sync", "parking_lot"] } tokio-stream = { workspace = true, features = ["fs"] } tracing = { workspace = true } uuid = { workspace = true, features = ["v4", "serde"] } +webp = { workspace = true } + [dev-dependencies] tempfile = { workspace = true } diff --git a/core/crates/heavy-lifting/src/file_identifier/job.rs b/core/crates/heavy-lifting/src/file_identifier/job.rs index d01a55f50..8ae358dee 100644 --- a/core/crates/heavy-lifting/src/file_identifier/job.rs +++ b/core/crates/heavy-lifting/src/file_identifier/job.rs @@ -1,4 +1,5 @@ use crate::{ + file_identifier, job_system::{ job::{Job, JobReturn, JobTaskDispatcher, ReturnStatus}, report::ReportOutputMetadata, @@ -6,7 +7,7 @@ use crate::{ SerializableJob, SerializedTasks, }, utils::sub_path::maybe_get_iso_file_path_from_sub_path, - Error, JobContext, JobName, LocationScanState, NonCriticalJobError, ProgressUpdate, + Error, JobName, LocationScanState, NonCriticalError, OuterContext, ProgressUpdate, UpdateEvent, }; use sd_core_file_path_helper::IsolatedFilePathData; @@ -20,7 +21,7 @@ use sd_task_system::{ use sd_utils::db::maybe_missing; use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, hash::{Hash, Hasher}, mem, path::PathBuf, @@ -30,35 +31,36 @@ use std::{ use futures::{stream::FuturesUnordered, StreamExt}; use futures_concurrency::future::TryJoin; -use prisma_client_rust::or; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::time::Instant; use tracing::warn; use super::{ + orphan_path_filters_deep, orphan_path_filters_shallow, tasks::{ - ExtractFileMetadataTask, ExtractFileMetadataTaskOutput, ObjectProcessorTask, - ObjectProcessorTaskMetrics, + extract_file_metadata, object_processor, ExtractFileMetadataTask, ObjectProcessorTask, }, - FileIdentifierError, CHUNK_SIZE, + CHUNK_SIZE, }; #[derive(Debug)] -pub struct FileIdentifierJob { +pub struct FileIdentifier { location: Arc, location_path: Arc, sub_path: Option, metadata: Metadata, - errors: Vec, + priority_tasks_ids: HashSet, + + errors: Vec, pending_tasks_on_resume: Vec>, tasks_for_shutdown: Vec>>, } -impl Hash for FileIdentifierJob { +impl Hash for FileIdentifier { fn hash(&self, state: &mut H) { self.location.id.hash(state); if let Some(ref sub_path) = self.sub_path { @@ -67,19 +69,19 @@ impl Hash for FileIdentifierJob { } } -impl Job for FileIdentifierJob { +impl Job for FileIdentifier { const NAME: JobName = JobName::FileIdentifier; async fn resume_tasks( &mut self, dispatcher: &JobTaskDispatcher, - ctx: &impl JobContext, + ctx: &impl OuterContext, SerializedTasks(serialized_tasks): SerializedTasks, ) -> Result<(), Error> { self.pending_tasks_on_resume = dispatcher .dispatch_many_boxed( rmp_serde::from_slice::)>>(&serialized_tasks) - .map_err(FileIdentifierError::from)? + .map_err(file_identifier::Error::from)? .into_iter() .map(|(task_kind, task_bytes)| async move { match task_kind { @@ -103,17 +105,17 @@ impl Job for FileIdentifierJob { .collect::>() .try_join() .await - .map_err(FileIdentifierError::from)?, + .map_err(file_identifier::Error::from)?, ) .await; Ok(()) } - async fn run( + async fn run( mut self, dispatcher: JobTaskDispatcher, - ctx: impl JobContext, + ctx: Ctx, ) -> Result { let mut pending_running_tasks = FuturesUnordered::new(); @@ -160,7 +162,9 @@ impl Job for FileIdentifierJob { } if !self.tasks_for_shutdown.is_empty() { - return Ok(ReturnStatus::Shutdown(self.serialize().await)); + return Ok(ReturnStatus::Shutdown( + SerializableJob::::serialize(self).await, + )); } // From this point onward, we are done with the job and it can't be interrupted anymore @@ -181,7 +185,7 @@ impl Job for FileIdentifierJob { ) .exec() .await - .map_err(FileIdentifierError::from)?; + .map_err(file_identifier::Error::from)?; Ok(ReturnStatus::Completed( JobReturn::builder() @@ -192,11 +196,11 @@ impl Job for FileIdentifierJob { } } -impl FileIdentifierJob { +impl FileIdentifier { pub fn new( location: location::Data, sub_path: Option, - ) -> Result { + ) -> Result { Ok(Self { location_path: maybe_missing(&location.path, "location.path") .map(PathBuf::from) @@ -204,6 +208,7 @@ impl FileIdentifierJob { location: Arc::new(location), sub_path, metadata: Metadata::default(), + priority_tasks_ids: HashSet::new(), errors: Vec::new(), pending_tasks_on_resume: Vec::new(), tasks_for_shutdown: Vec::new(), @@ -213,12 +218,12 @@ impl FileIdentifierJob { async fn init_or_resume( &mut self, pending_running_tasks: &mut FuturesUnordered>, - job_ctx: &impl JobContext, + ctx: &impl OuterContext, dispatcher: &JobTaskDispatcher, - ) -> Result<(), FileIdentifierError> { + ) -> Result<(), file_identifier::Error> { // if we don't have any pending task, then this is a fresh job if self.pending_tasks_on_resume.is_empty() { - let db = job_ctx.db(); + let db = ctx.db(); let maybe_sub_iso_file_path = maybe_get_iso_file_path_from_sub_path( self.location.id, &self.sub_path, @@ -227,53 +232,43 @@ impl FileIdentifierJob { ) .await?; - let mut orphans_count = 0; let mut last_orphan_file_path_id = None; let start = Instant::now(); - loop { - #[allow(clippy::cast_possible_wrap)] - // SAFETY: we know that CHUNK_SIZE is a valid i64 - let orphan_paths = db - .file_path() - .find_many(orphan_path_filters( - self.location.id, - last_orphan_file_path_id, - &maybe_sub_iso_file_path, - )) - .order_by(file_path::id::order(SortOrder::Asc)) - .take(CHUNK_SIZE as i64) - .select(file_path_for_file_identifier::select()) - .exec() - .await?; + let location_root_iso_file_path = IsolatedFilePathData::new( + self.location.id, + &*self.location_path, + &*self.location_path, + true, + ) + .map_err(file_identifier::Error::from)?; - if orphan_paths.is_empty() { - break; - } + // First we dispatch some shallow priority tasks to quickly identify orphans in the location + // root directory or in the desired sub-path + let file_paths_already_identifying = self + .dispatch_priority_identifier_tasks( + &mut last_orphan_file_path_id, + maybe_sub_iso_file_path + .as_ref() + .unwrap_or(&location_root_iso_file_path), + ctx, + dispatcher, + pending_running_tasks, + ) + .await?; - orphans_count += orphan_paths.len() as u64; - last_orphan_file_path_id = - Some(orphan_paths.last().expect("orphan_paths is not empty").id); - - job_ctx.progress(vec![ - ProgressUpdate::TaskCount(orphans_count), - ProgressUpdate::Message(format!("{orphans_count} files to be identified")), - ]); - - pending_running_tasks.push( - dispatcher - .dispatch(ExtractFileMetadataTask::new_deep( - Arc::clone(&self.location), - Arc::clone(&self.location_path), - orphan_paths, - )) - .await, - ); - } + self.dispatch_deep_identifier_tasks( + &mut last_orphan_file_path_id, + &maybe_sub_iso_file_path, + ctx, + dispatcher, + pending_running_tasks, + &file_paths_already_identifying, + ) + .await?; self.metadata.seeking_orphans_time = start.elapsed(); - self.metadata.total_found_orphans = orphans_count; } else { pending_running_tasks.extend(mem::take(&mut self.pending_tasks_on_resume)); } @@ -290,25 +285,27 @@ impl FileIdentifierJob { &mut self, task_id: TaskId, any_task_output: Box, - job_ctx: &impl JobContext, + ctx: &impl OuterContext, dispatcher: &JobTaskDispatcher, ) -> Option> { - if any_task_output.is::() { + if any_task_output.is::() { return self .process_extract_file_metadata_output( + task_id, *any_task_output - .downcast::() + .downcast::() .expect("just checked"), - job_ctx, + ctx, dispatcher, ) .await; - } else if any_task_output.is::() { + } else if any_task_output.is::() { self.process_object_processor_output( + task_id, *any_task_output - .downcast::() + .downcast::() .expect("just checked"), - job_ctx, + ctx, ); } else { unreachable!("Unexpected task output type: "); @@ -319,12 +316,13 @@ impl FileIdentifierJob { async fn process_extract_file_metadata_output( &mut self, - ExtractFileMetadataTaskOutput { + task_id: TaskId, + extract_file_metadata::Output { identified_files, extract_metadata_time, errors, - }: ExtractFileMetadataTaskOutput, - job_ctx: &impl JobContext, + }: extract_file_metadata::Output, + ctx: &impl OuterContext, dispatcher: &JobTaskDispatcher, ) -> Option> { self.metadata.extract_metadata_time += extract_metadata_time; @@ -333,37 +331,46 @@ impl FileIdentifierJob { if identified_files.is_empty() { self.metadata.completed_tasks += 1; - job_ctx.progress(vec![ProgressUpdate::CompletedTaskCount( + ctx.progress(vec![ProgressUpdate::CompletedTaskCount( self.metadata.completed_tasks, )]); None } else { - job_ctx.progress_msg(format!("Identified {} files", identified_files.len())); + ctx.progress_msg(format!("Identified {} files", identified_files.len())); - Some( - dispatcher - .dispatch(ObjectProcessorTask::new_deep( - identified_files, - Arc::clone(job_ctx.db()), - Arc::clone(job_ctx.sync()), - )) - .await, - ) + let with_priority = self.priority_tasks_ids.remove(&task_id); + + let task = dispatcher + .dispatch(ObjectProcessorTask::new( + identified_files, + Arc::clone(ctx.db()), + Arc::clone(ctx.sync()), + with_priority, + )) + .await; + + if with_priority { + self.priority_tasks_ids.insert(task.task_id()); + } + + Some(task) } } fn process_object_processor_output( &mut self, - ObjectProcessorTaskMetrics { + task_id: TaskId, + object_processor::Output { + file_path_ids_with_new_object, assign_cas_ids_time, fetch_existing_objects_time, assign_to_existing_object_time, create_object_time, created_objects_count, linked_objects_count, - }: ObjectProcessorTaskMetrics, - job_ctx: &impl JobContext, + }: object_processor::Output, + ctx: &impl OuterContext, ) { self.metadata.assign_cas_ids_time += assign_cas_ids_time; self.metadata.fetch_existing_objects_time += fetch_existing_objects_time; @@ -374,7 +381,7 @@ impl FileIdentifierJob { self.metadata.completed_tasks += 1; - job_ctx.progress(vec![ + ctx.progress(vec![ ProgressUpdate::CompletedTaskCount(self.metadata.completed_tasks), ProgressUpdate::Message(format!( "Processed {} of {} objects", @@ -382,6 +389,143 @@ impl FileIdentifierJob { self.metadata.total_found_orphans )), ]); + + if self.priority_tasks_ids.remove(&task_id) { + ctx.report_update(UpdateEvent::NewIdentifiedObjects { + file_path_ids: file_path_ids_with_new_object, + }); + } + } + + async fn dispatch_priority_identifier_tasks( + &mut self, + last_orphan_file_path_id: &mut Option, + sub_iso_file_path: &IsolatedFilePathData<'static>, + ctx: &impl OuterContext, + dispatcher: &JobTaskDispatcher, + pending_running_tasks: &FuturesUnordered>, + ) -> Result, file_identifier::Error> { + let db = ctx.db(); + + let mut file_paths_already_identifying = HashSet::new(); + + loop { + #[allow(clippy::cast_possible_wrap)] + // SAFETY: we know that CHUNK_SIZE is a valid i64 + let orphan_paths = db + .file_path() + .find_many(orphan_path_filters_shallow( + self.location.id, + *last_orphan_file_path_id, + sub_iso_file_path, + )) + .order_by(file_path::id::order(SortOrder::Asc)) + .take(CHUNK_SIZE as i64) + .select(file_path_for_file_identifier::select()) + .exec() + .await?; + + if orphan_paths.is_empty() { + break; + } + + file_paths_already_identifying.extend(orphan_paths.iter().map(|path| path.id)); + + self.metadata.total_found_orphans += orphan_paths.len() as u64; + *last_orphan_file_path_id = + Some(orphan_paths.last().expect("orphan_paths is not empty").id); + + ctx.progress(vec![ + ProgressUpdate::TaskCount(self.metadata.total_found_orphans), + ProgressUpdate::Message(format!( + "{} files to be identified", + self.metadata.total_found_orphans + )), + ]); + + let priority_task = dispatcher + .dispatch(ExtractFileMetadataTask::new( + Arc::clone(&self.location), + Arc::clone(&self.location_path), + orphan_paths, + true, + )) + .await; + + self.priority_tasks_ids.insert(priority_task.task_id()); + + pending_running_tasks.push(priority_task); + } + + Ok(file_paths_already_identifying) + } + + async fn dispatch_deep_identifier_tasks( + &mut self, + last_orphan_file_path_id: &mut Option, + maybe_sub_iso_file_path: &Option>, + ctx: &impl OuterContext, + dispatcher: &JobTaskDispatcher, + pending_running_tasks: &FuturesUnordered>, + file_paths_already_identifying: &HashSet, + ) -> Result<(), file_identifier::Error> { + let db = ctx.db(); + + loop { + #[allow(clippy::cast_possible_wrap)] + // SAFETY: we know that CHUNK_SIZE is a valid i64 + let mut orphan_paths = db + .file_path() + .find_many(orphan_path_filters_deep( + self.location.id, + *last_orphan_file_path_id, + maybe_sub_iso_file_path, + )) + .order_by(file_path::id::order(SortOrder::Asc)) + .take(CHUNK_SIZE as i64) + .select(file_path_for_file_identifier::select()) + .exec() + .await?; + + // No other orphans to identify, we can break the loop + if orphan_paths.is_empty() { + break; + } + + // We grab the last id to use as a starting point for the next iteration, in case we skip this one + *last_orphan_file_path_id = + Some(orphan_paths.last().expect("orphan_paths is not empty").id); + + orphan_paths.retain(|path| !file_paths_already_identifying.contains(&path.id)); + + // If we don't have any new orphan paths after filtering out, we can skip this iteration + if orphan_paths.is_empty() { + continue; + } + + self.metadata.total_found_orphans += orphan_paths.len() as u64; + + ctx.progress(vec![ + ProgressUpdate::TaskCount(self.metadata.total_found_orphans), + ProgressUpdate::Message(format!( + "{} files to be identified", + self.metadata.total_found_orphans + )), + ]); + + pending_running_tasks.push( + dispatcher + .dispatch(ExtractFileMetadataTask::new( + Arc::clone(&self.location), + Arc::clone(&self.location_path), + orphan_paths, + false, + )) + .await, + ); + } + + Ok(()) } } @@ -399,7 +543,9 @@ struct SaveState { metadata: Metadata, - errors: Vec, + priority_tasks_ids: HashSet, + + errors: Vec, tasks_for_shutdown_bytes: Option, } @@ -459,13 +605,14 @@ impl From for ReportOutputMetadata { } } -impl SerializableJob for FileIdentifierJob { +impl SerializableJob for FileIdentifier { async fn serialize(self) -> Result>, rmp_serde::encode::Error> { let Self { location, location_path, sub_path, metadata, + priority_tasks_ids, errors, tasks_for_shutdown, .. @@ -476,6 +623,7 @@ impl SerializableJob for FileIdentifierJob { location_path, sub_path, metadata, + priority_tasks_ids, tasks_for_shutdown_bytes: Some(SerializedTasks(rmp_serde::to_vec_named( &tasks_for_shutdown .into_iter() @@ -509,14 +657,14 @@ impl SerializableJob for FileIdentifierJob { async fn deserialize( serialized_job: &[u8], - _: &impl JobContext, + _: &Ctx, ) -> Result)>, rmp_serde::decode::Error> { let SaveState { location, location_path, sub_path, metadata, - + priority_tasks_ids, errors, tasks_for_shutdown_bytes, } = rmp_serde::from_slice::(serialized_job)?; @@ -527,6 +675,7 @@ impl SerializableJob for FileIdentifierJob { location_path, sub_path, metadata, + priority_tasks_ids, errors, pending_tasks_on_resume: Vec::new(), tasks_for_shutdown: Vec::new(), @@ -535,32 +684,3 @@ impl SerializableJob for FileIdentifierJob { ))) } } - -fn orphan_path_filters( - location_id: location::id::Type, - file_path_id: Option, - maybe_sub_iso_file_path: &Option>, -) -> Vec { - sd_utils::chain_optional_iter( - [ - or!( - file_path::object_id::equals(None), - file_path::cas_id::equals(None) - ), - file_path::is_dir::equals(Some(false)), - file_path::location_id::equals(Some(location_id)), - file_path::size_in_bytes_bytes::not(Some(0u64.to_be_bytes().to_vec())), - ], - [ - // this is a workaround for the cursor not working properly - file_path_id.map(file_path::id::gte), - maybe_sub_iso_file_path.as_ref().map(|sub_iso_file_path| { - file_path::materialized_path::starts_with( - sub_iso_file_path - .materialized_path_for_children() - .expect("sub path iso_file_path must be a directory"), - ) - }), - ], - ) -} diff --git a/core/crates/heavy-lifting/src/file_identifier/mod.rs b/core/crates/heavy-lifting/src/file_identifier/mod.rs index 6659ef375..b25e08578 100644 --- a/core/crates/heavy-lifting/src/file_identifier/mod.rs +++ b/core/crates/heavy-lifting/src/file_identifier/mod.rs @@ -1,13 +1,14 @@ -use crate::utils::sub_path::SubPathError; +use crate::utils::sub_path; use sd_core_file_path_helper::{FilePathError, IsolatedFilePathData}; use sd_file_ext::{extensions::Extension, kind::ObjectKind}; +use sd_prisma::prisma::{file_path, location}; use sd_utils::{db::MissingFieldError, error::FileIOError}; use std::{fs::Metadata, path::Path}; -use prisma_client_rust::QueryError; +use prisma_client_rust::{or, QueryError}; use rspc::ErrorCode; use serde::{Deserialize, Serialize}; use specta::Type; @@ -15,20 +16,20 @@ use tokio::fs; use tracing::trace; mod cas_id; -mod job; +pub mod job; mod shallow; mod tasks; use cas_id::generate_cas_id; -pub use job::FileIdentifierJob; +pub use job::FileIdentifier; pub use shallow::shallow; // we break these tasks into chunks of 100 to improve performance const CHUNK_SIZE: usize = 100; #[derive(thiserror::Error, Debug)] -pub enum FileIdentifierError { +pub enum Error { #[error("missing field on database: {0}")] MissingField(#[from] MissingFieldError), #[error("failed to deserialized stored tasks for job resume: {0}")] @@ -39,13 +40,13 @@ pub enum FileIdentifierError { #[error(transparent)] FilePathError(#[from] FilePathError), #[error(transparent)] - SubPath(#[from] SubPathError), + SubPath(#[from] sub_path::Error), } -impl From for rspc::Error { - fn from(err: FileIdentifierError) -> Self { +impl From for rspc::Error { + fn from(err: Error) -> Self { match err { - FileIdentifierError::SubPath(sub_path_err) => sub_path_err.into(), + Error::SubPath(sub_path_err) => sub_path_err.into(), _ => Self::with_cause(ErrorCode::InternalServerError, err.to_string(), err), } @@ -53,7 +54,7 @@ impl From for rspc::Error { } #[derive(thiserror::Error, Debug, Serialize, Deserialize, Type)] -pub enum NonCriticalFileIdentifierError { +pub enum NonCriticalError { #[error("failed to extract file metadata: {0}")] FailedToExtractFileMetadata(String), #[cfg(target_os = "windows")] @@ -118,3 +119,56 @@ impl FileMetadata { }) } } + +fn orphan_path_filters_shallow( + location_id: location::id::Type, + file_path_id: Option, + sub_iso_file_path: &IsolatedFilePathData<'_>, +) -> Vec { + sd_utils::chain_optional_iter( + [ + or!( + file_path::object_id::equals(None), + file_path::cas_id::equals(None) + ), + file_path::is_dir::equals(Some(false)), + file_path::location_id::equals(Some(location_id)), + file_path::materialized_path::equals(Some( + sub_iso_file_path + .materialized_path_for_children() + .expect("sub path for shallow identifier must be a directory"), + )), + file_path::size_in_bytes_bytes::not(Some(0u64.to_be_bytes().to_vec())), + ], + [file_path_id.map(file_path::id::gte)], + ) +} + +fn orphan_path_filters_deep( + location_id: location::id::Type, + file_path_id: Option, + maybe_sub_iso_file_path: &Option>, +) -> Vec { + sd_utils::chain_optional_iter( + [ + or!( + file_path::object_id::equals(None), + file_path::cas_id::equals(None) + ), + file_path::is_dir::equals(Some(false)), + file_path::location_id::equals(Some(location_id)), + file_path::size_in_bytes_bytes::not(Some(0u64.to_be_bytes().to_vec())), + ], + [ + // this is a workaround for the cursor not working properly + file_path_id.map(file_path::id::gte), + maybe_sub_iso_file_path.as_ref().map(|sub_iso_file_path| { + file_path::materialized_path::starts_with( + sub_iso_file_path + .materialized_path_for_children() + .expect("sub path iso_file_path must be a directory"), + ) + }), + ], + ) +} diff --git a/core/crates/heavy-lifting/src/file_identifier/shallow.rs b/core/crates/heavy-lifting/src/file_identifier/shallow.rs index ef85a07b8..dbbedb2c2 100644 --- a/core/crates/heavy-lifting/src/file_identifier/shallow.rs +++ b/core/crates/heavy-lifting/src/file_identifier/shallow.rs @@ -1,10 +1,12 @@ -use crate::{utils::sub_path::maybe_get_iso_file_path_from_sub_path, Error, NonCriticalJobError}; +use crate::{ + file_identifier, utils::sub_path::maybe_get_iso_file_path_from_sub_path, Error, + NonCriticalError, OuterContext, +}; use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::file_path_for_file_identifier; -use sd_core_sync::Manager as SyncManager; -use sd_prisma::prisma::{file_path, location, PrismaClient, SortOrder}; +use sd_prisma::prisma::{file_path, location, SortOrder}; use sd_task_system::{ BaseTaskDispatcher, CancelTaskOnDrop, TaskDispatcher, TaskOutput, TaskStatus, }; @@ -17,39 +19,40 @@ use std::{ use futures_concurrency::future::FutureGroup; use lending_stream::{LendingStream, StreamExt}; -use prisma_client_rust::or; use tracing::{debug, warn}; use super::{ - tasks::{ExtractFileMetadataTask, ExtractFileMetadataTaskOutput, ObjectProcessorTask}, - FileIdentifierError, CHUNK_SIZE, + orphan_path_filters_shallow, + tasks::{ + extract_file_metadata, object_processor, ExtractFileMetadataTask, ObjectProcessorTask, + }, + CHUNK_SIZE, }; pub async fn shallow( location: location::Data, sub_path: impl AsRef + Send, dispatcher: BaseTaskDispatcher, - db: Arc, - sync: Arc, - invalidate_query: impl Fn(&'static str) + Send + Sync, -) -> Result, Error> { + ctx: impl OuterContext, +) -> Result, Error> { let sub_path = sub_path.as_ref(); + let db = ctx.db(); let location_path = maybe_missing(&location.path, "location.path") .map(PathBuf::from) .map(Arc::new) - .map_err(FileIdentifierError::from)?; + .map_err(file_identifier::Error::from)?; let location = Arc::new(location); let sub_iso_file_path = - maybe_get_iso_file_path_from_sub_path(location.id, &Some(sub_path), &*location_path, &db) + maybe_get_iso_file_path_from_sub_path(location.id, &Some(sub_path), &*location_path, db) .await - .map_err(FileIdentifierError::from)? + .map_err(file_identifier::Error::from)? .map_or_else( || { IsolatedFilePathData::new(location.id, &*location_path, &*location_path, true) - .map_err(FileIdentifierError::from) + .map_err(file_identifier::Error::from) }, Ok, )?; @@ -64,7 +67,7 @@ pub async fn shallow( // SAFETY: we know that CHUNK_SIZE is a valid i64 let orphan_paths = db .file_path() - .find_many(orphan_path_filters( + .find_many(orphan_path_filters_shallow( location.id, last_orphan_file_path_id, &sub_iso_file_path, @@ -74,7 +77,7 @@ pub async fn shallow( .select(file_path_for_file_identifier::select()) .exec() .await - .map_err(FileIdentifierError::from)?; + .map_err(file_identifier::Error::from)?; let Some(last_orphan) = orphan_paths.last() else { // No orphans here! @@ -86,10 +89,11 @@ pub async fn shallow( pending_running_tasks.insert(CancelTaskOnDrop( dispatcher - .dispatch(ExtractFileMetadataTask::new_shallow( + .dispatch(ExtractFileMetadataTask::new( Arc::clone(&location), Arc::clone(&location_path), orphan_paths, + true, )) .await, )); @@ -104,10 +108,7 @@ pub async fn shallow( return Ok(vec![]); } - let errors = process_tasks(pending_running_tasks, dispatcher, db, sync).await?; - - invalidate_query("search.paths"); - invalidate_query("search.objects"); + let errors = process_tasks(pending_running_tasks, dispatcher, ctx).await?; Ok(errors) } @@ -115,11 +116,13 @@ pub async fn shallow( async fn process_tasks( pending_running_tasks: FutureGroup>, dispatcher: BaseTaskDispatcher, - db: Arc, - sync: Arc, -) -> Result, Error> { + ctx: impl OuterContext, +) -> Result, Error> { let mut pending_running_tasks = pending_running_tasks.lend_mut(); + let db = ctx.db(); + let sync = ctx.sync(); + let mut errors = vec![]; while let Some((pending_running_tasks, task_result)) = pending_running_tasks.next().await { @@ -128,28 +131,36 @@ async fn process_tasks( // We only care about ExtractFileMetadataTaskOutput because we need to dispatch further tasks // and the ObjectProcessorTask only gives back some metrics not much important for // shallow file identifier - if any_task_output.is::() { - let ExtractFileMetadataTaskOutput { + if any_task_output.is::() { + let extract_file_metadata::Output { identified_files, errors: more_errors, .. - } = *any_task_output - .downcast::() - .expect("just checked"); + } = *any_task_output.downcast().expect("just checked"); errors.extend(more_errors); if !identified_files.is_empty() { pending_running_tasks.insert(CancelTaskOnDrop( dispatcher - .dispatch(ObjectProcessorTask::new_shallow( + .dispatch(ObjectProcessorTask::new( identified_files, - Arc::clone(&db), - Arc::clone(&sync), + Arc::clone(db), + Arc::clone(sync), + true, )) .await, )); } + } else { + let object_processor::Output { + file_path_ids_with_new_object, + .. + } = *any_task_output.downcast().expect("just checked"); + + ctx.report_update(crate::UpdateEvent::NewIdentifiedObjects { + file_path_ids: file_path_ids_with_new_object, + }); } } @@ -181,27 +192,3 @@ async fn process_tasks( Ok(errors) } - -fn orphan_path_filters( - location_id: location::id::Type, - file_path_id: Option, - sub_iso_file_path: &IsolatedFilePathData<'_>, -) -> Vec { - sd_utils::chain_optional_iter( - [ - or!( - file_path::object_id::equals(None), - file_path::cas_id::equals(None) - ), - file_path::is_dir::equals(Some(false)), - file_path::location_id::equals(Some(location_id)), - file_path::materialized_path::equals(Some( - sub_iso_file_path - .materialized_path_for_children() - .expect("sub path for shallow identifier must be a directory"), - )), - file_path::size_in_bytes_bytes::not(Some(0u64.to_be_bytes().to_vec())), - ], - [file_path_id.map(file_path::id::gte)], - ) -} diff --git a/core/crates/heavy-lifting/src/file_identifier/tasks/extract_file_metadata.rs b/core/crates/heavy-lifting/src/file_identifier/tasks/extract_file_metadata.rs index c433bfb00..f8dd41fdc 100644 --- a/core/crates/heavy-lifting/src/file_identifier/tasks/extract_file_metadata.rs +++ b/core/crates/heavy-lifting/src/file_identifier/tasks/extract_file_metadata.rs @@ -1,6 +1,6 @@ use crate::{ - file_identifier::{FileMetadata, NonCriticalFileIdentifierError}, - Error, NonCriticalJobError, + file_identifier::{self, FileMetadata}, + Error, NonCriticalError, }; use sd_core_file_path_helper::IsolatedFilePathData; @@ -34,23 +34,24 @@ pub struct ExtractFileMetadataTask { file_paths_by_id: HashMap, identified_files: HashMap, extract_metadata_time: Duration, - errors: Vec, - is_shallow: bool, + errors: Vec, + with_priority: bool, } #[derive(Debug)] -pub struct ExtractFileMetadataTaskOutput { +pub struct Output { pub identified_files: HashMap, pub extract_metadata_time: Duration, - pub errors: Vec, + pub errors: Vec, } impl ExtractFileMetadataTask { - fn new( + #[must_use] + pub fn new( location: Arc, location_path: Arc, file_paths: Vec, - is_shallow: bool, + with_priority: bool, ) -> Self { Self { id: TaskId::new_v4(), @@ -69,27 +70,9 @@ impl ExtractFileMetadataTask { .collect(), extract_metadata_time: Duration::ZERO, errors: Vec::new(), - is_shallow, + with_priority, } } - - #[must_use] - pub fn new_deep( - location: Arc, - location_path: Arc, - file_paths: Vec, - ) -> Self { - Self::new(location, location_path, file_paths, false) - } - - #[must_use] - pub fn new_shallow( - location: Arc, - location_path: Arc, - file_paths: Vec, - ) -> Self { - Self::new(location, location_path, file_paths, true) - } } #[async_trait::async_trait] @@ -99,7 +82,7 @@ impl Task for ExtractFileMetadataTask { } fn with_priority(&self) -> bool { - self.is_shallow + self.with_priority } async fn run(&mut self, interrupter: &Interrupter) -> Result { @@ -196,7 +179,7 @@ impl Task for ExtractFileMetadataTask { } Ok(ExecStatus::Done( - ExtractFileMetadataTaskOutput { + Output { identified_files: mem::take(identified_files), extract_metadata_time: *extract_metadata_time + start_time.elapsed(), errors: mem::take(errors), @@ -210,7 +193,7 @@ fn handle_non_critical_errors( location_id: location::id::Type, file_path_pub_id: Uuid, e: &FileIOError, - errors: &mut Vec, + errors: &mut Vec, ) { error!("Failed to extract file metadata : {e:#?}"); @@ -221,14 +204,15 @@ fn handle_non_critical_errors( // Handle case where file is on-demand (NTFS only) if e.source.raw_os_error().map_or(false, |code| code == 362) { errors.push( - NonCriticalFileIdentifierError::FailedToExtractMetadataFromOnDemandFile( + file_identifier::NonCriticalError::FailedToExtractMetadataFromOnDemandFile( formatted_error, ) .into(), ); } else { errors.push( - NonCriticalFileIdentifierError::FailedToExtractFileMetadata(formatted_error).into(), + file_identifier::NonCriticalError::FailedToExtractFileMetadata(formatted_error) + .into(), ); } } @@ -236,7 +220,7 @@ fn handle_non_critical_errors( #[cfg(not(target_os = "windows"))] { errors.push( - NonCriticalFileIdentifierError::FailedToExtractFileMetadata(formatted_error).into(), + file_identifier::NonCriticalError::FailedToExtractFileMetadata(formatted_error).into(), ); } } @@ -246,7 +230,7 @@ fn try_iso_file_path_extraction( file_path_pub_id: Uuid, file_path: &file_path_for_file_identifier::Data, location_path: Arc, - errors: &mut Vec, + errors: &mut Vec, ) -> Option<(Uuid, IsolatedFilePathData<'static>, Arc)> { IsolatedFilePathData::try_from((location_id, file_path)) .map(IsolatedFilePathData::to_owned) @@ -254,7 +238,7 @@ fn try_iso_file_path_extraction( .map_err(|e| { error!("Failed to extract isolated file path data: {e:#?}"); errors.push( - NonCriticalFileIdentifierError::FailedToExtractIsolatedFilePathData(format!( + file_identifier::NonCriticalError::FailedToExtractIsolatedFilePathData(format!( "" )) .into(), diff --git a/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs b/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs index c5bac9fb1..c06fc8ad0 100644 --- a/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs +++ b/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs @@ -4,11 +4,11 @@ use sd_file_ext::kind::ObjectKind; use serde::{Deserialize, Serialize}; -mod extract_file_metadata; -mod object_processor; +pub mod extract_file_metadata; +pub mod object_processor; -pub use extract_file_metadata::{ExtractFileMetadataTask, ExtractFileMetadataTaskOutput}; -pub use object_processor::{ObjectProcessorTask, ObjectProcessorTaskMetrics}; +pub use extract_file_metadata::ExtractFileMetadataTask; +pub use object_processor::ObjectProcessorTask; #[derive(Debug, Serialize, Deserialize)] pub(super) struct IdentifiedFile { diff --git a/core/crates/heavy-lifting/src/file_identifier/tasks/object_processor.rs b/core/crates/heavy-lifting/src/file_identifier/tasks/object_processor.rs index cdb9f0842..bdc826ddc 100644 --- a/core/crates/heavy-lifting/src/file_identifier/tasks/object_processor.rs +++ b/core/crates/heavy-lifting/src/file_identifier/tasks/object_processor.rs @@ -1,4 +1,4 @@ -use crate::{file_identifier::FileIdentifierError, Error}; +use crate::{file_identifier, Error}; use sd_core_prisma_helpers::{ file_path_for_file_identifier, file_path_pub_id, object_for_file_identifier, @@ -36,22 +36,23 @@ pub struct ObjectProcessorTask { db: Arc, sync: Arc, identified_files: HashMap, - metrics: ObjectProcessorTaskMetrics, + output: Output, stage: Stage, - is_shallow: bool, + with_priority: bool, } #[derive(Debug, Serialize, Deserialize)] pub struct SaveState { id: TaskId, identified_files: HashMap, - metrics: ObjectProcessorTaskMetrics, + output: Output, stage: Stage, - is_shallow: bool, + with_priority: bool, } #[derive(Debug, Serialize, Deserialize, Default)] -pub struct ObjectProcessorTaskMetrics { +pub struct Output { + pub file_path_ids_with_new_object: Vec, pub assign_cas_ids_time: Duration, pub fetch_existing_objects_time: Duration, pub assign_to_existing_object_time: Duration, @@ -71,11 +72,12 @@ enum Stage { } impl ObjectProcessorTask { - fn new( + #[must_use] + pub fn new( identified_files: HashMap, db: Arc, sync: Arc, - is_shallow: bool, + with_priority: bool, ) -> Self { Self { id: TaskId::new_v4(), @@ -83,26 +85,10 @@ impl ObjectProcessorTask { sync, identified_files, stage: Stage::Starting, - metrics: ObjectProcessorTaskMetrics::default(), - is_shallow, + output: Output::default(), + with_priority, } } - - pub fn new_deep( - identified_files: HashMap, - db: Arc, - sync: Arc, - ) -> Self { - Self::new(identified_files, db, sync, false) - } - - pub fn new_shallow( - identified_files: HashMap, - db: Arc, - sync: Arc, - ) -> Self { - Self::new(identified_files, db, sync, true) - } } #[async_trait::async_trait] @@ -112,7 +98,7 @@ impl Task for ObjectProcessorTask { } fn with_priority(&self) -> bool { - self.is_shallow + self.with_priority } async fn run(&mut self, interrupter: &Interrupter) -> Result { @@ -121,8 +107,9 @@ impl Task for ObjectProcessorTask { sync, identified_files, stage, - metrics: - ObjectProcessorTaskMetrics { + output: + Output { + file_path_ids_with_new_object, assign_cas_ids_time, fetch_existing_objects_time, assign_to_existing_object_time, @@ -193,6 +180,11 @@ impl Task for ObjectProcessorTask { *created_objects_count = create_objects(identified_files, db, sync).await?; *create_object_time = start.elapsed(); + *file_path_ids_with_new_object = identified_files + .values() + .map(|IdentifiedFile { file_path, .. }| file_path.id) + .collect(); + break; } } @@ -200,7 +192,7 @@ impl Task for ObjectProcessorTask { check_interruption!(interrupter); } - Ok(ExecStatus::Done(mem::take(&mut self.metrics).into_output())) + Ok(ExecStatus::Done(mem::take(&mut self.output).into_output())) } } @@ -208,7 +200,7 @@ async fn assign_cas_id_to_file_paths( identified_files: &HashMap, db: &PrismaClient, sync: &SyncManager, -) -> Result<(), FileIdentifierError> { +) -> Result<(), file_identifier::Error> { // Assign cas_id to each file path sync.write_ops( db, @@ -243,7 +235,7 @@ async fn assign_cas_id_to_file_paths( async fn fetch_existing_objects_by_cas_id( identified_files: &HashMap, db: &PrismaClient, -) -> Result, FileIdentifierError> { +) -> Result, file_identifier::Error> { // Retrieves objects that are already connected to file paths with the same id db.object() .find_many(vec![object::file_paths::some(vec![ @@ -280,7 +272,7 @@ async fn assign_existing_objects_to_file_paths( objects_by_cas_id: &HashMap, db: &PrismaClient, sync: &SyncManager, -) -> Result, FileIdentifierError> { +) -> Result, file_identifier::Error> { // Attempt to associate each file path with an object that has been // connected to file paths with the same cas_id sync.write_ops( @@ -341,7 +333,7 @@ async fn create_objects( identified_files: &HashMap, db: &PrismaClient, sync: &SyncManager, -) -> Result { +) -> Result { trace!("Creating {} new Objects", identified_files.len(),); let (object_create_args, file_path_update_args) = identified_files @@ -433,18 +425,18 @@ impl SerializableTask for ObjectProcessorTask { let Self { id, identified_files, - metrics, + output, stage, - is_shallow, + with_priority, .. } = self; rmp_serde::to_vec_named(&SaveState { id, identified_files, - metrics, + output, stage, - is_shallow, + with_priority, }) } @@ -456,17 +448,17 @@ impl SerializableTask for ObjectProcessorTask { |SaveState { id, identified_files, - metrics, + output, stage, - is_shallow, + with_priority, }| Self { id, db, sync, identified_files, - metrics, + output, stage, - is_shallow, + with_priority, }, ) } diff --git a/core/crates/heavy-lifting/src/indexer/job.rs b/core/crates/heavy-lifting/src/indexer/job.rs index dacea9ab5..f2ad3f6e5 100644 --- a/core/crates/heavy-lifting/src/indexer/job.rs +++ b/core/crates/heavy-lifting/src/indexer/job.rs @@ -1,15 +1,15 @@ use crate::{ - indexer::BATCH_SIZE, + indexer, job_system::{ job::{ - Job, JobContext, JobName, JobReturn, JobTaskDispatcher, ProgressUpdate, ReturnStatus, + Job, JobName, JobReturn, JobTaskDispatcher, OuterContext, ProgressUpdate, ReturnStatus, }, report::ReportOutputMetadata, utils::cancel_pending_tasks, SerializableJob, SerializedTasks, }, utils::sub_path::get_full_path_from_sub_path, - Error, LocationScanState, NonCriticalJobError, + Error, LocationScanState, NonCriticalError, }; use sd_core_file_path_helper::IsolatedFilePathData; @@ -47,11 +47,11 @@ use super::{ updater::{UpdateTask, UpdateTaskOutput}, walker::{WalkDirTask, WalkTaskOutput, WalkedEntry}, }, - update_directory_sizes, update_location_size, IndexerError, IsoFilePathFactory, WalkerDBProxy, + update_directory_sizes, update_location_size, IsoFilePathFactory, WalkerDBProxy, BATCH_SIZE, }; #[derive(Debug)] -pub struct IndexerJob { +pub struct Indexer { location: location_with_indexer_rules::Data, sub_path: Option, metadata: Metadata, @@ -63,19 +63,19 @@ pub struct IndexerJob { ancestors_already_indexed: HashSet>, iso_paths_and_sizes: HashMap, u64>, - errors: Vec, + errors: Vec, pending_tasks_on_resume: Vec>, tasks_for_shutdown: Vec>>, } -impl Job for IndexerJob { +impl Job for Indexer { const NAME: JobName = JobName::Indexer; async fn resume_tasks( &mut self, dispatcher: &JobTaskDispatcher, - ctx: &impl JobContext, + ctx: &impl OuterContext, SerializedTasks(serialized_tasks): SerializedTasks, ) -> Result<(), Error> { let location_id = self.location.id; @@ -83,7 +83,7 @@ impl Job for IndexerJob { self.pending_tasks_on_resume = dispatcher .dispatch_many_boxed( rmp_serde::from_slice::)>>(&serialized_tasks) - .map_err(IndexerError::from)? + .map_err(indexer::Error::from)? .into_iter() .map(|(task_kind, task_bytes)| { let indexer_ruler = self.indexer_ruler.clone(); @@ -123,17 +123,17 @@ impl Job for IndexerJob { .collect::>() .try_join() .await - .map_err(IndexerError::from)?, + .map_err(indexer::Error::from)?, ) .await; Ok(()) } - async fn run( + async fn run( mut self, dispatcher: JobTaskDispatcher, - ctx: impl JobContext, + ctx: Ctx, ) -> Result { let mut pending_running_tasks = FuturesUnordered::new(); @@ -148,7 +148,9 @@ impl Job for IndexerJob { } if !self.tasks_for_shutdown.is_empty() { - return Ok(ReturnStatus::Shutdown(self.serialize().await)); + return Ok(ReturnStatus::Shutdown( + SerializableJob::::serialize(self).await, + )); } if !self.ancestors_needing_indexing.is_empty() { @@ -182,7 +184,9 @@ impl Job for IndexerJob { } if !self.tasks_for_shutdown.is_empty() { - return Ok(ReturnStatus::Shutdown(self.serialize().await)); + return Ok(ReturnStatus::Shutdown( + SerializableJob::::serialize(self).await, + )); } } @@ -217,7 +221,7 @@ impl Job for IndexerJob { .await?; } - update_location_size(location.id, ctx.db(), &ctx.query_invalidator()).await?; + update_location_size(location.id, ctx.db(), &ctx).await?; metadata.db_write_time += start_size_update_time.elapsed(); } @@ -243,7 +247,7 @@ impl Job for IndexerJob { ) .exec() .await - .map_err(IndexerError::from)?; + .map_err(indexer::Error::from)?; Ok(ReturnStatus::Completed( JobReturn::builder() @@ -254,11 +258,11 @@ impl Job for IndexerJob { } } -impl IndexerJob { +impl Indexer { pub fn new( location: location_with_indexer_rules::Data, sub_path: Option, - ) -> Result { + ) -> Result { Ok(Self { indexer_ruler: location .indexer_rules @@ -295,12 +299,12 @@ impl IndexerJob { &mut self, task_id: TaskId, any_task_output: Box, - job_ctx: &impl JobContext, + ctx: &impl OuterContext, dispatcher: &JobTaskDispatcher, - ) -> Result>, IndexerError> { + ) -> Result>, indexer::Error> { self.metadata.completed_tasks += 1; - job_ctx.progress(vec![ProgressUpdate::CompletedTaskCount( + ctx.progress(vec![ProgressUpdate::CompletedTaskCount( self.metadata.completed_tasks, )]); @@ -310,7 +314,7 @@ impl IndexerJob { *any_task_output .downcast::() .expect("just checked"), - job_ctx, + ctx, dispatcher, ) .await; @@ -319,14 +323,14 @@ impl IndexerJob { *any_task_output .downcast::() .expect("just checked"), - job_ctx, + ctx, ); } else if any_task_output.is::() { self.process_update_output( *any_task_output .downcast::() .expect("just checked"), - job_ctx, + ctx, ); } else { unreachable!("Unexpected task output type: "); @@ -348,9 +352,9 @@ impl IndexerJob { mut handles, scan_time, }: WalkTaskOutput, - job_ctx: &impl JobContext, + ctx: &impl OuterContext, dispatcher: &JobTaskDispatcher, - ) -> Result>, IndexerError> { + ) -> Result>, indexer::Error> { self.metadata.scan_read_time += scan_time; let (to_create_count, to_update_count) = (to_create.len(), to_update.len()); @@ -398,7 +402,7 @@ impl IndexerJob { let db_delete_time = Instant::now(); self.metadata.removed_count += - remove_non_existing_file_paths(to_remove, job_ctx.db(), job_ctx.sync()).await?; + remove_non_existing_file_paths(to_remove, ctx.db(), ctx.sync()).await?; self.metadata.db_write_time += db_delete_time.elapsed(); let save_tasks = to_create @@ -414,8 +418,8 @@ impl IndexerJob { self.location.id, self.location.pub_id.clone(), chunked_saves, - Arc::clone(job_ctx.db()), - Arc::clone(job_ctx.sync()), + Arc::clone(ctx.db()), + Arc::clone(ctx.sync()), ) }) .collect::>(); @@ -431,8 +435,8 @@ impl IndexerJob { UpdateTask::new_deep( chunked_updates, - Arc::clone(job_ctx.db()), - Arc::clone(job_ctx.sync()), + Arc::clone(ctx.db()), + Arc::clone(ctx.sync()), ) }) .collect::>(); @@ -442,7 +446,7 @@ impl IndexerJob { self.metadata.total_tasks += handles.len() as u64; - job_ctx.progress(vec![ + ctx.progress(vec![ ProgressUpdate::TaskCount(handles.len() as u64), ProgressUpdate::message(format!( "Found {to_create_count} new files and {to_update_count} to update" @@ -458,12 +462,12 @@ impl IndexerJob { saved_count, save_duration, }: SaveTaskOutput, - job_ctx: &impl JobContext, + ctx: &impl OuterContext, ) { self.metadata.indexed_count += saved_count; self.metadata.db_write_time += save_duration; - job_ctx.progress_msg(format!("Saved {saved_count} files")); + ctx.progress_msg(format!("Saved {saved_count} files")); } fn process_update_output( @@ -472,25 +476,25 @@ impl IndexerJob { updated_count, update_duration, }: UpdateTaskOutput, - job_ctx: &impl JobContext, + ctx: &impl OuterContext, ) { self.metadata.updated_count += updated_count; self.metadata.db_write_time += update_duration; - job_ctx.progress_msg(format!("Updated {updated_count} files")); + ctx.progress_msg(format!("Updated {updated_count} files")); } async fn process_handles( &mut self, pending_running_tasks: &mut FuturesUnordered>, - job_ctx: &impl JobContext, + ctx: &impl OuterContext, dispatcher: &JobTaskDispatcher, ) -> Option> { while let Some(task) = pending_running_tasks.next().await { match task { Ok(TaskStatus::Done((task_id, TaskOutput::Out(out)))) => { let more_handles = match self - .process_task_output(task_id, out, job_ctx, dispatcher) + .process_task_output(task_id, out, ctx, dispatcher) .await { Ok(more_handles) => more_handles, @@ -538,9 +542,9 @@ impl IndexerJob { async fn init_or_resume( &mut self, pending_running_tasks: &mut FuturesUnordered>, - job_ctx: &impl JobContext, + ctx: &impl OuterContext, dispatcher: &JobTaskDispatcher, - ) -> Result<(), IndexerError> { + ) -> Result<(), indexer::Error> { // if we don't have any pending task, then this is a fresh job if self.pending_tasks_on_resume.is_empty() { let walker_root_path = Arc::new( @@ -548,7 +552,7 @@ impl IndexerJob { self.location.id, &self.sub_path, &*self.iso_file_path_factory.location_path, - job_ctx.db(), + ctx.db(), ) .await?, ); @@ -562,7 +566,7 @@ impl IndexerJob { self.iso_file_path_factory.clone(), WalkerDBProxy { location_id: self.location.id, - db: Arc::clone(job_ctx.db()), + db: Arc::clone(ctx.db()), }, dispatcher.clone(), )?) @@ -633,12 +637,12 @@ struct SaveState { ancestors_already_indexed: HashSet>, paths_and_sizes: HashMap, u64>, - errors: Vec, + errors: Vec, tasks_for_shutdown_bytes: Option, } -impl SerializableJob for IndexerJob { +impl SerializableJob for Indexer { async fn serialize(self) -> Result>, rmp_serde::encode::Error> { let Self { location, @@ -706,7 +710,7 @@ impl SerializableJob for IndexerJob { async fn deserialize( serialized_job: &[u8], - _: &impl JobContext, + _: &Ctx, ) -> Result)>, rmp_serde::decode::Error> { let SaveState { location, @@ -744,7 +748,7 @@ impl SerializableJob for IndexerJob { } } -impl Hash for IndexerJob { +impl Hash for Indexer { fn hash(&self, state: &mut H) { self.location.id.hash(state); if let Some(ref sub_path) = self.sub_path { diff --git a/core/crates/heavy-lifting/src/indexer/mod.rs b/core/crates/heavy-lifting/src/indexer/mod.rs index 2bac41b1b..78b9d3827 100644 --- a/core/crates/heavy-lifting/src/indexer/mod.rs +++ b/core/crates/heavy-lifting/src/indexer/mod.rs @@ -1,4 +1,4 @@ -use crate::{utils::sub_path::SubPathError, NonCriticalJobError}; +use crate::{utils::sub_path, OuterContext}; use sd_core_file_path_helper::{FilePathError, IsolatedFilePathData}; use sd_core_indexer_rules::IndexerRuleError; @@ -8,7 +8,7 @@ use sd_core_prisma_helpers::{ use sd_core_sync::Manager as SyncManager; use sd_prisma::{ - prisma::{file_path, location, PrismaClient, SortOrder}, + prisma::{file_path, indexer_rule, location, PrismaClient, SortOrder}, prisma_sync, }; use sd_sync::OperationFactory; @@ -33,11 +33,10 @@ use serde::{Deserialize, Serialize}; use specta::Type; use tracing::warn; -mod job; +pub mod job; mod shallow; mod tasks; -pub use job::IndexerJob; pub use shallow::shallow; use tasks::walker; @@ -46,12 +45,12 @@ use tasks::walker; const BATCH_SIZE: usize = 1000; #[derive(thiserror::Error, Debug)] -pub enum IndexerError { +pub enum Error { // Not Found errors #[error("indexer rule not found: ")] - IndexerRuleNotFound(i32), + IndexerRuleNotFound(indexer_rule::id::Type), #[error(transparent)] - SubPath(#[from] SubPathError), + SubPath(#[from] sub_path::Error), // Internal Errors #[error("database Error: {0}")] @@ -72,16 +71,16 @@ pub enum IndexerError { Rules(#[from] IndexerRuleError), } -impl From for rspc::Error { - fn from(err: IndexerError) -> Self { +impl From for rspc::Error { + fn from(err: Error) -> Self { match err { - IndexerError::IndexerRuleNotFound(_) => { + Error::IndexerRuleNotFound(_) => { Self::with_cause(ErrorCode::NotFound, err.to_string(), err) } - IndexerError::SubPath(sub_path_err) => sub_path_err.into(), + Error::SubPath(sub_path_err) => sub_path_err.into(), - IndexerError::Rules(rule_err) => rule_err.into(), + Error::Rules(rule_err) => rule_err.into(), _ => Self::with_cause(ErrorCode::InternalServerError, err.to_string(), err), } @@ -89,7 +88,7 @@ impl From for rspc::Error { } #[derive(thiserror::Error, Debug, Serialize, Deserialize, Type)] -pub enum NonCriticalIndexerError { +pub enum NonCriticalError { #[error("failed to read directory entry: {0}")] FailedDirectoryEntry(String), #[error("failed to fetch metadata: {0}")] @@ -134,7 +133,7 @@ async fn update_directory_sizes( iso_paths_and_sizes: HashMap, u64, impl BuildHasher + Send>, db: &PrismaClient, sync: &SyncManager, -) -> Result<(), IndexerError> { +) -> Result<(), Error> { let to_sync_and_update = db ._batch(chunk_db_queries(iso_paths_and_sizes.keys(), db)) .await? @@ -160,7 +159,7 @@ async fn update_directory_sizes( ), )) }) - .collect::, IndexerError>>()? + .collect::, Error>>()? .into_iter() .unzip::<_, _, Vec<_>, Vec<_>>(); @@ -169,11 +168,11 @@ async fn update_directory_sizes( Ok(()) } -async fn update_location_size( +async fn update_location_size( location_id: location::id::Type, db: &PrismaClient, - invalidate_query: &InvalidateQuery, -) -> Result<(), IndexerError> { + ctx: &impl OuterContext, +) -> Result<(), Error> { let total_size = db .file_path() .find_many(vec![ @@ -201,8 +200,8 @@ async fn update_location_size( .exec() .await?; - invalidate_query("locations.list"); - invalidate_query("locations.get"); + ctx.invalidate_query("locations.list"); + ctx.invalidate_query("locations.get"); Ok(()) } @@ -211,7 +210,7 @@ async fn remove_non_existing_file_paths( to_remove: Vec, db: &PrismaClient, sync: &sd_core_sync::Manager, -) -> Result { +) -> Result { #[allow(clippy::cast_sign_loss)] let (sync_params, db_params): (Vec<_>, Vec<_>) = to_remove .into_iter() @@ -248,8 +247,8 @@ async fn reverse_update_directories_sizes( location_path: impl AsRef + Send, db: &PrismaClient, sync: &SyncManager, - errors: &mut Vec, -) -> Result<(), IndexerError> { + errors: &mut Vec, +) -> Result<(), Error> { let location_path = location_path.as_ref(); let ancestors = base_path @@ -279,7 +278,7 @@ async fn reverse_update_directories_sizes( IsolatedFilePathData::try_from(file_path) .map_err(|e| { errors.push( - NonCriticalIndexerError::MissingFilePathData(format!( + NonCriticalError::MissingFilePathData(format!( "Found a file_path missing data: , error: {e:#?}", from_bytes_to_uuid(&pub_id) )) @@ -345,8 +344,8 @@ async fn compute_sizes( materialized_paths: Vec, pub_id_by_ancestor_materialized_path: &mut HashMap, db: &PrismaClient, - errors: &mut Vec, -) -> Result<(), IndexerError> { + errors: &mut Vec, +) -> Result<(), Error> { db.file_path() .find_many(vec![ file_path::location_id::equals(Some(location_id)), @@ -371,7 +370,7 @@ async fn compute_sizes( } } else { errors.push( - NonCriticalIndexerError::MissingFilePathData(format!( + NonCriticalError::MissingFilePathData(format!( "Corrupt database possessing a file_path entry without materialized_path: ", from_bytes_to_uuid(&file_path.pub_id) )) @@ -409,7 +408,7 @@ impl walker::WalkerDBProxy for WalkerDBProxy { async fn fetch_file_paths( &self, found_paths: Vec, - ) -> Result, IndexerError> { + ) -> Result, Error> { // Each found path is a AND with 4 terms, and SQLite has a expression tree limit of 1000 terms // so we will use chunks of 200 just to be safe self.db @@ -435,7 +434,7 @@ impl walker::WalkerDBProxy for WalkerDBProxy { &self, parent_iso_file_path: &IsolatedFilePathData<'_>, unique_location_id_materialized_path_name_extension_params: Vec, - ) -> Result, NonCriticalIndexerError> { + ) -> Result, NonCriticalError> { // NOTE: This batch size can be increased if we wish to trade memory for more performance const BATCH_SIZE: i64 = 1000; @@ -461,7 +460,7 @@ impl walker::WalkerDBProxy for WalkerDBProxy { .flat_map(|file_paths| file_paths.into_iter().map(|file_path| file_path.id)) .collect::>() }) - .map_err(|e| NonCriticalIndexerError::FetchAlreadyExistingFilePathIds(e.to_string()))?; + .map_err(|e| NonCriticalError::FetchAlreadyExistingFilePathIds(e.to_string()))?; let mut to_remove = vec![]; let mut cursor = 1; @@ -484,7 +483,7 @@ impl walker::WalkerDBProxy for WalkerDBProxy { .select(file_path_pub_and_cas_ids::select()) .exec() .await - .map_err(|e| NonCriticalIndexerError::FetchFilePathsToRemove(e.to_string()))?; + .map_err(|e| NonCriticalError::FetchFilePathsToRemove(e.to_string()))?; #[allow(clippy::cast_possible_truncation)] // Safe because we are using a constant let should_stop = found.len() < BATCH_SIZE as usize; diff --git a/core/crates/heavy-lifting/src/indexer/shallow.rs b/core/crates/heavy-lifting/src/indexer/shallow.rs index 96eaf4398..085b6f1a5 100644 --- a/core/crates/heavy-lifting/src/indexer/shallow.rs +++ b/core/crates/heavy-lifting/src/indexer/shallow.rs @@ -1,4 +1,6 @@ -use crate::{utils::sub_path::get_full_path_from_sub_path, Error, NonCriticalJobError}; +use crate::{ + indexer, utils::sub_path::get_full_path_from_sub_path, Error, NonCriticalError, OuterContext, +}; use sd_core_indexer_rules::{IndexerRule, IndexerRuler}; use sd_core_prisma_helpers::location_with_indexer_rules; @@ -25,29 +27,28 @@ use super::{ updater::{UpdateTask, UpdateTaskOutput}, walker::{ToWalkEntry, WalkDirTask, WalkTaskOutput, WalkedEntry}, }, - update_directory_sizes, update_location_size, IndexerError, IsoFilePathFactory, WalkerDBProxy, - BATCH_SIZE, + update_directory_sizes, update_location_size, IsoFilePathFactory, WalkerDBProxy, BATCH_SIZE, }; pub async fn shallow( location: location_with_indexer_rules::Data, sub_path: impl AsRef + Send, dispatcher: BaseTaskDispatcher, - db: Arc, - sync: Arc, - invalidate_query: impl Fn(&'static str) + Send + Sync, -) -> Result, Error> { + ctx: impl OuterContext, +) -> Result, Error> { let sub_path = sub_path.as_ref(); + let db = ctx.db(); + let sync = ctx.sync(); let location_path = maybe_missing(&location.path, "location.path") .map(PathBuf::from) .map(Arc::new) - .map_err(IndexerError::from)?; + .map_err(indexer::Error::from)?; let to_walk_path = Arc::new( - get_full_path_from_sub_path(location.id, &Some(sub_path), &*location_path, &db) + get_full_path_from_sub_path(location.id, &Some(sub_path), &*location_path, db) .await - .map_err(IndexerError::from)?, + .map_err(indexer::Error::from)?, ); let Some(WalkTaskOutput { @@ -62,7 +63,7 @@ pub async fn shallow( &location, Arc::clone(&location_path), Arc::clone(&to_walk_path), - Arc::clone(&db), + Arc::clone(db), &dispatcher, ) .await? @@ -70,7 +71,7 @@ pub async fn shallow( return Ok(vec![]); }; - let removed_count = remove_non_existing_file_paths(to_remove, &db, &sync).await?; + let removed_count = remove_non_existing_file_paths(to_remove, db, sync).await?; let Some(Metadata { indexed_count, @@ -79,8 +80,8 @@ pub async fn shallow( &location, to_create, to_update, - Arc::clone(&db), - Arc::clone(&sync), + Arc::clone(db), + Arc::clone(sync), &dispatcher, ) .await? @@ -91,8 +92,8 @@ pub async fn shallow( if indexed_count > 0 || removed_count > 0 || updated_count > 0 { update_directory_sizes( HashMap::from([(directory_iso_file_path, total_size)]), - &db, - &sync, + db, + sync, ) .await?; @@ -101,18 +102,18 @@ pub async fn shallow( &*to_walk_path, location.id, &*location_path, - &db, - &sync, + db, + sync, &mut errors, ) .await?; } - update_location_size(location.id, &db, &invalidate_query).await?; + update_location_size(location.id, db, &ctx).await?; } if indexed_count > 0 || removed_count > 0 { - invalidate_query("search.paths"); + ctx.invalidate_query("search.paths"); } Ok(errors) @@ -135,7 +136,7 @@ async fn walk( .map(|rule| IndexerRule::try_from(&rule.indexer_rule)) .collect::, _>>() .map(IndexerRuler::new) - .map_err(IndexerError::from)?, + .map_err(indexer::Error::from)?, IsoFilePathFactory { location_id: location.id, location_path, diff --git a/core/crates/heavy-lifting/src/indexer/tasks/saver.rs b/core/crates/heavy-lifting/src/indexer/tasks/saver.rs index 715fc770c..3c97b01ad 100644 --- a/core/crates/heavy-lifting/src/indexer/tasks/saver.rs +++ b/core/crates/heavy-lifting/src/indexer/tasks/saver.rs @@ -1,4 +1,4 @@ -use crate::{indexer::IndexerError, Error}; +use crate::{indexer, Error}; use sd_core_file_path_helper::IsolatedFilePathDataParts; use sd_core_sync::Manager as SyncManager; @@ -234,7 +234,7 @@ impl Task for SaveTask { ), ) .await - .map_err(IndexerError::from)? as u64; + .map_err(indexer::Error::from)? as u64; trace!("Inserted {saved_count} records"); diff --git a/core/crates/heavy-lifting/src/indexer/tasks/updater.rs b/core/crates/heavy-lifting/src/indexer/tasks/updater.rs index f4cf0d7fd..e547ec8ac 100644 --- a/core/crates/heavy-lifting/src/indexer/tasks/updater.rs +++ b/core/crates/heavy-lifting/src/indexer/tasks/updater.rs @@ -1,4 +1,4 @@ -use crate::{indexer::IndexerError, Error}; +use crate::{indexer, Error}; use sd_core_file_path_helper::IsolatedFilePathDataParts; use sd_core_sync::Manager as SyncManager; @@ -222,7 +222,7 @@ impl Task for UpdateTask { (sync_stuff.into_iter().flatten().collect(), paths_to_update), ) .await - .map_err(IndexerError::from)?; + .map_err(indexer::Error::from)?; trace!("Updated {updated:?} records"); @@ -240,7 +240,7 @@ async fn fetch_objects_ids_to_unlink( walked_entries: &[WalkedEntry], object_ids_that_should_be_unlinked: &mut HashSet, db: &PrismaClient, -) -> Result<(), IndexerError> { +) -> Result<(), indexer::Error> { if object_ids_that_should_be_unlinked.is_empty() { // First we consult which file paths we should unlink let object_ids = walked_entries diff --git a/core/crates/heavy-lifting/src/indexer/tasks/walker.rs b/core/crates/heavy-lifting/src/indexer/tasks/walker.rs index 3ed771e2e..238d8ab52 100644 --- a/core/crates/heavy-lifting/src/indexer/tasks/walker.rs +++ b/core/crates/heavy-lifting/src/indexer/tasks/walker.rs @@ -1,7 +1,4 @@ -use crate::{ - indexer::{IndexerError, NonCriticalIndexerError}, - Error, NonCriticalJobError, -}; +use crate::{indexer, Error, NonCriticalError}; use sd_core_file_path_helper::{FilePathError, FilePathMetadata, IsolatedFilePathData}; use sd_core_indexer_rules::{IndexerRuler, MetadataForIndexerRules, RuleKind}; @@ -111,13 +108,14 @@ pub trait WalkerDBProxy: Clone + Send + Sync + fmt::Debug + 'static { fn fetch_file_paths( &self, found_paths: Vec, - ) -> impl Future, IndexerError>> + Send; + ) -> impl Future, indexer::Error>> + Send; fn fetch_file_paths_to_remove( &self, parent_iso_file_path: &IsolatedFilePathData<'_>, unique_location_id_materialized_path_name_extension_params: Vec, - ) -> impl Future, NonCriticalIndexerError>> + Send; + ) -> impl Future, indexer::NonCriticalError>> + + Send; } #[derive(Debug, Serialize, Deserialize)] @@ -141,7 +139,7 @@ pub struct WalkTaskOutput { pub to_update: Vec, pub to_remove: Vec, pub accepted_ancestors: HashSet, - pub errors: Vec, + pub errors: Vec, pub directory_iso_file_path: IsolatedFilePathData<'static>, pub total_size: u64, pub handles: Vec>, @@ -160,7 +158,7 @@ struct InnerMetadata { } impl InnerMetadata { - fn new(path: impl AsRef, metadata: &Metadata) -> Result { + fn new(path: impl AsRef, metadata: &Metadata) -> Result { let FilePathMetadata { inode, size_in_bytes, @@ -168,7 +166,7 @@ impl InnerMetadata { modified_at, hidden, } = FilePathMetadata::from_path(path, metadata) - .map_err(|e| NonCriticalIndexerError::FilePathMetadata(e.to_string()))?; + .map_err(|e| indexer::NonCriticalError::FilePathMetadata(e.to_string()))?; Ok(Self { is_dir: metadata.is_dir(), @@ -237,7 +235,7 @@ struct WalkDirSaveState { root: Arc, entry_iso_file_path: IsolatedFilePathData<'static>, stage: WalkerStageSaveState, - errors: Vec, + errors: Vec, scan_time: Duration, is_shallow: bool, } @@ -367,7 +365,7 @@ where db_proxy: DBProxy, stage: WalkerStage, maybe_dispatcher: Option, - errors: Vec, + errors: Vec, scan_time: Duration, is_shallow: bool, } @@ -385,7 +383,7 @@ where iso_file_path_factory: IsoPathFactory, db_proxy: DBProxy, dispatcher: Dispatcher, - ) -> Result { + ) -> Result { let entry = entry.into(); Ok(Self { id: TaskId::new_v4(), @@ -415,7 +413,7 @@ where indexer_ruler: IndexerRuler, iso_file_path_factory: IsoPathFactory, db_proxy: DBProxy, - ) -> Result { + ) -> Result { let entry = entry.into(); Ok(Self { id: TaskId::new_v4(), @@ -545,7 +543,7 @@ where *stage = WalkerStage::Walking { read_dir_stream: ReadDirStream::new(fs::read_dir(&path).await.map_err( |e| { - IndexerError::FileIO( + indexer::Error::FileIO( (&path, e, "Failed to open directory to read its entries") .into(), ) @@ -565,8 +563,8 @@ where found_paths.push(dir_entry.path()); } Err(e) => { - errors.push(NonCriticalJobError::Indexer( - NonCriticalIndexerError::FailedDirectoryEntry( + errors.push(NonCriticalError::Indexer( + indexer::NonCriticalError::FailedDirectoryEntry( FileIOError::from((&path, e)).to_string(), ), )); @@ -709,7 +707,7 @@ where async fn segregate_creates_and_updates( walking_entries: &mut Vec, db_proxy: &impl WalkerDBProxy, -) -> Result<(Vec, Vec, u64), IndexerError> { +) -> Result<(Vec, Vec, u64), Error> { if walking_entries.is_empty() { Ok((vec![], vec![], 0)) } else { @@ -791,7 +789,7 @@ async fn keep_walking( db_proxy: &impl WalkerDBProxy, maybe_to_keep_walking: &mut Option>, dispatcher: &Option>, - errors: &mut Vec, + errors: &mut Vec, ) -> Vec> { if let (Some(dispatcher), Some(to_keep_walking)) = (dispatcher, maybe_to_keep_walking) { dispatcher @@ -807,7 +805,7 @@ async fn keep_walking( db_proxy.clone(), dispatcher.clone(), ) - .map_err(|e| NonCriticalIndexerError::DispatchKeepWalking(e.to_string())) + .map_err(|e| indexer::NonCriticalError::DispatchKeepWalking(e.to_string())) }) .filter_map(|res| res.map_err(|e| errors.push(e.into())).ok()), ) @@ -819,7 +817,7 @@ async fn keep_walking( async fn collect_metadata( found_paths: &mut Vec, - errors: &mut Vec, + errors: &mut Vec, ) -> HashMap { found_paths .drain(..) @@ -827,7 +825,7 @@ async fn collect_metadata( fs::metadata(¤t_path) .await .map_err(|e| { - NonCriticalIndexerError::Metadata( + indexer::NonCriticalError::Metadata( FileIOError::from((¤t_path, e)).to_string(), ) }) @@ -847,7 +845,7 @@ async fn collect_metadata( async fn apply_indexer_rules( paths_and_metadatas: &mut HashMap, indexer_ruler: &IndexerRuler, - errors: &mut Vec, + errors: &mut Vec, ) -> HashMap>)> { paths_and_metadatas .drain() @@ -860,7 +858,7 @@ async fn apply_indexer_rules( .map(|acceptance_per_rule_kind| { (current_path, (metadata, acceptance_per_rule_kind)) }) - .map_err(|e| NonCriticalIndexerError::IndexerRule(e.to_string())) + .map_err(|e| indexer::NonCriticalError::IndexerRule(e.to_string())) }) .collect::>() .join() @@ -879,7 +877,7 @@ async fn process_rules_results( (InnerMetadata, HashMap>), >, maybe_to_keep_walking: &mut Option>, - errors: &mut Vec, + errors: &mut Vec, ) -> (HashMap, HashSet) { let root = root.as_ref(); @@ -951,7 +949,7 @@ async fn process_rules_results( fs::metadata(&ancestor_path) .await .map_err(|e| { - NonCriticalIndexerError::Metadata( + indexer::NonCriticalError::Metadata( FileIOError::from((&ancestor_path, e)).to_string(), ) }) @@ -964,7 +962,7 @@ async fn process_rules_results( } .into() }) - .map_err(|e| NonCriticalIndexerError::FilePathMetadata(e.to_string())) + .map_err(|e| indexer::NonCriticalError::FilePathMetadata(e.to_string())) }) }) .collect::>() @@ -1023,7 +1021,7 @@ fn accept_ancestors( accepted: &mut HashMap, iso_file_path_factory: &impl IsoFilePathFactory, accepted_ancestors: &mut HashMap, PathBuf>, - errors: &mut Vec, + errors: &mut Vec, ) { // If the ancestors directories wasn't indexed before, now we do for ancestor in current_path @@ -1033,7 +1031,7 @@ fn accept_ancestors( { if let Ok(iso_file_path) = iso_file_path_factory .build(ancestor, true) - .map_err(|e| errors.push(NonCriticalIndexerError::IsoFilePath(e.to_string()).into())) + .map_err(|e| errors.push(indexer::NonCriticalError::IsoFilePath(e.to_string()).into())) { match accepted_ancestors.entry(iso_file_path) { Entry::Occupied(_) => { @@ -1083,7 +1081,7 @@ async fn gather_file_paths_to_remove( entry_iso_file_path: &IsolatedFilePathData<'_>, iso_file_path_factory: &impl IsoFilePathFactory, db_proxy: &impl WalkerDBProxy, - errors: &mut Vec, + errors: &mut Vec, ) -> (Vec, Vec) { let (walking, to_delete_params) = accepted_paths .drain() @@ -1102,7 +1100,7 @@ async fn gather_file_paths_to_remove( ) }) .map_err(|e| { - errors.push(NonCriticalIndexerError::IsoFilePath(e.to_string()).into()); + errors.push(indexer::NonCriticalError::IsoFilePath(e.to_string()).into()); }) .ok() }) @@ -1158,7 +1156,7 @@ mod tests { async fn fetch_file_paths( &self, _: Vec, - ) -> Result, IndexerError> { + ) -> Result, indexer::Error> { Ok(vec![]) } @@ -1166,7 +1164,7 @@ mod tests { &self, _: &IsolatedFilePathData<'_>, _: Vec, - ) -> Result, NonCriticalIndexerError> { + ) -> Result, indexer::NonCriticalError> { Ok(vec![]) } } diff --git a/core/crates/heavy-lifting/src/job_system/job.rs b/core/crates/heavy-lifting/src/job_system/job.rs index eadb587e1..4006481fd 100644 --- a/core/crates/heavy-lifting/src/job_system/job.rs +++ b/core/crates/heavy-lifting/src/job_system/job.rs @@ -1,4 +1,4 @@ -use crate::{Error, NonCriticalJobError}; +use crate::{Error, NonCriticalError, UpdateEvent}; use sd_core_sync::Manager as SyncManager; @@ -11,6 +11,7 @@ use std::{ collections::{hash_map::DefaultHasher, VecDeque}, hash::{Hash, Hasher}, marker::PhantomData, + path::Path, pin::pin, sync::Arc, }; @@ -46,6 +47,7 @@ use super::{ pub enum JobName { Indexer, FileIdentifier, + MediaProcessor, // TODO: Add more job names as needed } @@ -72,7 +74,7 @@ impl ProgressUpdate { } } -pub trait JobContext: Send + Sync + Clone + 'static { +pub trait OuterContext: Send + Sync + Clone + 'static { fn id(&self) -> Uuid; fn db(&self) -> &Arc; fn sync(&self) -> &Arc; @@ -82,6 +84,8 @@ pub trait JobContext: Send + Sync + Clone + 'static { fn progress_msg(&self, msg: impl Into) { self.progress(vec![ProgressUpdate::Message(msg.into())]); } + fn report_update(&self, update: UpdateEvent); + fn get_data_directory(&self) -> &Path; } pub trait Job: Send + Sync + Hash + 'static { @@ -91,31 +95,31 @@ pub trait Job: Send + Sync + Hash + 'static { fn resume_tasks( &mut self, dispatcher: &JobTaskDispatcher, - ctx: &impl JobContext, + ctx: &impl OuterContext, serialized_tasks: SerializedTasks, ) -> impl Future> + Send { async move { Ok(()) } } - fn run( + fn run( self, dispatcher: JobTaskDispatcher, - ctx: impl JobContext, + ctx: Ctx, ) -> impl Future> + Send; } pub trait IntoJob where - J: Job + SerializableJob, - Ctx: JobContext, + J: Job + SerializableJob, + Ctx: OuterContext, { fn into_job(self) -> Box>; } impl IntoJob for J where - J: Job + SerializableJob, - Ctx: JobContext, + J: Job + SerializableJob, + Ctx: OuterContext, { fn into_job(self) -> Box> { let id = JobId::new_v4(); @@ -132,8 +136,8 @@ where impl IntoJob for JobBuilder where - J: Job + SerializableJob, - Ctx: JobContext, + J: Job + SerializableJob, + Ctx: OuterContext, { fn into_job(self) -> Box> { self.build() @@ -144,7 +148,7 @@ where pub struct JobReturn { data: JobOutputData, metadata: Option, - non_critical_errors: Vec, + non_critical_errors: Vec, } impl JobReturn { @@ -185,7 +189,7 @@ impl JobReturnBuilder { } #[must_use] - pub fn with_non_critical_errors(mut self, errors: Vec) -> Self { + pub fn with_non_critical_errors(mut self, errors: Vec) -> Self { if self.job_return.non_critical_errors.is_empty() { self.job_return.non_critical_errors = errors; } else { @@ -207,7 +211,7 @@ pub struct JobOutput { job_name: JobName, data: JobOutputData, metadata: Vec, - non_critical_errors: Vec, + non_critical_errors: Vec, } impl JobOutput { @@ -260,8 +264,8 @@ pub enum JobOutputData { pub struct JobBuilder where - J: Job + SerializableJob, - Ctx: JobContext, + J: Job + SerializableJob, + Ctx: OuterContext, { id: JobId, job: J, @@ -272,8 +276,8 @@ where impl JobBuilder where - J: Job + SerializableJob, - Ctx: JobContext, + J: Job + SerializableJob, + Ctx: OuterContext, { pub fn build(self) -> Box> { Box::new(JobHolder { @@ -315,7 +319,7 @@ where } #[must_use] - pub fn enqueue_next(mut self, next: impl Job + SerializableJob) -> Self { + pub fn enqueue_next(mut self, next: impl Job + SerializableJob) -> Self { let next_job_order = self.next_jobs.len() + 1; let mut child_job_builder = JobBuilder::new(next).with_parent_id(self.id); @@ -333,8 +337,8 @@ where pub struct JobHolder where - J: Job + SerializableJob, - Ctx: JobContext, + J: Job + SerializableJob, + Ctx: OuterContext, { pub(super) id: JobId, pub(super) job: J, @@ -343,14 +347,14 @@ where pub(super) _ctx: PhantomData, } -pub struct JobHandle { +pub struct JobHandle { pub(crate) next_jobs: VecDeque>>, - pub(crate) job_ctx: Ctx, + pub(crate) ctx: Ctx, pub(crate) report: Report, pub(crate) commands_tx: chan::Sender, } -impl JobHandle { +impl JobHandle { pub async fn send_command(&mut self, command: Command) -> Result<(), JobSystemError> { if self.commands_tx.send(command).await.is_err() { warn!("Tried to send a {command:?} to a job that was already completed"); @@ -375,7 +379,7 @@ impl JobHandle { next_job_report.status = new_status; next_job_report.completed_at = completed_at; - next_job_report.update(self.job_ctx.db()).await + next_job_report.update(self.ctx.db()).await }) .collect::>() .try_join() @@ -391,7 +395,7 @@ impl JobHandle { let Self { next_jobs, report, - job_ctx, + ctx, .. } = self; @@ -400,7 +404,7 @@ impl JobHandle { report.started_at = Some(start_time); } - let db = job_ctx.db(); + let db = ctx.db(); // If the report doesn't have a created_at date, it's a new report if report.created_at.is_none() { @@ -432,21 +436,17 @@ impl JobHandle { &mut self, job_return: JobReturn, ) -> Result { - let Self { - report, job_ctx, .. - } = self; + let Self { report, ctx, .. } = self; let output = JobOutput::prepare_output_and_report(job_return, report); - report.update(job_ctx.db()).await?; + report.update(ctx.db()).await?; Ok(output) } pub async fn failed_job(&mut self, e: &Error) -> Result<(), JobSystemError> { - let Self { - report, job_ctx, .. - } = self; + let Self { report, ctx, .. } = self; error!( "Job failed with a critical error: {e:#?};", report.id, report.name @@ -456,15 +456,13 @@ impl JobHandle { report.critical_error = Some(e.to_string()); report.completed_at = Some(Utc::now()); - report.update(job_ctx.db()).await?; + report.update(ctx.db()).await?; self.command_children(Command::Cancel).await } pub async fn shutdown_pause_job(&mut self) -> Result<(), JobSystemError> { - let Self { - report, job_ctx, .. - } = self; + let Self { report, ctx, .. } = self; info!( "Job paused due to system shutdown, we will pause all children jobs", report.id, report.name @@ -472,15 +470,13 @@ impl JobHandle { report.status = Status::Paused; - report.update(job_ctx.db()).await?; + report.update(ctx.db()).await?; self.command_children(Command::Pause).await } pub async fn cancel_job(&mut self) -> Result<(), JobSystemError> { - let Self { - report, job_ctx, .. - } = self; + let Self { report, ctx, .. } = self; info!( "Job canceled, we will cancel all children jobs", report.id, report.name @@ -489,14 +485,14 @@ impl JobHandle { report.status = Status::Canceled; report.completed_at = Some(Utc::now()); - report.update(job_ctx.db()).await?; + report.update(ctx.db()).await?; self.command_children(Command::Cancel).await } } #[async_trait::async_trait] -pub trait DynJob: Send + Sync + 'static { +pub trait DynJob: Send + Sync + 'static { fn id(&self) -> JobId; fn job_name(&self) -> JobName; @@ -514,14 +510,14 @@ pub trait DynJob: Send + Sync + 'static { fn dispatch( self: Box, base_dispatcher: BaseTaskDispatcher, - job_ctx: Ctx, + ctx: Ctx, done_tx: chan::Sender<(JobId, Result)>, ) -> JobHandle; fn resume( self: Box, base_dispatcher: BaseTaskDispatcher, - job_ctx: Ctx, + ctx: Ctx, serialized_tasks: Option, done_tx: chan::Sender<(JobId, Result)>, ) -> JobHandle; @@ -530,8 +526,8 @@ pub trait DynJob: Send + Sync + 'static { #[async_trait::async_trait] impl DynJob for JobHolder where - J: Job + SerializableJob, - Ctx: JobContext, + J: Job + SerializableJob, + Ctx: OuterContext, { fn id(&self) -> JobId { self.id @@ -567,7 +563,7 @@ where fn dispatch( self: Box, base_dispatcher: BaseTaskDispatcher, - job_ctx: Ctx, + ctx: Ctx, done_tx: chan::Sender<(JobId, Result)>, ) -> JobHandle { let (commands_tx, commands_rx) = chan::bounded(8); @@ -575,7 +571,7 @@ where spawn(to_spawn_job( self.id, self.job, - job_ctx.clone(), + ctx.clone(), None, base_dispatcher, commands_rx, @@ -584,7 +580,7 @@ where JobHandle { next_jobs: self.next_jobs, - job_ctx, + ctx, report: self.report, commands_tx, } @@ -593,7 +589,7 @@ where fn resume( self: Box, base_dispatcher: BaseTaskDispatcher, - job_ctx: Ctx, + ctx: Ctx, serialized_tasks: Option, done_tx: chan::Sender<(JobId, Result)>, ) -> JobHandle { @@ -602,7 +598,7 @@ where spawn(to_spawn_job( self.id, self.job, - job_ctx.clone(), + ctx.clone(), serialized_tasks, base_dispatcher, commands_rx, @@ -611,17 +607,17 @@ where JobHandle { next_jobs: self.next_jobs, - job_ctx, + ctx, report: self.report, commands_tx, } } } -async fn to_spawn_job( +async fn to_spawn_job( id: JobId, mut job: impl Job, - job_ctx: Ctx, + ctx: Ctx, existing_tasks: Option, base_dispatcher: BaseTaskDispatcher, commands_rx: chan::Receiver, @@ -641,10 +637,7 @@ async fn to_spawn_job( JobTaskDispatcher::new(base_dispatcher, running_state_rx); if let Some(existing_tasks) = existing_tasks { - if let Err(e) = job - .resume_tasks(&dispatcher, &job_ctx, existing_tasks) - .await - { + if let Err(e) = job.resume_tasks(&dispatcher, &ctx, existing_tasks).await { done_tx .send((id, Err(e))) .await @@ -657,7 +650,7 @@ async fn to_spawn_job( let mut msgs_stream = pin!(( commands_rx.map(StreamMessage::Commands), remote_controllers_rx.map(StreamMessage::NewRemoteController), - stream::once(job.run(dispatcher, job_ctx)).map(StreamMessage::Done), + stream::once(job.run(dispatcher, ctx)).map(StreamMessage::Done), ) .merge()); diff --git a/core/crates/heavy-lifting/src/job_system/mod.rs b/core/crates/heavy-lifting/src/job_system/mod.rs index 9f8c6c15b..a8b552a70 100644 --- a/core/crates/heavy-lifting/src/job_system/mod.rs +++ b/core/crates/heavy-lifting/src/job_system/mod.rs @@ -21,7 +21,7 @@ mod store; pub mod utils; use error::JobSystemError; -use job::{IntoJob, Job, JobContext, JobName, JobOutput}; +use job::{IntoJob, Job, JobName, JobOutput, OuterContext}; use runner::{run, JobSystemRunner, RunnerMessage}; use store::{load_jobs, StoredJobEntry}; @@ -38,13 +38,13 @@ pub enum Command { Cancel, } -pub struct JobSystem { +pub struct JobSystem { msgs_tx: chan::Sender>, job_outputs_rx: chan::Receiver<(JobId, Result)>, runner_handle: RefCell>>, } -impl JobSystem { +impl JobSystem { pub async fn new( base_dispatcher: BaseTaskDispatcher, data_directory: impl AsRef + Send, @@ -164,11 +164,11 @@ impl JobSystem { /// Dispatch a new job to the system /// # Panics /// Panics only happen if internal channels are unexpectedly closed - pub async fn dispatch( + pub async fn dispatch>( &mut self, job: impl IntoJob + Send, location_id: location::id::Type, - job_ctx: Ctx, + ctx: Ctx, ) -> Result { let dyn_job = job.into_job(); let id = dyn_job.id(); @@ -179,7 +179,7 @@ impl JobSystem { id, location_id, dyn_job, - job_ctx, + ctx, ack_tx, }) .await @@ -230,9 +230,9 @@ impl JobSystem { /// SAFETY: Due to usage of refcell we lost `Sync` impl, but we only use it to have a shutdown method /// receiving `&self` which is called once, and we also use `try_borrow_mut` so we never panic -unsafe impl Sync for JobSystem {} +unsafe impl Sync for JobSystem {} -async fn load_stored_job_entries( +async fn load_stored_job_entries( store_jobs_file: impl AsRef + Send, previously_existing_job_contexts: &HashMap, msgs_tx: &chan::Sender>, @@ -273,11 +273,11 @@ async fn load_stored_job_entries( res.map_err(|e| error!("Failed to load stored jobs: {e:#?}")) .ok() }) - .flat_map(|(stored_jobs, job_ctx)| { + .flat_map(|(stored_jobs, ctx)| { stored_jobs .into_iter() .map(move |(location_id, dyn_job, serialized_tasks)| { - let job_ctx = job_ctx.clone(); + let ctx = ctx.clone(); async move { let (ack_tx, ack_rx) = oneshot::channel(); @@ -286,7 +286,7 @@ async fn load_stored_job_entries( id: dyn_job.id(), location_id, dyn_job, - job_ctx, + ctx, serialized_tasks, ack_tx, }) diff --git a/core/crates/heavy-lifting/src/job_system/runner.rs b/core/crates/heavy-lifting/src/job_system/runner.rs index 0f257f259..f1ea8f137 100644 --- a/core/crates/heavy-lifting/src/job_system/runner.rs +++ b/core/crates/heavy-lifting/src/job_system/runner.rs @@ -26,7 +26,7 @@ use tracing::{debug, error, info, warn}; use uuid::Uuid; use super::{ - job::{DynJob, JobContext, JobHandle, JobName, JobOutput, ReturnStatus}, + job::{DynJob, JobHandle, JobName, JobOutput, OuterContext, ReturnStatus}, report, store::{StoredJob, StoredJobEntry}, Command, JobId, JobSystemError, SerializedTasks, @@ -35,19 +35,19 @@ use super::{ const JOBS_INITIAL_CAPACITY: usize = 32; const FIVE_MINUTES: Duration = Duration::from_secs(5 * 60); -pub(super) enum RunnerMessage { +pub(super) enum RunnerMessage { NewJob { id: JobId, location_id: location::id::Type, dyn_job: Box>, - job_ctx: Ctx, + ctx: Ctx, ack_tx: oneshot::Sender>, }, ResumeStoredJob { id: JobId, location_id: location::id::Type, dyn_job: Box>, - job_ctx: Ctx, + ctx: Ctx, serialized_tasks: Option, ack_tx: oneshot::Sender>, }, @@ -64,7 +64,7 @@ pub(super) enum RunnerMessage { Shutdown, } -pub(super) struct JobSystemRunner { +pub(super) struct JobSystemRunner { base_dispatcher: BaseTaskDispatcher, handles: HashMap>, job_hashes: HashMap, @@ -76,7 +76,7 @@ pub(super) struct JobSystemRunner { job_outputs_tx: chan::Sender<(JobId, Result)>, } -impl JobSystemRunner { +impl JobSystemRunner { pub(super) fn new( base_dispatcher: BaseTaskDispatcher, job_return_status_tx: chan::Sender<(JobId, Result)>, @@ -100,7 +100,7 @@ impl JobSystemRunner { id: JobId, location_id: location::id::Type, dyn_job: Box>, - job_ctx: Ctx, + ctx: Ctx, maybe_existing_tasks: Option, ) -> Result<(), JobSystemError> { let Self { @@ -114,7 +114,7 @@ impl JobSystemRunner { .. } = self; - let db = job_ctx.db(); + let db = ctx.db(); let job_name = dyn_job.job_name(); let job_hash = dyn_job.hash(); @@ -137,14 +137,14 @@ impl JobSystemRunner { let mut handle = if maybe_existing_tasks.is_some() { dyn_job.resume( base_dispatcher.clone(), - job_ctx.clone(), + ctx.clone(), maybe_existing_tasks, job_return_status_tx.clone(), ) } else { dyn_job.dispatch( base_dispatcher.clone(), - job_ctx.clone(), + ctx.clone(), job_return_status_tx.clone(), ) }; @@ -169,7 +169,7 @@ impl JobSystemRunner { .map(|dyn_job| dyn_job.report_mut()) .map(|next_job_report| async { if next_job_report.created_at.is_none() { - next_job_report.create(job_ctx.db()).await + next_job_report.create(ctx.db()).await } else { Ok(()) } @@ -277,7 +277,7 @@ impl JobSystemRunner { }; jobs_to_store_by_ctx_id - .entry(handle.job_ctx.id()) + .entry(handle.ctx.id()) .or_default() .push(StoredJobEntry { location_id, @@ -384,7 +384,7 @@ impl JobSystemRunner { } } -fn try_dispatch_next_job( +fn try_dispatch_next_job( handle: &mut JobHandle, base_dispatcher: BaseTaskDispatcher, (job_hashes, job_hashes_by_id): (&mut HashMap, &mut HashMap), @@ -397,11 +397,8 @@ fn try_dispatch_next_job( if let Entry::Vacant(e) = job_hashes.entry(next_hash) { e.insert(next_id); job_hashes_by_id.insert(next_id, next_hash); - let mut next_handle = next.dispatch( - base_dispatcher, - handle.job_ctx.clone(), - job_return_status_tx, - ); + let mut next_handle = + next.dispatch(base_dispatcher, handle.ctx.clone(), job_return_status_tx); assert!( next_handle.next_jobs.is_empty(), @@ -418,13 +415,13 @@ fn try_dispatch_next_job( } } -pub(super) async fn run( +pub(super) async fn run( mut runner: JobSystemRunner, store_jobs_file: impl AsRef + Send, msgs_rx: chan::Receiver>, job_return_status_rx: chan::Receiver<(JobId, Result)>, ) { - enum StreamMessage { + enum StreamMessage { ReturnStatus((JobId, Result)), RunnerMessage(RunnerMessage), CleanMemoryTick, @@ -453,15 +450,11 @@ pub(super) async fn run( id, location_id, dyn_job, - job_ctx, + ctx, ack_tx, }) => { ack_tx - .send( - runner - .new_job(id, location_id, dyn_job, job_ctx, None) - .await, - ) + .send(runner.new_job(id, location_id, dyn_job, ctx, None).await) .expect("ack channel closed before sending new job response"); } @@ -469,14 +462,14 @@ pub(super) async fn run( id, location_id, dyn_job, - job_ctx, + ctx, serialized_tasks, ack_tx, }) => { ack_tx .send( runner - .new_job(id, location_id, dyn_job, job_ctx, serialized_tasks) + .new_job(id, location_id, dyn_job, ctx, serialized_tasks) .await, ) .expect("ack channel closed before sending resume job response"); diff --git a/core/crates/heavy-lifting/src/job_system/store.rs b/core/crates/heavy-lifting/src/job_system/store.rs index 4d1cb9485..8c40c7dc5 100644 --- a/core/crates/heavy-lifting/src/job_system/store.rs +++ b/core/crates/heavy-lifting/src/job_system/store.rs @@ -1,4 +1,4 @@ -use crate::{file_identifier::FileIdentifierJob, indexer::IndexerJob}; +use crate::{file_identifier, indexer, media_processor}; use sd_prisma::prisma::{job, location}; use sd_utils::uuid_to_bytes; @@ -14,7 +14,7 @@ use futures_concurrency::future::TryJoin; use serde::{Deserialize, Serialize}; use super::{ - job::{DynJob, Job, JobContext, JobHolder, JobName}, + job::{DynJob, Job, JobHolder, JobName, OuterContext}, report::{Report, ReportError}, JobId, JobSystemError, }; @@ -22,7 +22,7 @@ use super::{ #[derive(Debug, Serialize, Deserialize)] pub struct SerializedTasks(pub Vec); -pub trait SerializableJob: 'static +pub trait SerializableJob: 'static where Self: Sized, { @@ -35,7 +35,7 @@ where #[allow(unused_variables)] fn deserialize( serialized_job: &[u8], - ctx: &impl JobContext, + ctx: &Ctx, ) -> impl Future< Output = Result)>, rmp_serde::decode::Error>, > + Send { @@ -57,9 +57,9 @@ pub struct StoredJobEntry { pub(super) next_jobs: Vec, } -pub async fn load_jobs( +pub async fn load_jobs( entries: Vec, - job_ctx: &Ctx, + ctx: &Ctx, ) -> Result< Vec<( location::id::Type, @@ -68,7 +68,7 @@ pub async fn load_jobs( )>, JobSystemError, > { - let mut reports = job_ctx + let mut reports = ctx .db() .job() .find_many(vec![job::id::in_vec( @@ -105,7 +105,7 @@ pub async fn load_jobs( .ok_or(ReportError::MissingReport(root_job.id))?; Ok(async move { - load_job(root_job, report, job_ctx) + load_job(root_job, report, ctx) .await .map(|maybe_loaded_job| { maybe_loaded_job @@ -135,7 +135,7 @@ pub async fn load_jobs( next_jobs_and_reports .into_iter() .map(|(next_job, report)| async move { - load_job(next_job, report, job_ctx) + load_job(next_job, report, ctx) .await .map(|maybe_loaded_next_job| { maybe_loaded_next_job.map(|(next_dyn_job, next_tasks)| { @@ -166,7 +166,7 @@ pub async fn load_jobs( } macro_rules! match_deserialize_job { - ($stored_job:ident, $report:ident, $job_ctx:ident, $ctx_type:ty, [$($job_type:ty),+ $(,)?]) => {{ + ($stored_job:ident, $report:ident, $ctx:ident, $ctx_type:ty, [$($job_type:ty),+ $(,)?]) => {{ let StoredJob { id, name, @@ -175,9 +175,9 @@ macro_rules! match_deserialize_job { match name { - $(<$job_type as Job>::NAME => <$job_type as SerializableJob>::deserialize( + $(<$job_type as Job>::NAME => <$job_type as SerializableJob<$ctx_type>>::deserialize( &serialized_job, - $job_ctx, + $ctx, ).await .map(|maybe_job| maybe_job.map(|(job, tasks)| -> ( Box>, @@ -200,21 +200,21 @@ macro_rules! match_deserialize_job { }}; } -async fn load_job( +async fn load_job( stored_job: StoredJob, report: Report, - job_ctx: &Ctx, + ctx: &Ctx, ) -> Result>, Option)>, JobSystemError> { match_deserialize_job!( stored_job, report, - job_ctx, + ctx, Ctx, [ - IndexerJob, - FileIdentifierJob, + indexer::job::Indexer, + file_identifier::job::FileIdentifier, + media_processor::job::MediaProcessor, // TODO: Add more jobs here - // e.g.: FileIdentifierJob, MediaProcessorJob, etc., ] ) } diff --git a/core/crates/heavy-lifting/src/lib.rs b/core/crates/heavy-lifting/src/lib.rs index 1cc079f8d..696940179 100644 --- a/core/crates/heavy-lifting/src/lib.rs +++ b/core/crates/heavy-lifting/src/lib.rs @@ -27,6 +27,7 @@ #![forbid(deprecated_in_future)] #![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)] +use sd_prisma::prisma::file_path; use sd_task_system::TaskSystemError; use serde::{Deserialize, Serialize}; @@ -36,22 +37,24 @@ use thiserror::Error; pub mod file_identifier; pub mod indexer; pub mod job_system; +pub mod media_processor; pub mod utils; -use file_identifier::{FileIdentifierError, NonCriticalFileIdentifierError}; -use indexer::{IndexerError, NonCriticalIndexerError}; +use media_processor::ThumbKey; pub use job_system::{ - job::{IntoJob, JobBuilder, JobContext, JobName, JobOutput, JobOutputData, ProgressUpdate}, + job::{IntoJob, JobBuilder, JobName, JobOutput, JobOutputData, OuterContext, ProgressUpdate}, JobId, JobSystem, }; #[derive(Error, Debug)] pub enum Error { #[error(transparent)] - Indexer(#[from] IndexerError), + Indexer(#[from] indexer::Error), #[error(transparent)] - FileIdentifier(#[from] FileIdentifierError), + FileIdentifier(#[from] file_identifier::Error), + #[error(transparent)] + MediaProcessor(#[from] media_processor::Error), #[error(transparent)] TaskSystem(#[from] TaskSystemError), @@ -62,6 +65,7 @@ impl From for rspc::Error { match e { Error::Indexer(e) => e.into(), Error::FileIdentifier(e) => e.into(), + Error::MediaProcessor(e) => e.into(), Error::TaskSystem(e) => { Self::with_cause(rspc::ErrorCode::InternalServerError, e.to_string(), e) } @@ -70,12 +74,14 @@ impl From for rspc::Error { } #[derive(thiserror::Error, Debug, Serialize, Deserialize, Type)] -pub enum NonCriticalJobError { +pub enum NonCriticalError { // TODO: Add variants as needed #[error(transparent)] - Indexer(#[from] NonCriticalIndexerError), + Indexer(#[from] indexer::NonCriticalError), #[error(transparent)] - FileIdentifier(#[from] NonCriticalFileIdentifierError), + FileIdentifier(#[from] file_identifier::NonCriticalError), + #[error(transparent)] + MediaProcessor(#[from] media_processor::NonCriticalError), } #[repr(i32)] @@ -86,3 +92,13 @@ pub enum LocationScanState { FilesIdentified = 2, Completed = 3, } + +#[derive(Debug, Serialize, Type)] +pub enum UpdateEvent { + NewThumbnailEvent { + thumb_key: ThumbKey, + }, + NewIdentifiedObjects { + file_path_ids: Vec, + }, +} diff --git a/core/crates/heavy-lifting/src/media_processor/helpers/exif_media_data.rs b/core/crates/heavy-lifting/src/media_processor/helpers/exif_media_data.rs new file mode 100644 index 000000000..3fa2c7618 --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/helpers/exif_media_data.rs @@ -0,0 +1,85 @@ +use crate::media_processor::{self, media_data_extractor}; + +use sd_file_ext::extensions::{Extension, ImageExtension, ALL_IMAGE_EXTENSIONS}; +use sd_media_metadata::ExifMetadata; +use sd_prisma::prisma::{exif_data, object, PrismaClient}; + +use std::path::Path; + +use once_cell::sync::Lazy; + +pub static AVAILABLE_EXTENSIONS: Lazy> = Lazy::new(|| { + ALL_IMAGE_EXTENSIONS + .iter() + .copied() + .filter(|&ext| can_extract(ext)) + .map(Extension::Image) + .collect() +}); + +pub const fn can_extract(image_extension: ImageExtension) -> bool { + use ImageExtension::{ + Avci, Avcs, Avif, Dng, Heic, Heif, Heifs, Hif, Jpeg, Jpg, Png, Tiff, Webp, + }; + matches!( + image_extension, + Tiff | Dng | Jpeg | Jpg | Heif | Heifs | Heic | Avif | Avcs | Avci | Hif | Png | Webp + ) +} + +pub fn to_query( + mdi: ExifMetadata, + object_id: exif_data::object_id::Type, +) -> exif_data::CreateUnchecked { + exif_data::CreateUnchecked { + object_id, + _params: vec![ + exif_data::camera_data::set(serde_json::to_vec(&mdi.camera_data).ok()), + exif_data::media_date::set(serde_json::to_vec(&mdi.date_taken).ok()), + exif_data::resolution::set(serde_json::to_vec(&mdi.resolution).ok()), + exif_data::media_location::set(serde_json::to_vec(&mdi.location).ok()), + exif_data::artist::set(mdi.artist), + exif_data::description::set(mdi.description), + exif_data::copyright::set(mdi.copyright), + exif_data::exif_version::set(mdi.exif_version), + exif_data::epoch_time::set(mdi.date_taken.map(|x| x.unix_timestamp())), + ], + } +} + +pub async fn extract( + path: impl AsRef + Send, +) -> Result, media_processor::NonCriticalError> { + let path = path.as_ref(); + + ExifMetadata::from_path(&path).await.map_err(|e| { + media_data_extractor::NonCriticalError::FailedToExtractImageMediaData( + path.to_path_buf(), + e.to_string(), + ) + .into() + }) +} + +pub async fn save( + media_datas: Vec<(ExifMetadata, object::id::Type)>, + db: &PrismaClient, +) -> Result { + db.exif_data() + .create_many( + media_datas + .into_iter() + .map(|(exif_data, object_id)| to_query(exif_data, object_id)) + .collect(), + ) + .skip_duplicates() + .exec() + .await + .map(|created| { + #[allow(clippy::cast_sign_loss)] + { + created as u64 + } + }) + .map_err(Into::into) +} diff --git a/core/crates/heavy-lifting/src/media_processor/helpers/ffmpeg_media_data.rs b/core/crates/heavy-lifting/src/media_processor/helpers/ffmpeg_media_data.rs new file mode 100644 index 000000000..0d1734c22 --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/helpers/ffmpeg_media_data.rs @@ -0,0 +1,572 @@ +use crate::media_processor::{self, media_data_extractor}; + +use sd_file_ext::extensions::{ + AudioExtension, Extension, VideoExtension, ALL_AUDIO_EXTENSIONS, ALL_VIDEO_EXTENSIONS, +}; +use sd_media_metadata::{ + ffmpeg::{ + audio_props::AudioProps, + chapter::Chapter, + codec::{Codec, Props}, + metadata::Metadata, + program::Program, + stream::Stream, + video_props::VideoProps, + }, + FFmpegMetadata, +}; +use sd_prisma::prisma::{ + ffmpeg_data, ffmpeg_media_audio_props, ffmpeg_media_chapter, ffmpeg_media_codec, + ffmpeg_media_program, ffmpeg_media_stream, ffmpeg_media_video_props, object, PrismaClient, +}; +use sd_utils::db::ffmpeg_data_field_to_db; + +use std::{collections::HashMap, path::Path}; + +use futures_concurrency::future::TryJoin; +use once_cell::sync::Lazy; +use prisma_client_rust::QueryError; +use tracing::error; + +pub static AVAILABLE_EXTENSIONS: Lazy> = Lazy::new(|| { + ALL_AUDIO_EXTENSIONS + .iter() + .copied() + .filter(|&ext| can_extract_for_audio(ext)) + .map(Extension::Audio) + .chain( + ALL_VIDEO_EXTENSIONS + .iter() + .copied() + .filter(|&ext| can_extract_for_video(ext)) + .map(Extension::Video), + ) + .collect() +}); + +pub const fn can_extract_for_audio(audio_extension: AudioExtension) -> bool { + use AudioExtension::{ + Aac, Adts, Aif, Aiff, Amr, Aptx, Ast, Caf, Flac, Loas, M4a, Mid, Mp2, Mp3, Oga, Ogg, Opus, + Tta, Voc, Wav, Wma, Wv, + }; + + matches!( + audio_extension, + Mp3 | Mp2 + | M4a | Wav | Aiff + | Aif | Flac | Ogg + | Oga | Opus | Wma + | Amr | Aac | Wv + | Voc | Tta | Loas + | Caf | Aptx | Adts + | Ast | Mid + ) +} + +pub const fn can_extract_for_video(video_extension: VideoExtension) -> bool { + use VideoExtension::{ + Asf, Avi, Avifs, F4v, Flv, Hevc, M2ts, M2v, M4v, Mjpeg, Mkv, Mov, Mp4, Mpe, Mpeg, Mpg, Mts, + Mxf, Ogv, Qt, Swf, Ts, Vob, Webm, Wm, Wmv, Wtv, _3gp, + }; + + matches!( + video_extension, + Avi | Avifs + | Qt | Mov | Swf + | Mjpeg | Ts | Mts + | Mpeg | Mxf | M2v + | Mpg | Mpe | M2ts + | Flv | Wm | _3gp + | M4v | Wmv | Asf + | Mp4 | Webm | Mkv + | Vob | Ogv | Wtv + | Hevc | F4v + ) +} + +pub async fn extract( + path: impl AsRef + Send, +) -> Result { + let path = path.as_ref(); + + FFmpegMetadata::from_path(&path).await.map_err(|e| { + media_data_extractor::NonCriticalError::FailedToExtractImageMediaData( + path.to_path_buf(), + e.to_string(), + ) + .into() + }) +} + +pub async fn save( + ffmpeg_datas: impl IntoIterator + Send, + db: &PrismaClient, +) -> Result { + ffmpeg_datas + .into_iter() + .map( + move |( + FFmpegMetadata { + formats, + duration, + start_time, + bit_rate, + chapters, + programs, + metadata, + }, + object_id, + )| { + db._transaction() + .with_timeout(30 * 1000) + .run(move |db| async move { + let data_id = create_ffmpeg_data( + formats, bit_rate, duration, start_time, metadata, object_id, &db, + ) + .await?; + + create_ffmpeg_chapters(data_id, chapters, &db).await?; + + let streams = create_ffmpeg_programs(data_id, programs, &db).await?; + + let codecs = create_ffmpeg_streams(data_id, streams, &db).await?; + + let (audio_props, video_props) = + create_ffmpeg_codecs(data_id, codecs, &db).await?; + + ( + create_ffmpeg_audio_props(audio_props, &db), + create_ffmpeg_video_props(video_props, &db), + ) + .try_join() + .await + .map(|_| ()) + }) + }, + ) + .collect::>() + .try_join() + .await + .map(|created| created.len() as u64) + .map_err(Into::into) +} + +async fn create_ffmpeg_data( + formats: Vec, + (bit_rate_high, bit_rate_low): (i32, u32), + maybe_duration: Option<(i32, u32)>, + maybe_start_time: Option<(i32, u32)>, + metadata: Metadata, + object_id: i32, + db: &PrismaClient, +) -> Result { + db.ffmpeg_data() + .create( + formats.join(","), + ffmpeg_data_field_to_db(i64::from(bit_rate_high) << 32 | i64::from(bit_rate_low)), + object::id::equals(object_id), + vec![ + ffmpeg_data::duration::set(maybe_duration.map(|(duration_high, duration_low)| { + ffmpeg_data_field_to_db( + i64::from(duration_high) << 32 | i64::from(duration_low), + ) + })), + ffmpeg_data::start_time::set(maybe_start_time.map( + |(start_time_high, start_time_low)| { + ffmpeg_data_field_to_db( + i64::from(start_time_high) << 32 | i64::from(start_time_low), + ) + }, + )), + ffmpeg_data::metadata::set( + serde_json::to_vec(&metadata) + .map_err(|err| { + error!("Error reading FFmpegData metadata: {err:#?}"); + err + }) + .ok(), + ), + ], + ) + .select(ffmpeg_data::select!({ id })) + .exec() + .await + .map(|data| data.id) +} + +async fn create_ffmpeg_chapters( + ffmpeg_data_id: ffmpeg_data::id::Type, + chapters: Vec, + db: &PrismaClient, +) -> Result<(), QueryError> { + db.ffmpeg_media_chapter() + .create_many( + chapters + .into_iter() + .map( + |Chapter { + id: chapter_id, + start: (start_high, start_low), + end: (end_high, end_low), + time_base_den, + time_base_num, + metadata, + }| ffmpeg_media_chapter::CreateUnchecked { + chapter_id, + start: ffmpeg_data_field_to_db( + i64::from(start_high) << 32 | i64::from(start_low), + ), + end: ffmpeg_data_field_to_db( + i64::from(end_high) << 32 | i64::from(end_low), + ), + time_base_den, + time_base_num, + ffmpeg_data_id, + _params: vec![ffmpeg_media_chapter::metadata::set( + serde_json::to_vec(&metadata) + .map_err(|err| { + error!("Error reading FFmpegMediaChapter metadata: {err:#?}"); + err + }) + .ok(), + )], + }, + ) + .collect(), + ) + .exec() + .await + .map(|_| ()) +} + +async fn create_ffmpeg_programs( + data_id: i32, + programs: Vec, + db: &PrismaClient, +) -> Result)>, QueryError> { + let (creates, streams_by_program_id) = + programs + .into_iter() + .map( + |Program { + id: program_id, + name, + metadata, + streams, + }| { + ( + ffmpeg_media_program::CreateUnchecked { + program_id, + ffmpeg_data_id: data_id, + _params: vec![ + ffmpeg_media_program::name::set(name), + ffmpeg_media_program::metadata::set( + serde_json::to_vec(&metadata) + .map_err(|err| { + error!("Error reading FFmpegMediaProgram metadata: {err:#?}"); + err + }) + .ok(), + ), + ], + }, + (program_id, streams), + ) + }, + ) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + db.ffmpeg_media_program() + .create_many(creates) + .exec() + .await + .map(|_| streams_by_program_id) +} + +async fn create_ffmpeg_streams( + ffmpeg_data_id: ffmpeg_data::id::Type, + streams: Vec<(ffmpeg_media_program::program_id::Type, Vec)>, + db: &PrismaClient, +) -> Result< + Vec<( + ffmpeg_media_program::program_id::Type, + ffmpeg_media_stream::stream_id::Type, + Codec, + )>, + QueryError, +> { + let (creates, maybe_codecs) = streams + .into_iter() + .flat_map(|(program_id, streams)| { + streams.into_iter().map( + move |Stream { + id: stream_id, + name, + codec: maybe_codec, + aspect_ratio_num, + aspect_ratio_den, + frames_per_second_num, + frames_per_second_den, + time_base_real_den, + time_base_real_num, + dispositions, + metadata, + }| { + ( + ffmpeg_media_stream::CreateUnchecked { + stream_id, + aspect_ratio_num, + aspect_ratio_den, + frames_per_second_num, + frames_per_second_den, + time_base_real_den, + time_base_real_num, + program_id, + ffmpeg_data_id, + _params: vec![ + ffmpeg_media_stream::name::set(name), + ffmpeg_media_stream::dispositions::set( + (!dispositions.is_empty()).then_some(dispositions.join(",")), + ), + ffmpeg_media_stream::title::set(metadata.title.clone()), + ffmpeg_media_stream::encoder::set(metadata.encoder.clone()), + ffmpeg_media_stream::language::set(metadata.language.clone()), + ffmpeg_media_stream::metadata::set( + serde_json::to_vec(&metadata) + .map_err(|err| { + error!("Error reading FFmpegMediaStream metadata: {err:#?}"); + err + }) + .ok(), + ), + ], + }, + maybe_codec.map(|codec| (program_id, stream_id, codec)), + ) + }, + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + db.ffmpeg_media_stream() + .create_many(creates) + .exec() + .await + .map(|_| maybe_codecs.into_iter().flatten().collect()) +} + +async fn create_ffmpeg_codecs( + ffmpeg_data_id: ffmpeg_data::id::Type, + codecs: Vec<( + ffmpeg_media_program::program_id::Type, + ffmpeg_media_stream::stream_id::Type, + Codec, + )>, + db: &PrismaClient, +) -> Result< + ( + Vec<(ffmpeg_media_codec::id::Type, AudioProps)>, + Vec<(ffmpeg_media_codec::id::Type, VideoProps)>, + ), + QueryError, +> { + let expected_creates = codecs.len(); + + let (creates, mut audio_props, mut video_props) = codecs.into_iter().enumerate().fold( + ( + Vec::with_capacity(expected_creates), + HashMap::with_capacity(expected_creates), + HashMap::with_capacity(expected_creates), + ), + |(mut creates, mut audio_props, mut video_props), + ( + idx, + ( + program_id, + stream_id, + Codec { + kind, + sub_kind, + tag, + name, + profile, + bit_rate, + props: maybe_props, + }, + ), + )| { + creates.push(ffmpeg_media_codec::CreateUnchecked { + bit_rate, + stream_id, + program_id, + ffmpeg_data_id, + _params: vec![ + ffmpeg_media_codec::kind::set(kind), + ffmpeg_media_codec::sub_kind::set(sub_kind), + ffmpeg_media_codec::tag::set(tag), + ffmpeg_media_codec::name::set(name), + ffmpeg_media_codec::profile::set(profile), + ], + }); + + if let Some(props) = maybe_props { + match props { + Props::Audio(props) => { + audio_props.insert(idx, props); + } + Props::Video(props) => { + video_props.insert(idx, props); + } + Props::Subtitle(_) => { + // We don't care about subtitles props for now :D + } + } + } + + (creates, audio_props, video_props) + }, + ); + + let created_ids = creates + .into_iter() + .map( + |ffmpeg_media_codec::CreateUnchecked { + bit_rate, + stream_id, + program_id, + ffmpeg_data_id, + _params: params, + }| { + db.ffmpeg_media_codec() + .create_unchecked(bit_rate, stream_id, program_id, ffmpeg_data_id, params) + .select(ffmpeg_media_codec::select!({ id })) + .exec() + }, + ) + .collect::>() + .try_join() + .await?; + + assert_eq!( + created_ids.len(), + expected_creates, + "Not all codecs were created and our invariant is broken!" + ); + + debug_assert!( + created_ids + .windows(2) + .all(|window| window[0].id < window[1].id), + "Codecs were created in a different order than we expected, our invariant is broken!" + ); + + Ok(created_ids.into_iter().enumerate().fold( + ( + Vec::with_capacity(audio_props.len()), + Vec::with_capacity(video_props.len()), + ), + |(mut a_props, mut v_props), (idx, codec_data)| { + if let Some(audio_props) = audio_props.remove(&idx) { + a_props.push((codec_data.id, audio_props)); + } else if let Some(video_props) = video_props.remove(&idx) { + v_props.push((codec_data.id, video_props)); + } + + (a_props, v_props) + }, + )) +} + +async fn create_ffmpeg_audio_props( + audio_props: Vec<(ffmpeg_media_codec::id::Type, AudioProps)>, + db: &PrismaClient, +) -> Result<(), QueryError> { + db.ffmpeg_media_audio_props() + .create_many( + audio_props + .into_iter() + .map( + |( + codec_id, + AudioProps { + delay, + padding, + sample_rate, + sample_format, + bit_per_sample, + channel_layout, + }, + )| ffmpeg_media_audio_props::CreateUnchecked { + delay, + padding, + codec_id, + _params: vec![ + ffmpeg_media_audio_props::sample_rate::set(sample_rate), + ffmpeg_media_audio_props::sample_format::set(sample_format), + ffmpeg_media_audio_props::bit_per_sample::set(bit_per_sample), + ffmpeg_media_audio_props::channel_layout::set(channel_layout), + ], + }, + ) + .collect(), + ) + .exec() + .await + .map(|_| ()) +} + +async fn create_ffmpeg_video_props( + video_props: Vec<(ffmpeg_media_codec::id::Type, VideoProps)>, + db: &PrismaClient, +) -> Result<(), QueryError> { + db.ffmpeg_media_video_props() + .create_many( + video_props + .into_iter() + .map( + |( + codec_id, + VideoProps { + pixel_format, + color_range, + bits_per_channel, + color_space, + color_primaries, + color_transfer, + field_order, + chroma_location, + width, + height, + aspect_ratio_num, + aspect_ratio_den, + properties, + }, + )| { + ffmpeg_media_video_props::CreateUnchecked { + width, + height, + codec_id, + _params: vec![ + ffmpeg_media_video_props::pixel_format::set(pixel_format), + ffmpeg_media_video_props::color_range::set(color_range), + ffmpeg_media_video_props::bits_per_channel::set(bits_per_channel), + ffmpeg_media_video_props::color_space::set(color_space), + ffmpeg_media_video_props::color_primaries::set(color_primaries), + ffmpeg_media_video_props::color_transfer::set(color_transfer), + ffmpeg_media_video_props::field_order::set(field_order), + ffmpeg_media_video_props::chroma_location::set(chroma_location), + ffmpeg_media_video_props::aspect_ratio_num::set(aspect_ratio_num), + ffmpeg_media_video_props::aspect_ratio_den::set(aspect_ratio_den), + ffmpeg_media_video_props::properties::set(Some( + properties.join(","), + )), + ], + } + }, + ) + .collect(), + ) + .exec() + .await + .map(|_| ()) +} diff --git a/core/crates/heavy-lifting/src/media_processor/helpers/mod.rs b/core/crates/heavy-lifting/src/media_processor/helpers/mod.rs new file mode 100644 index 000000000..4432d19a7 --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/helpers/mod.rs @@ -0,0 +1,3 @@ +pub mod exif_media_data; +pub mod ffmpeg_media_data; +pub mod thumbnailer; diff --git a/core/crates/heavy-lifting/src/media_processor/helpers/thumbnailer.rs b/core/crates/heavy-lifting/src/media_processor/helpers/thumbnailer.rs new file mode 100644 index 000000000..5f2de34e7 --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/helpers/thumbnailer.rs @@ -0,0 +1,135 @@ +use once_cell::sync::Lazy; +use sd_file_ext::extensions::{ + DocumentExtension, Extension, ImageExtension, ALL_DOCUMENT_EXTENSIONS, ALL_IMAGE_EXTENSIONS, +}; + +#[cfg(feature = "ffmpeg")] +use sd_file_ext::extensions::{VideoExtension, ALL_VIDEO_EXTENSIONS}; + +use std::time::Duration; + +use serde::{Deserialize, Serialize}; +use specta::Type; +use uuid::Uuid; + +// Files names constants +pub const THUMBNAIL_CACHE_DIR_NAME: &str = "thumbnails"; +pub const WEBP_EXTENSION: &str = "webp"; +pub const EPHEMERAL_DIR: &str = "ephemeral"; + +/// This is the target pixel count for all thumbnails to be resized to, and it is eventually downscaled +/// to [`TARGET_QUALITY`]. +pub const TARGET_PX: f32 = 1_048_576.0; // 1024x1024 + +/// This is the target quality that we render thumbnails at, it is a float between 0-100 +/// and is treated as a percentage (so 60% in this case, or it's the same as multiplying by `0.6`). +pub const TARGET_QUALITY: f32 = 60.0; + +/// How much time we allow for the thumbnail generation process to complete before we give up. +pub const THUMBNAIL_GENERATION_TIMEOUT: Duration = Duration::from_secs(60); + +#[cfg(feature = "ffmpeg")] +pub static THUMBNAILABLE_VIDEO_EXTENSIONS: Lazy> = Lazy::new(|| { + ALL_VIDEO_EXTENSIONS + .iter() + .copied() + .filter(|&ext| can_generate_thumbnail_for_video(ext)) + .map(Extension::Video) + .collect() +}); + +pub static THUMBNAILABLE_EXTENSIONS: Lazy> = Lazy::new(|| { + ALL_IMAGE_EXTENSIONS + .iter() + .copied() + .filter(|&ext| can_generate_thumbnail_for_image(ext)) + .map(Extension::Image) + .chain( + ALL_DOCUMENT_EXTENSIONS + .iter() + .copied() + .filter(|&ext| can_generate_thumbnail_for_document(ext)) + .map(Extension::Document), + ) + .collect() +}); + +pub static ALL_THUMBNAILABLE_EXTENSIONS: Lazy> = Lazy::new(|| { + #[cfg(feature = "ffmpeg")] + return THUMBNAILABLE_EXTENSIONS + .iter() + .cloned() + .chain(THUMBNAILABLE_VIDEO_EXTENSIONS.iter().cloned()) + .collect(); + + #[cfg(not(feature = "ffmpeg"))] + THUMBNAILABLE_EXTENSIONS.clone() +}); + +/// This type is used to pass the relevant data to the frontend so it can request the thumbnail. +/// Tt supports extending the shard hex to support deeper directory structures in the future +#[derive(Debug, Serialize, Deserialize, Type)] +pub struct ThumbKey { + pub shard_hex: String, + pub cas_id: String, + pub base_directory_str: String, +} + +impl ThumbKey { + #[must_use] + pub fn new(cas_id: &str, kind: &ThumbnailKind) -> Self { + Self { + shard_hex: get_shard_hex(cas_id).to_string(), + cas_id: cas_id.to_string(), + base_directory_str: match kind { + ThumbnailKind::Ephemeral => String::from(EPHEMERAL_DIR), + ThumbnailKind::Indexed(library_id) => library_id.to_string(), + }, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Type, Clone, Copy)] +pub enum ThumbnailKind { + Ephemeral, + Indexed(Uuid), +} + +/// The practice of dividing files into hex coded folders, often called "sharding," +/// is mainly used to optimize file system performance. File systems can start to slow down +/// as the number of files in a directory increases. Thus, it's often beneficial to split +/// files into multiple directories to avoid this performance degradation. +/// +/// `get_shard_hex` takes a `cas_id` (a hexadecimal hash) as input and returns the first +/// three characters of the hash as the directory name. Because we're using these first +/// three characters of a the hash, this will give us 4096 (16^3) possible directories, +/// named 000 to fff. +#[inline] +pub fn get_shard_hex(cas_id: &str) -> &str { + // Use the first three characters of the hash as the directory name + &cas_id[0..3] +} + +#[cfg(feature = "ffmpeg")] +pub const fn can_generate_thumbnail_for_video(video_extension: VideoExtension) -> bool { + use VideoExtension::{Hevc, M2ts, M2v, Mpg, Mts, Swf, Ts}; + // File extensions that are specifically not supported by the thumbnailer + !matches!(video_extension, Mpg | Swf | M2v | Hevc | M2ts | Mts | Ts) +} + +pub const fn can_generate_thumbnail_for_image(image_extension: ImageExtension) -> bool { + use ImageExtension::{ + Avif, Bmp, Gif, Heic, Heics, Heif, Heifs, Ico, Jpeg, Jpg, Png, Svg, Webp, + }; + + matches!( + image_extension, + Jpg | Jpeg | Png | Webp | Gif | Svg | Heic | Heics | Heif | Heifs | Avif | Bmp | Ico + ) +} + +pub const fn can_generate_thumbnail_for_document(document_extension: DocumentExtension) -> bool { + use DocumentExtension::Pdf; + + matches!(document_extension, Pdf) +} diff --git a/core/crates/heavy-lifting/src/media_processor/job.rs b/core/crates/heavy-lifting/src/media_processor/job.rs new file mode 100644 index 000000000..bacca5e87 --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/job.rs @@ -0,0 +1,825 @@ +use crate::{ + job_system::{ + job::{Job, JobReturn, JobTaskDispatcher, ReturnStatus}, + report::ReportOutputMetadata, + utils::cancel_pending_tasks, + SerializableJob, SerializedTasks, + }, + media_processor::{self, helpers::thumbnailer::THUMBNAIL_CACHE_DIR_NAME}, + utils::sub_path::{self, maybe_get_iso_file_path_from_sub_path}, + Error, JobName, LocationScanState, OuterContext, ProgressUpdate, +}; +use sd_core_file_path_helper::IsolatedFilePathData; +use sd_core_prisma_helpers::file_path_for_media_processor; + +use sd_file_ext::extensions::Extension; +use sd_prisma::prisma::{location, PrismaClient}; +use sd_task_system::{ + AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskOutput, + TaskStatus, +}; +use sd_utils::db::maybe_missing; + +use std::{ + collections::HashMap, + fmt, + hash::{Hash, Hasher}, + mem, + path::PathBuf, + sync::Arc, + time::Duration, +}; + +use futures::{stream::FuturesUnordered, StreamExt}; +use futures_concurrency::future::TryJoin; +use itertools::Itertools; +use prisma_client_rust::{raw, PrismaValue}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tracing::{debug, warn}; + +use super::{ + helpers, + tasks::{self, media_data_extractor, thumbnailer}, + NewThumbnailsReporter, BATCH_SIZE, +}; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +enum TaskKind { + MediaDataExtractor, + Thumbnailer, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +enum Phase { + MediaDataExtraction, + ThumbnailGeneration, + // LabelsGeneration, // TODO: Implement labels generation +} + +impl Default for Phase { + fn default() -> Self { + Self::MediaDataExtraction + } +} + +impl fmt::Display for Phase { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MediaDataExtraction => write!(f, "media_data"), + Self::ThumbnailGeneration => write!(f, "thumbnails"), + // Self::LabelsGeneration => write!(f, "labels"), // TODO: Implement labels generation + } + } +} + +#[derive(Debug)] +pub struct MediaProcessor { + location: Arc, + location_path: Arc, + sub_path: Option, + regenerate_thumbnails: bool, + + total_media_data_extraction_tasks: u64, + total_thumbnailer_tasks: u64, + total_thumbnailer_files: u64, + + phase: Phase, + + metadata: Metadata, + + errors: Vec, + + pending_tasks_on_resume: Vec>, + tasks_for_shutdown: Vec>>, +} + +impl Job for MediaProcessor { + const NAME: JobName = JobName::MediaProcessor; + + async fn resume_tasks( + &mut self, + dispatcher: &JobTaskDispatcher, + ctx: &impl OuterContext, + SerializedTasks(serialized_tasks): SerializedTasks, + ) -> Result<(), Error> { + let reporter = Arc::new(NewThumbnailsReporter { ctx: ctx.clone() }); + + self.pending_tasks_on_resume = dispatcher + .dispatch_many_boxed( + rmp_serde::from_slice::)>>(&serialized_tasks) + .map_err(media_processor::Error::from)? + .into_iter() + .map(|(task_kind, task_bytes)| { + let reporter = Arc::clone(&reporter); + async move { + match task_kind { + TaskKind::MediaDataExtractor => { + tasks::MediaDataExtractor::deserialize( + &task_bytes, + Arc::clone(ctx.db()), + ) + .await + .map(IntoTask::into_task) + } + + TaskKind::Thumbnailer => tasks::Thumbnailer::deserialize( + &task_bytes, + Arc::clone(&reporter), + ) + .await + .map(IntoTask::into_task), + } + } + }) + .collect::>() + .try_join() + .await + .map_err(media_processor::Error::from)?, + ) + .await; + + Ok(()) + } + + async fn run( + mut self, + dispatcher: JobTaskDispatcher, + ctx: Ctx, + ) -> Result { + let mut pending_running_tasks = FuturesUnordered::new(); + + self.init_or_resume(&mut pending_running_tasks, &ctx, &dispatcher) + .await?; + + if let Some(res) = self.process_handles(&mut pending_running_tasks, &ctx).await { + return res; + } + + if !self.tasks_for_shutdown.is_empty() { + return Ok(ReturnStatus::Shutdown( + SerializableJob::::serialize(self).await, + )); + } + + // From this point onward, we are done with the job and it can't be interrupted anymore + let Self { + location, + metadata, + errors, + .. + } = self; + + ctx.db() + .location() + .update( + location::id::equals(location.id), + vec![location::scan_state::set( + LocationScanState::Completed as i32, + )], + ) + .exec() + .await + .map_err(media_processor::Error::from)?; + + Ok(ReturnStatus::Completed( + JobReturn::builder() + .with_metadata(metadata) + .with_non_critical_errors(errors) + .build(), + )) + } +} + +impl MediaProcessor { + pub fn new( + location: location::Data, + sub_path: Option, + regenerate_thumbnails: bool, + ) -> Result { + Ok(Self { + location_path: maybe_missing(&location.path, "location.path") + .map(PathBuf::from) + .map(Arc::new)?, + location: Arc::new(location), + sub_path, + regenerate_thumbnails, + total_media_data_extraction_tasks: 0, + total_thumbnailer_tasks: 0, + total_thumbnailer_files: 0, + phase: Phase::default(), + metadata: Metadata::default(), + errors: Vec::new(), + pending_tasks_on_resume: Vec::new(), + tasks_for_shutdown: Vec::new(), + }) + } + + async fn init_or_resume( + &mut self, + pending_running_tasks: &mut FuturesUnordered>, + ctx: &impl OuterContext, + dispatcher: &JobTaskDispatcher, + ) -> Result<(), media_processor::Error> { + // if we don't have any pending task, then this is a fresh job + if self.pending_tasks_on_resume.is_empty() { + let location_id = self.location.id; + let location_path = &*self.location_path; + + let iso_file_path = maybe_get_iso_file_path_from_sub_path( + location_id, + &self.sub_path, + &*self.location_path, + ctx.db(), + ) + .await? + .map_or_else( + || { + IsolatedFilePathData::new(location_id, location_path, location_path, true) + .map_err(sub_path::Error::from) + }, + Ok, + )?; + + debug!( + "Searching for media files in location {location_id} at directory \"{iso_file_path}\"" + ); + + // First we will dispatch all tasks for media data extraction so we have a nice reporting + let (total_media_data_extraction_files, task_handles) = + dispatch_media_data_extractor_tasks( + ctx.db(), + &iso_file_path, + &self.location_path, + dispatcher, + ) + .await?; + self.total_media_data_extraction_tasks = task_handles.len() as u64; + + pending_running_tasks.extend(task_handles); + + ctx.progress(vec![ + ProgressUpdate::TaskCount(total_media_data_extraction_files), + ProgressUpdate::Phase(self.phase.to_string()), + ProgressUpdate::Message(format!( + "Preparing to process {total_media_data_extraction_files} files in {} chunks", + self.total_media_data_extraction_tasks + )), + ]); + + // Now we dispatch thumbnailer tasks + let (total_thumbnailer_tasks, task_handles) = dispatch_thumbnailer_tasks( + &iso_file_path, + self.regenerate_thumbnails, + &self.location_path, + dispatcher, + ctx, + ) + .await?; + pending_running_tasks.extend(task_handles); + + self.total_thumbnailer_tasks = total_thumbnailer_tasks; + } else { + pending_running_tasks.extend(mem::take(&mut self.pending_tasks_on_resume)); + } + + Ok(()) + } + + async fn process_handles( + &mut self, + pending_running_tasks: &mut FuturesUnordered>, + ctx: &impl OuterContext, + ) -> Option> { + while let Some(task) = pending_running_tasks.next().await { + match task { + Ok(TaskStatus::Done((task_id, TaskOutput::Out(out)))) => { + self.process_task_output(task_id, out, ctx); + } + + Ok(TaskStatus::Done((task_id, TaskOutput::Empty))) => { + warn!("Task returned an empty output"); + } + + Ok(TaskStatus::Shutdown(task)) => { + self.tasks_for_shutdown.push(task); + } + + Ok(TaskStatus::Error(e)) => { + cancel_pending_tasks(&*pending_running_tasks).await; + + return Some(Err(e)); + } + + Ok(TaskStatus::Canceled | TaskStatus::ForcedAbortion) => { + cancel_pending_tasks(&*pending_running_tasks).await; + + return Some(Ok(ReturnStatus::Canceled)); + } + + Err(e) => { + cancel_pending_tasks(&*pending_running_tasks).await; + + return Some(Err(e.into())); + } + } + } + + None + } + + fn process_task_output( + &mut self, + task_id: uuid::Uuid, + any_task_output: Box, + ctx: &impl OuterContext, + ) { + if any_task_output.is::() { + let media_data_extractor::Output { + extracted, + skipped, + db_read_time, + filtering_time, + extraction_time, + db_write_time, + errors, + } = *any_task_output.downcast().expect("just checked"); + + self.metadata.media_data_metrics.extracted += extracted; + self.metadata.media_data_metrics.skipped += skipped; + self.metadata.media_data_metrics.db_read_time += db_read_time; + self.metadata.media_data_metrics.filtering_time += filtering_time; + self.metadata.media_data_metrics.extraction_time += extraction_time; + self.metadata.media_data_metrics.db_write_time += db_write_time; + self.metadata.media_data_metrics.total_successful_tasks += 1; + + self.errors.extend(errors); + + debug!( + "Processed {}/{} media data extraction tasks", + self.metadata.media_data_metrics.total_successful_tasks, + self.total_media_data_extraction_tasks + ); + ctx.progress(vec![ProgressUpdate::CompletedTaskCount( + self.metadata.media_data_metrics.extracted + + self.metadata.media_data_metrics.skipped, + )]); + + if self.total_media_data_extraction_tasks + == self.metadata.media_data_metrics.total_successful_tasks + { + debug!("All media data extraction tasks have been processed"); + + self.phase = Phase::ThumbnailGeneration; + + ctx.progress(vec![ + ProgressUpdate::TaskCount(self.total_thumbnailer_files), + ProgressUpdate::Phase(self.phase.to_string()), + ProgressUpdate::Message(format!( + "Waiting for processing of {} thumbnails in {} tasks", + self.total_thumbnailer_files, self.total_thumbnailer_tasks + )), + ]); + } + } else if any_task_output.is::() { + let thumbnailer::Output { + generated, + skipped, + errors, + total_time, + mean_time_acc, + std_dev_acc, + } = *any_task_output.downcast().expect("just checked"); + + self.metadata.thumbnailer_metrics_acc.generated += generated; + self.metadata.thumbnailer_metrics_acc.skipped += skipped; + self.metadata.thumbnailer_metrics_acc.total_time += total_time; + self.metadata.thumbnailer_metrics_acc.mean_time_acc += mean_time_acc; + self.metadata.thumbnailer_metrics_acc.std_dev_acc += std_dev_acc; + self.metadata.thumbnailer_metrics_acc.total_successful_tasks += 1; + + self.errors.extend(errors); + + ctx.progress(vec![ProgressUpdate::CompletedTaskCount( + self.metadata.thumbnailer_metrics_acc.generated + + self.metadata.thumbnailer_metrics_acc.skipped, + )]); + + // if self.total_thumbnailer_tasks + // == self.metadata.thumbnailer_metrics_acc.total_successful_tasks + // { + // debug!("All thumbnailer tasks have been processed"); + + // self.phase = Phase::LabelsGeneration; + + // ctx.progress(vec![ + // ProgressUpdate::TaskCount(self.total_thumbnailer_files), + // ProgressUpdate::Phase(self.phase.to_string()), + // ProgressUpdate::Message(format!( + // "Waiting for processing of {} labels in {} tasks", + // self.total_labeller_files, self.total_labeller_tasks + // )), + // ]); + // } + } else { + unreachable!("Unexpected task output type: "); + } + } +} + +#[derive(Debug, Serialize, Deserialize, Default)] +struct Metadata { + media_data_metrics: MediaExtractorMetrics, + thumbnailer_metrics_acc: ThumbnailerMetricsAccumulator, +} + +impl From for ReportOutputMetadata { + fn from( + Metadata { + media_data_metrics, + thumbnailer_metrics_acc: thumbnailer_metrics_accumulator, + }: Metadata, + ) -> Self { + let thumbnailer_metrics = ThumbnailerMetrics::from(thumbnailer_metrics_accumulator); + + Self::Metrics(HashMap::from([ + // + // Media data extractor + // + ( + "media_data_extraction_metrics".into(), + json!(media_data_metrics), + ), + // + // Thumbnailer + // + ("thumbnailer_metrics".into(), json!(thumbnailer_metrics)), + ])) + } +} + +#[derive(Debug, Serialize, Deserialize, Default)] +struct MediaExtractorMetrics { + extracted: u64, + skipped: u64, + db_read_time: Duration, + filtering_time: Duration, + extraction_time: Duration, + db_write_time: Duration, + total_successful_tasks: u64, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +struct ThumbnailerMetricsAccumulator { + generated: u64, + skipped: u64, + total_time: Duration, + mean_time_acc: f64, + std_dev_acc: f64, + total_successful_tasks: u64, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +struct ThumbnailerMetrics { + generated: u64, + skipped: u64, + total_generation_time: Duration, + mean_generation_time: Duration, + std_dev: Duration, + total_successful_tasks: u64, +} + +impl From for ThumbnailerMetrics { + fn from( + ThumbnailerMetricsAccumulator { + generated, + skipped, + total_time: total_generation_time, + mean_time_acc: mean_generation_time_acc, + std_dev_acc, + total_successful_tasks, + }: ThumbnailerMetricsAccumulator, + ) -> Self { + #[allow(clippy::cast_precision_loss)] + // SAFETY: we're probably won't have 2^52 thumbnails being generated on a single job for this cast to have + // a precision loss issue + let total = (generated + skipped) as f64; + let mean_generation_time = mean_generation_time_acc / total; + + let std_dev = Duration::from_secs_f64( + (mean_generation_time.mul_add(-mean_generation_time, std_dev_acc / total)).sqrt(), + ); + + Self { + generated, + skipped, + total_generation_time, + mean_generation_time: Duration::from_secs_f64(mean_generation_time), + std_dev, + total_successful_tasks, + } + } +} + +#[derive(Serialize, Deserialize)] +struct SaveState { + location: Arc, + location_path: Arc, + sub_path: Option, + regenerate_thumbnails: bool, + + total_media_data_extraction_tasks: u64, + total_thumbnailer_tasks: u64, + total_thumbnailer_files: u64, + + phase: Phase, + + metadata: Metadata, + + errors: Vec, + + tasks_for_shutdown_bytes: Option, +} + +impl SerializableJob for MediaProcessor { + async fn serialize(self) -> Result>, rmp_serde::encode::Error> { + let Self { + location, + location_path, + sub_path, + regenerate_thumbnails, + total_media_data_extraction_tasks, + total_thumbnailer_tasks, + total_thumbnailer_files, + phase, + metadata, + errors, + tasks_for_shutdown, + .. + } = self; + + rmp_serde::to_vec_named(&SaveState { + location, + location_path, + sub_path, + regenerate_thumbnails, + total_media_data_extraction_tasks, + total_thumbnailer_tasks, + total_thumbnailer_files, + phase, + metadata, + tasks_for_shutdown_bytes: Some(SerializedTasks(rmp_serde::to_vec_named( + &tasks_for_shutdown + .into_iter() + .map(|task| async move { + if task.is::() { + task.downcast::() + .expect("just checked") + .serialize() + .await + .map(|bytes| (TaskKind::MediaDataExtractor, bytes)) + } else if task.is::>>() { + task.downcast::>>() + .expect("just checked") + .serialize() + .await + .map(|bytes| (TaskKind::Thumbnailer, bytes)) + } else { + unreachable!("Unexpected task type") + } + }) + .collect::>() + .try_join() + .await?, + )?)), + errors, + }) + .map(Some) + } + + async fn deserialize( + serialized_job: &[u8], + _: &Ctx, + ) -> Result)>, rmp_serde::decode::Error> { + let SaveState { + location, + location_path, + sub_path, + regenerate_thumbnails, + total_media_data_extraction_tasks, + total_thumbnailer_tasks, + total_thumbnailer_files, + phase, + metadata, + errors, + tasks_for_shutdown_bytes, + } = rmp_serde::from_slice::(serialized_job)?; + + Ok(Some(( + Self { + location, + location_path, + sub_path, + regenerate_thumbnails, + total_media_data_extraction_tasks, + total_thumbnailer_tasks, + total_thumbnailer_files, + phase, + metadata, + errors, + pending_tasks_on_resume: Vec::new(), + tasks_for_shutdown: Vec::new(), + }, + tasks_for_shutdown_bytes, + ))) + } +} + +impl Hash for MediaProcessor { + fn hash(&self, state: &mut H) { + self.location.id.hash(state); + if let Some(ref sub_path) = self.sub_path { + sub_path.hash(state); + } + } +} + +async fn dispatch_media_data_extractor_tasks( + db: &Arc, + parent_iso_file_path: &IsolatedFilePathData<'_>, + location_path: &Arc, + dispatcher: &JobTaskDispatcher, +) -> Result<(u64, Vec>), media_processor::Error> { + let (extract_exif_file_paths, extract_ffmpeg_file_paths) = ( + get_all_children_files_by_extensions( + db, + parent_iso_file_path, + &helpers::exif_media_data::AVAILABLE_EXTENSIONS, + ), + get_all_children_files_by_extensions( + db, + parent_iso_file_path, + &helpers::ffmpeg_media_data::AVAILABLE_EXTENSIONS, + ), + ) + .try_join() + .await?; + + let files_count = (extract_exif_file_paths.len() + extract_ffmpeg_file_paths.len()) as u64; + + let tasks = extract_exif_file_paths + .into_iter() + .chunks(BATCH_SIZE) + .into_iter() + .map(Iterator::collect::>) + .map(|chunked_file_paths| { + tasks::MediaDataExtractor::new_exif( + &chunked_file_paths, + parent_iso_file_path.location_id(), + Arc::clone(location_path), + Arc::clone(db), + ) + }) + .map(IntoTask::into_task) + .chain( + extract_ffmpeg_file_paths + .into_iter() + .chunks(BATCH_SIZE) + .into_iter() + .map(Iterator::collect::>) + .map(|chunked_file_paths| { + tasks::MediaDataExtractor::new_ffmpeg( + &chunked_file_paths, + parent_iso_file_path.location_id(), + Arc::clone(location_path), + Arc::clone(db), + ) + }) + .map(IntoTask::into_task), + ) + .collect::>(); + + Ok((files_count, dispatcher.dispatch_many_boxed(tasks).await)) +} + +async fn get_all_children_files_by_extensions( + db: &PrismaClient, + parent_iso_file_path: &IsolatedFilePathData<'_>, + extensions: &[Extension], +) -> Result, media_processor::Error> { + // FIXME: Had to use format! macro because PCR doesn't support IN with Vec for SQLite + // We have no data coming from the user, so this is sql injection safe + db._query_raw(raw!( + &format!( + "SELECT id, materialized_path, is_dir, name, extension, cas_id, object_id + FROM file_path + WHERE + location_id={{}} + AND cas_id IS NOT NULL + AND LOWER(extension) IN ({}) + AND materialized_path LIKE {{}} + ORDER BY materialized_path ASC", + // Ordering by materialized_path so we can prioritize processing the first files + // in the above part of the directories tree + extensions + .iter() + .map(|ext| format!("LOWER('{ext}')")) + .collect::>() + .join(",") + ), + PrismaValue::Int(i64::from(parent_iso_file_path.location_id())), + PrismaValue::String(format!( + "{}%", + parent_iso_file_path + .materialized_path_for_children() + .expect("sub path iso_file_path must be a directory") + )) + )) + .exec() + .await + .map_err(Into::into) +} + +async fn dispatch_thumbnailer_tasks( + parent_iso_file_path: &IsolatedFilePathData<'_>, + should_regenerate: bool, + location_path: &PathBuf, + dispatcher: &JobTaskDispatcher, + ctx: &impl OuterContext, +) -> Result<(u64, Vec>), media_processor::Error> { + let thumbnails_directory_path = + Arc::new(ctx.get_data_directory().join(THUMBNAIL_CACHE_DIR_NAME)); + let location_id = parent_iso_file_path.location_id(); + let library_id = ctx.id(); + let db = ctx.db(); + let reporter = Arc::new(NewThumbnailsReporter { ctx: ctx.clone() }); + + let mut file_paths = get_all_children_files_by_extensions( + db, + parent_iso_file_path, + &helpers::thumbnailer::ALL_THUMBNAILABLE_EXTENSIONS, + ) + .await?; + + let thumbs_count = file_paths.len() as u64; + + let first_materialized_path = file_paths[0].materialized_path.clone(); + + // Only the first materialized_path should be processed with priority as the user must see the thumbnails ASAP + let different_materialized_path_idx = file_paths + .iter() + .position(|file_path| file_path.materialized_path != first_materialized_path); + + let non_priority_tasks = different_materialized_path_idx + .map(|idx| { + file_paths + .drain(idx..) + .chunks(BATCH_SIZE) + .into_iter() + .map(|chunk| { + tasks::Thumbnailer::new_indexed( + Arc::clone(&thumbnails_directory_path), + &chunk.collect::>(), + (location_id, location_path), + library_id, + should_regenerate, + false, + Arc::clone(&reporter), + ) + }) + .map(IntoTask::into_task) + .collect::>() + }) + .unwrap_or_default(); + + let priority_tasks = file_paths + .into_iter() + .chunks(BATCH_SIZE) + .into_iter() + .map(|chunk| { + tasks::Thumbnailer::new_indexed( + Arc::clone(&thumbnails_directory_path), + &chunk.collect::>(), + (location_id, location_path), + library_id, + should_regenerate, + true, + Arc::clone(&reporter), + ) + }) + .map(IntoTask::into_task) + .collect::>(); + + debug!( + "Dispatching {thumbs_count} thumbnails to be processed, {} with priority and {} without priority tasks", + priority_tasks.len(), + non_priority_tasks.len() + ); + + Ok(( + thumbs_count, + dispatcher + .dispatch_many_boxed(priority_tasks.into_iter().chain(non_priority_tasks)) + .await, + )) +} diff --git a/core/crates/heavy-lifting/src/media_processor/mod.rs b/core/crates/heavy-lifting/src/media_processor/mod.rs new file mode 100644 index 000000000..7197e686f --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/mod.rs @@ -0,0 +1,73 @@ +use crate::{utils::sub_path, OuterContext, UpdateEvent}; + +use sd_core_file_path_helper::FilePathError; + +use sd_utils::db::MissingFieldError; + +use std::fmt; + +use serde::{Deserialize, Serialize}; +use specta::Type; + +mod helpers; +pub mod job; +mod shallow; +mod tasks; + +pub use tasks::{ + media_data_extractor::{self, MediaDataExtractor}, + thumbnailer::{self, Thumbnailer}, +}; + +pub use helpers::thumbnailer::{ThumbKey, ThumbnailKind}; +pub use shallow::shallow; + +use self::thumbnailer::NewThumbnailReporter; + +const BATCH_SIZE: usize = 10; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("missing field on database: {0}")] + MissingField(#[from] MissingFieldError), + #[error("database error: {0}")] + Database(#[from] prisma_client_rust::QueryError), + #[error("failed to deserialized stored tasks for job resume: {0}")] + DeserializeTasks(#[from] rmp_serde::decode::Error), + + #[error(transparent)] + FilePathError(#[from] FilePathError), + #[error(transparent)] + SubPath(#[from] sub_path::Error), +} + +impl From for rspc::Error { + fn from(e: Error) -> Self { + Self::with_cause(rspc::ErrorCode::InternalServerError, e.to_string(), e) + } +} + +#[derive(thiserror::Error, Debug, Serialize, Deserialize, Type)] +pub enum NonCriticalError { + #[error(transparent)] + MediaDataExtractor(#[from] media_data_extractor::NonCriticalError), + #[error(transparent)] + Thumbnailer(#[from] thumbnailer::NonCriticalError), +} + +struct NewThumbnailsReporter { + ctx: Ctx, +} + +impl fmt::Debug for NewThumbnailsReporter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NewThumbnailsReporter").finish() + } +} + +impl NewThumbnailReporter for NewThumbnailsReporter { + fn new_thumbnail(&self, thumb_key: ThumbKey) { + self.ctx + .report_update(UpdateEvent::NewThumbnailEvent { thumb_key }); + } +} diff --git a/core/crates/heavy-lifting/src/media_processor/shallow.rs b/core/crates/heavy-lifting/src/media_processor/shallow.rs new file mode 100644 index 000000000..1bcb45e31 --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/shallow.rs @@ -0,0 +1,258 @@ +use crate::{ + media_processor, utils::sub_path::maybe_get_iso_file_path_from_sub_path, Error, + NonCriticalError, OuterContext, +}; + +use sd_core_file_path_helper::IsolatedFilePathData; +use sd_core_prisma_helpers::file_path_for_media_processor; + +use sd_file_ext::extensions::Extension; +use sd_prisma::prisma::{location, PrismaClient}; +use sd_task_system::{ + BaseTaskDispatcher, CancelTaskOnDrop, IntoTask, TaskDispatcher, TaskHandle, TaskOutput, + TaskStatus, +}; +use sd_utils::db::maybe_missing; + +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; + +use futures::StreamExt; +use futures_concurrency::future::{FutureGroup, TryJoin}; +use itertools::Itertools; +use prisma_client_rust::{raw, PrismaValue}; +use tracing::{debug, warn}; + +use super::{ + helpers::{self, exif_media_data, ffmpeg_media_data, thumbnailer::THUMBNAIL_CACHE_DIR_NAME}, + tasks::{self, media_data_extractor, thumbnailer}, + NewThumbnailsReporter, BATCH_SIZE, +}; + +#[allow(clippy::missing_panics_doc)] // SAFETY: It doesn't actually panics +pub async fn shallow( + location: location::Data, + sub_path: impl AsRef + Send, + dispatcher: BaseTaskDispatcher, + ctx: impl OuterContext, +) -> Result, Error> { + let sub_path = sub_path.as_ref(); + + let location_path = maybe_missing(&location.path, "location.path") + .map(PathBuf::from) + .map(Arc::new) + .map_err(media_processor::Error::from)?; + + let location = Arc::new(location); + + let sub_iso_file_path = maybe_get_iso_file_path_from_sub_path( + location.id, + &Some(sub_path), + &*location_path, + ctx.db(), + ) + .await + .map_err(media_processor::Error::from)? + .map_or_else( + || { + IsolatedFilePathData::new(location.id, &*location_path, &*location_path, true) + .map_err(media_processor::Error::from) + }, + Ok, + )?; + + let mut errors = vec![]; + + let mut futures = dispatch_media_data_extractor_tasks( + ctx.db(), + &sub_iso_file_path, + &location_path, + &dispatcher, + ) + .await? + .into_iter() + .map(CancelTaskOnDrop) + .chain( + dispatch_thumbnailer_tasks(&sub_iso_file_path, false, &location_path, &dispatcher, &ctx) + .await? + .into_iter() + .map(CancelTaskOnDrop), + ) + .collect::>(); + + while let Some(res) = futures.next().await { + match res { + Ok(TaskStatus::Done((_, TaskOutput::Out(out)))) => { + if out.is::() { + errors.extend( + out.downcast::() + .expect("just checked") + .errors, + ); + } else if out.is::() { + errors.extend( + out.downcast::() + .expect("just checked") + .errors, + ); + } else { + unreachable!( + "Task returned unexpected output type on media processor shallow job" + ); + } + } + Ok(TaskStatus::Done((_, TaskOutput::Empty))) => { + warn!("Task returned empty output on media processor shallow job"); + } + Ok(TaskStatus::Canceled | TaskStatus::ForcedAbortion | TaskStatus::Shutdown(_)) => { + return Ok(errors); + } + Ok(TaskStatus::Error(e)) => return Err(e), + + Err(e) => return Err(e.into()), + } + } + + Ok(errors) +} + +async fn dispatch_media_data_extractor_tasks( + db: &Arc, + parent_iso_file_path: &IsolatedFilePathData<'_>, + location_path: &Arc, + dispatcher: &BaseTaskDispatcher, +) -> Result>, media_processor::Error> { + let (extract_exif_file_paths, extract_ffmpeg_file_paths) = ( + get_files_by_extensions( + db, + parent_iso_file_path, + &exif_media_data::AVAILABLE_EXTENSIONS, + ), + get_files_by_extensions( + db, + parent_iso_file_path, + &ffmpeg_media_data::AVAILABLE_EXTENSIONS, + ), + ) + .try_join() + .await?; + + let tasks = extract_exif_file_paths + .into_iter() + .chunks(BATCH_SIZE) + .into_iter() + .map(Iterator::collect::>) + .map(|chunked_file_paths| { + tasks::MediaDataExtractor::new_exif( + &chunked_file_paths, + parent_iso_file_path.location_id(), + Arc::clone(location_path), + Arc::clone(db), + ) + }) + .map(IntoTask::into_task) + .chain( + extract_ffmpeg_file_paths + .into_iter() + .chunks(BATCH_SIZE) + .into_iter() + .map(Iterator::collect::>) + .map(|chunked_file_paths| { + tasks::MediaDataExtractor::new_ffmpeg( + &chunked_file_paths, + parent_iso_file_path.location_id(), + Arc::clone(location_path), + Arc::clone(db), + ) + }) + .map(IntoTask::into_task), + ) + .collect::>(); + + Ok(dispatcher.dispatch_many_boxed(tasks).await) +} + +async fn get_files_by_extensions( + db: &PrismaClient, + parent_iso_file_path: &IsolatedFilePathData<'_>, + extensions: &[Extension], +) -> Result, media_processor::Error> { + // FIXME: Had to use format! macro because PCR doesn't support IN with Vec for SQLite + // We have no data coming from the user, so this is sql injection safe + db._query_raw(raw!( + &format!( + "SELECT id, materialized_path, is_dir, name, extension, cas_id, object_id + FROM file_path + WHERE + location_id={{}} + AND cas_id IS NOT NULL + AND LOWER(extension) IN ({}) + AND materialized_path = {{}}", + extensions + .iter() + .map(|ext| format!("LOWER('{ext}')")) + .collect::>() + .join(",") + ), + PrismaValue::Int(i64::from(parent_iso_file_path.location_id())), + PrismaValue::String( + parent_iso_file_path + .materialized_path_for_children() + .expect("sub path iso_file_path must be a directory") + ) + )) + .exec() + .await + .map_err(Into::into) +} + +async fn dispatch_thumbnailer_tasks( + parent_iso_file_path: &IsolatedFilePathData<'_>, + should_regenerate: bool, + location_path: &PathBuf, + dispatcher: &BaseTaskDispatcher, + ctx: &impl OuterContext, +) -> Result>, media_processor::Error> { + let thumbnails_directory_path = + Arc::new(ctx.get_data_directory().join(THUMBNAIL_CACHE_DIR_NAME)); + let location_id = parent_iso_file_path.location_id(); + let library_id = ctx.id(); + let db = ctx.db(); + let reporter = Arc::new(NewThumbnailsReporter { ctx: ctx.clone() }); + + let file_paths = get_files_by_extensions( + db, + parent_iso_file_path, + &helpers::thumbnailer::ALL_THUMBNAILABLE_EXTENSIONS, + ) + .await?; + + let thumbs_count = file_paths.len() as u64; + + let tasks = file_paths + .into_iter() + .chunks(BATCH_SIZE) + .into_iter() + .map(|chunk| { + tasks::Thumbnailer::new_indexed( + Arc::clone(&thumbnails_directory_path), + &chunk.collect::>(), + (location_id, location_path), + library_id, + should_regenerate, + true, + Arc::clone(&reporter), + ) + }) + .map(IntoTask::into_task) + .collect::>(); + + debug!( + "Dispatching {thumbs_count} thumbnails to be processed, in {} priority tasks", + tasks.len(), + ); + + Ok(dispatcher.dispatch_many_boxed(tasks).await) +} diff --git a/core/crates/heavy-lifting/src/media_processor/tasks/media_data_extractor.rs b/core/crates/heavy-lifting/src/media_processor/tasks/media_data_extractor.rs new file mode 100644 index 000000000..4a5f6661f --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/tasks/media_data_extractor.rs @@ -0,0 +1,525 @@ +use crate::{ + media_processor::{ + self, + helpers::{exif_media_data, ffmpeg_media_data}, + }, + Error, +}; + +use sd_core_file_path_helper::IsolatedFilePathData; +use sd_core_prisma_helpers::file_path_for_media_processor; + +use sd_media_metadata::{ExifMetadata, FFmpegMetadata}; +use sd_prisma::prisma::{exif_data, ffmpeg_data, file_path, location, object, PrismaClient}; +use sd_task_system::{ + check_interruption, ExecStatus, Interrupter, InterruptionKind, IntoAnyTaskOutput, + SerializableTask, Task, TaskId, +}; + +use std::{ + collections::{HashMap, HashSet}, + future::{Future, IntoFuture}, + mem, + path::{Path, PathBuf}, + pin::pin, + sync::Arc, + time::Duration, +}; + +use futures::{FutureExt, StreamExt}; +use futures_concurrency::future::{FutureGroup, Race}; +use serde::{Deserialize, Serialize}; +use specta::Type; +use tokio::time::Instant; + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)] +enum Kind { + Exif, + FFmpeg, +} + +#[derive(Debug)] +pub struct MediaDataExtractor { + id: TaskId, + kind: Kind, + file_paths: Vec, + location_id: location::id::Type, + location_path: Arc, + stage: Stage, + db: Arc, + output: Output, +} + +#[derive(Debug, Serialize, Deserialize)] +enum Stage { + Starting, + FetchedObjectsAlreadyWithMediaData(Vec), + ExtractingMediaData { + paths_by_id: HashMap, + exif_media_datas: Vec<(ExifMetadata, object::id::Type)>, + ffmpeg_media_datas: Vec<(FFmpegMetadata, object::id::Type)>, + extract_ids_to_remove_from_map: Vec, + }, + SaveMediaData { + exif_media_datas: Vec<(ExifMetadata, object::id::Type)>, + ffmpeg_media_datas: Vec<(FFmpegMetadata, object::id::Type)>, + }, +} + +impl MediaDataExtractor { + fn new( + kind: Kind, + file_paths: &[file_path_for_media_processor::Data], + location_id: location::id::Type, + location_path: Arc, + db: Arc, + ) -> Self { + let mut output = Output::default(); + + Self { + id: TaskId::new_v4(), + kind, + file_paths: file_paths + .iter() + .filter(|file_path| { + if file_path.object_id.is_some() { + true + } else { + output.errors.push( + media_processor::NonCriticalError::from( + NonCriticalError::FilePathMissingObjectId(file_path.id), + ) + .into(), + ); + false + } + }) + .cloned() + .collect(), + location_id, + location_path, + stage: Stage::Starting, + db, + output, + } + } + + #[must_use] + pub fn new_exif( + file_paths: &[file_path_for_media_processor::Data], + location_id: location::id::Type, + location_path: Arc, + db: Arc, + ) -> Self { + Self::new(Kind::Exif, file_paths, location_id, location_path, db) + } + + #[must_use] + pub fn new_ffmpeg( + file_paths: &[file_path_for_media_processor::Data], + location_id: location::id::Type, + location_path: Arc, + db: Arc, + ) -> Self { + Self::new(Kind::FFmpeg, file_paths, location_id, location_path, db) + } +} + +#[async_trait::async_trait] +impl Task for MediaDataExtractor { + fn id(&self) -> TaskId { + self.id + } + + /// MediaDataExtractor never needs priority, as the data it generates are only accessed through + /// the media inspector, so it isn't latency sensitive like other tasks, like FileIdentifier or + /// the Thumbnailer + fn with_priority(&self) -> bool { + false + } + + async fn run(&mut self, interrupter: &Interrupter) -> Result { + loop { + match &mut self.stage { + Stage::Starting => { + let db_read_start = Instant::now(); + let object_ids = fetch_objects_already_with_media_data( + self.kind, + &self.file_paths, + &self.db, + ) + .await?; + self.output.db_read_time = db_read_start.elapsed(); + + self.stage = Stage::FetchedObjectsAlreadyWithMediaData(object_ids); + } + + Stage::FetchedObjectsAlreadyWithMediaData(objects_already_with_media_data) => { + let filtering_start = Instant::now(); + if self.file_paths.len() == objects_already_with_media_data.len() { + // All files already have media data, skipping + self.output.skipped = self.file_paths.len() as u64; + + break; + } + let paths_by_id = filter_files_to_extract_media_data( + mem::take(objects_already_with_media_data), + self.location_id, + &self.location_path, + &mut self.file_paths, + &mut self.output, + ); + + self.output.filtering_time = filtering_start.elapsed(); + + self.stage = Stage::ExtractingMediaData { + extract_ids_to_remove_from_map: Vec::with_capacity(paths_by_id.len()), + exif_media_datas: if self.kind == Kind::Exif { + Vec::with_capacity(paths_by_id.len()) + } else { + Vec::new() + }, + ffmpeg_media_datas: if self.kind == Kind::FFmpeg { + Vec::with_capacity(paths_by_id.len()) + } else { + Vec::new() + }, + paths_by_id, + }; + } + + Stage::ExtractingMediaData { + paths_by_id, + exif_media_datas, + ffmpeg_media_datas, + extract_ids_to_remove_from_map, + } => { + { + // This inner scope is necessary to appease the mighty borrowck + let extraction_start = Instant::now(); + for id in extract_ids_to_remove_from_map.drain(..) { + paths_by_id.remove(&id); + } + + let mut futures = pin!(prepare_extraction_futures( + self.kind, + paths_by_id, + interrupter + )); + + while let Some(race_output) = futures.next().await { + match race_output { + InterruptRace::Processed(out) => { + process_output( + out, + exif_media_datas, + ffmpeg_media_datas, + extract_ids_to_remove_from_map, + &mut self.output, + ); + } + + InterruptRace::Interrupted(kind) => { + self.output.extraction_time += extraction_start.elapsed(); + return Ok(match kind { + InterruptionKind::Pause => ExecStatus::Paused, + InterruptionKind::Cancel => ExecStatus::Canceled, + }); + } + } + } + } + + self.stage = Stage::SaveMediaData { + exif_media_datas: mem::take(exif_media_datas), + ffmpeg_media_datas: mem::take(ffmpeg_media_datas), + }; + } + + Stage::SaveMediaData { + exif_media_datas, + ffmpeg_media_datas, + } => { + let db_write_start = Instant::now(); + self.output.extracted = + save(self.kind, exif_media_datas, ffmpeg_media_datas, &self.db).await?; + self.output.db_write_time = db_write_start.elapsed(); + + self.output.skipped += self.output.errors.len() as u64; + + break; + } + } + + check_interruption!(interrupter); + } + + Ok(ExecStatus::Done(mem::take(&mut self.output).into_output())) + } +} + +#[derive(thiserror::Error, Debug, Serialize, Deserialize, Type)] +pub enum NonCriticalError { + #[error("failed to extract media data from : {1}", .0.display())] + FailedToExtractImageMediaData(PathBuf, String), + #[error("file path missing object id: ")] + FilePathMissingObjectId(file_path::id::Type), + #[error("failed to construct isolated file path data: : {1}")] + FailedToConstructIsolatedFilePathData(file_path::id::Type, String), +} + +#[derive(Serialize, Deserialize, Default, Debug)] +pub struct Output { + pub extracted: u64, + pub skipped: u64, + pub db_read_time: Duration, + pub filtering_time: Duration, + pub extraction_time: Duration, + pub db_write_time: Duration, + pub errors: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct SaveState { + id: TaskId, + kind: Kind, + file_paths: Vec, + location_id: location::id::Type, + location_path: Arc, + stage: Stage, + output: Output, +} + +impl SerializableTask for MediaDataExtractor { + type SerializeError = rmp_serde::encode::Error; + + type DeserializeError = rmp_serde::decode::Error; + + type DeserializeCtx = Arc; + + async fn serialize(self) -> Result, Self::SerializeError> { + let Self { + id, + kind, + file_paths, + location_id, + location_path, + stage, + output, + .. + } = self; + + rmp_serde::to_vec_named(&SaveState { + id, + kind, + file_paths, + location_id, + location_path, + stage, + output, + }) + } + + async fn deserialize( + data: &[u8], + db: Self::DeserializeCtx, + ) -> Result { + rmp_serde::from_slice(data).map( + |SaveState { + id, + kind, + file_paths, + location_id, + location_path, + stage, + output, + }| Self { + id, + kind, + file_paths, + location_id, + location_path, + stage, + db, + output, + }, + ) + } +} + +#[inline] +async fn fetch_objects_already_with_media_data( + kind: Kind, + file_paths: &[file_path_for_media_processor::Data], + db: &PrismaClient, +) -> Result, media_processor::Error> { + let object_ids = file_paths + .iter() + .filter_map(|file_path| file_path.object_id) + .collect(); + + match kind { + Kind::Exif => db + .exif_data() + .find_many(vec![exif_data::object_id::in_vec(object_ids)]) + .select(exif_data::select!({ object_id })) + .exec() + .await + .map(|object_ids| object_ids.into_iter().map(|data| data.object_id).collect()) + .map_err(Into::into), + + Kind::FFmpeg => db + .ffmpeg_data() + .find_many(vec![ffmpeg_data::object_id::in_vec(object_ids)]) + .select(ffmpeg_data::select!({ object_id })) + .exec() + .await + .map(|object_ids| object_ids.into_iter().map(|data| data.object_id).collect()) + .map_err(Into::into), + } +} + +#[inline] +fn filter_files_to_extract_media_data( + objects_already_with_media_data: Vec, + location_id: location::id::Type, + location_path: &Path, + file_paths: &mut Vec, + Output { + skipped, errors, .. + }: &mut Output, +) -> HashMap { + let unique_objects_already_with_media_data = objects_already_with_media_data + .into_iter() + .collect::>(); + + *skipped = unique_objects_already_with_media_data.len() as u64; + + file_paths.retain(|file_path| { + !unique_objects_already_with_media_data + .contains(&file_path.object_id.expect("already checked")) + }); + + file_paths + .iter() + .filter_map(|file_path| { + IsolatedFilePathData::try_from((location_id, file_path)) + .map_err(|e| { + errors.push( + media_processor::NonCriticalError::from( + NonCriticalError::FailedToConstructIsolatedFilePathData( + file_path.id, + e.to_string(), + ), + ) + .into(), + ); + }) + .map(|iso_file_path| { + ( + file_path.id, + ( + location_path.join(iso_file_path), + file_path.object_id.expect("already checked"), + ), + ) + }) + .ok() + }) + .collect() +} + +enum ExtractionOutputKind { + Exif(Result, media_processor::NonCriticalError>), + FFmpeg(Result), +} + +struct ExtractionOutput { + file_path_id: file_path::id::Type, + object_id: object::id::Type, + kind: ExtractionOutputKind, +} + +#[allow(clippy::large_enum_variant)] +/* + * NOTE(fogodev): Interrupts will be pretty rare, so paying the boxing price for + * the Processed variant isn't worth it to avoid the enum size disparity between variants + */ +enum InterruptRace { + Interrupted(InterruptionKind), + Processed(ExtractionOutput), +} + +#[inline] +fn prepare_extraction_futures<'a>( + kind: Kind, + paths_by_id: &'a HashMap, + interrupter: &'a Interrupter, +) -> FutureGroup + 'a> { + paths_by_id + .iter() + .map(|(file_path_id, (path, object_id))| async move { + InterruptRace::Processed(ExtractionOutput { + file_path_id: *file_path_id, + object_id: *object_id, + kind: match kind { + Kind::Exif => ExtractionOutputKind::Exif(exif_media_data::extract(path).await), + Kind::FFmpeg => { + ExtractionOutputKind::FFmpeg(ffmpeg_media_data::extract(path).await) + } + }, + }) + }) + .map(|fut| { + ( + fut, + interrupter.into_future().map(InterruptRace::Interrupted), + ) + .race() + }) + .collect::>() +} + +#[inline] +fn process_output( + ExtractionOutput { + file_path_id, + object_id, + kind, + }: ExtractionOutput, + exif_media_datas: &mut Vec<(ExifMetadata, object::id::Type)>, + ffmpeg_media_datas: &mut Vec<(FFmpegMetadata, object::id::Type)>, + extract_ids_to_remove_from_map: &mut Vec, + output: &mut Output, +) { + match kind { + ExtractionOutputKind::Exif(Ok(Some(exif_data))) => { + exif_media_datas.push((exif_data, object_id)); + } + ExtractionOutputKind::Exif(Ok(None)) => { + // No exif media data found + output.skipped += 1; + } + ExtractionOutputKind::FFmpeg(Ok(ffmpeg_data)) => { + ffmpeg_media_datas.push((ffmpeg_data, object_id)); + } + ExtractionOutputKind::Exif(Err(e)) | ExtractionOutputKind::FFmpeg(Err(e)) => { + output.errors.push(e.into()); + } + } + + extract_ids_to_remove_from_map.push(file_path_id); +} + +#[inline] +async fn save( + kind: Kind, + exif_media_datas: &mut Vec<(ExifMetadata, object::id::Type)>, + ffmpeg_media_datas: &mut Vec<(FFmpegMetadata, object::id::Type)>, + db: &PrismaClient, +) -> Result { + match kind { + Kind::Exif => exif_media_data::save(mem::take(exif_media_datas), db).await, + Kind::FFmpeg => ffmpeg_media_data::save(mem::take(ffmpeg_media_datas), db).await, + } +} diff --git a/core/crates/heavy-lifting/src/media_processor/tasks/mod.rs b/core/crates/heavy-lifting/src/media_processor/tasks/mod.rs new file mode 100644 index 000000000..cb88d09d0 --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/tasks/mod.rs @@ -0,0 +1,5 @@ +pub mod media_data_extractor; +pub mod thumbnailer; + +pub use media_data_extractor::MediaDataExtractor; +pub use thumbnailer::Thumbnailer; diff --git a/core/crates/heavy-lifting/src/media_processor/tasks/thumbnailer.rs b/core/crates/heavy-lifting/src/media_processor/tasks/thumbnailer.rs new file mode 100644 index 000000000..c04fb6c55 --- /dev/null +++ b/core/crates/heavy-lifting/src/media_processor/tasks/thumbnailer.rs @@ -0,0 +1,677 @@ +//! Thumbnails directory have the following structure: +//! thumbnails/ +//! ├── version.txt +//! ├── ephemeral/ # ephemeral ones have it's own directory +//! │ └── <`cas_id`>[0..3]/ # sharding +//! │ └── <`cas_id`>.webp +//! └── <`library_id`>/ # we segregate thumbnails by library +//! └── <`cas_id`>[0..3]/ # sharding +//! └── <`cas_id`>.webp + +use crate::{ + media_processor::{ + self, + helpers::thumbnailer::{ + can_generate_thumbnail_for_document, can_generate_thumbnail_for_image, get_shard_hex, + EPHEMERAL_DIR, TARGET_PX, TARGET_QUALITY, THUMBNAIL_GENERATION_TIMEOUT, WEBP_EXTENSION, + }, + ThumbKey, ThumbnailKind, + }, + Error, +}; + +use sd_core_file_path_helper::IsolatedFilePathData; +use sd_core_prisma_helpers::file_path_for_media_processor; + +use sd_file_ext::extensions::{DocumentExtension, ImageExtension}; +use sd_images::{format_image, scale_dimensions, ConvertibleExtension}; +use sd_media_metadata::exif::Orientation; +use sd_prisma::prisma::{file_path, location}; +use sd_task_system::{ + ExecStatus, Interrupter, InterruptionKind, IntoAnyTaskOutput, SerializableTask, Task, TaskId, +}; +use sd_utils::error::FileIOError; + +use std::{ + collections::HashMap, + fmt, + future::IntoFuture, + mem, + ops::Deref, + path::{Path, PathBuf}, + pin::pin, + str::FromStr, + sync::Arc, + time::Duration, +}; + +use futures::{FutureExt, StreamExt}; +use futures_concurrency::future::{FutureGroup, Race}; +use image::{imageops, DynamicImage, GenericImageView}; +use serde::{Deserialize, Serialize}; +use specta::Type; +use tokio::{ + fs, io, + task::spawn_blocking, + time::{sleep, Instant}, +}; +use tracing::{error, info, trace}; +use uuid::Uuid; +use webp::Encoder; + +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateThumbnailArgs { + pub extension: String, + pub cas_id: String, + pub path: PathBuf, +} + +impl GenerateThumbnailArgs { + #[must_use] + pub const fn new(extension: String, cas_id: String, path: PathBuf) -> Self { + Self { + extension, + cas_id, + path, + } + } +} + +pub type ThumbnailId = u32; + +pub trait NewThumbnailReporter: Send + Sync + fmt::Debug + 'static { + fn new_thumbnail(&self, thumb_key: ThumbKey); +} + +#[derive(Debug)] +pub struct Thumbnailer { + id: TaskId, + reporter: Arc, + thumbs_kind: ThumbnailKind, + thumbnails_directory_path: Arc, + thumbnails_to_generate: HashMap, + already_processed_ids: Vec, + should_regenerate: bool, + with_priority: bool, + output: Output, +} + +#[async_trait::async_trait] +impl Task for Thumbnailer { + fn id(&self) -> TaskId { + self.id + } + + fn with_priority(&self) -> bool { + self.with_priority + } + + fn with_timeout(&self) -> Option { + Some(Duration::from_secs(60 * 5)) // The entire task must not take more than 5 minutes + } + + async fn run(&mut self, interrupter: &Interrupter) -> Result { + enum InterruptRace { + Interrupted(InterruptionKind), + Processed(ThumbnailGenerationOutput), + } + + let Self { + thumbs_kind, + thumbnails_directory_path, + thumbnails_to_generate, + already_processed_ids, + should_regenerate, + with_priority, + reporter, + output, + .. + } = self; + + // Removing already processed thumbnails from a possible previous run + already_processed_ids.drain(..).for_each(|id| { + thumbnails_to_generate.remove(&id); + }); + + let start = Instant::now(); + + let mut futures = pin!(thumbnails_to_generate + .iter() + .map(|(id, generate_args)| { + let path = &generate_args.path; + + ( + generate_thumbnail( + thumbnails_directory_path, + generate_args, + thumbs_kind, + *should_regenerate, + ) + .map(|res| (*id, res)), + sleep(THUMBNAIL_GENERATION_TIMEOUT).map(|()| { + ( + *id, + ( + THUMBNAIL_GENERATION_TIMEOUT, + Err(NonCriticalError::ThumbnailGenerationTimeout(path.clone())), + ), + ) + }), + ) + .race() + .map(InterruptRace::Processed) + }) + .map(|fut| ( + fut, + interrupter.into_future().map(InterruptRace::Interrupted) + ) + .race()) + .collect::>()); + + while let Some(race_output) = futures.next().await { + match race_output { + InterruptRace::Processed(out) => process_thumbnail_generation_output( + out, + *with_priority, + reporter.as_ref(), + already_processed_ids, + output, + ), + + InterruptRace::Interrupted(kind) => { + output.total_time += start.elapsed(); + return Ok(match kind { + InterruptionKind::Pause => ExecStatus::Paused, + InterruptionKind::Cancel => ExecStatus::Canceled, + }); + } + } + } + + output.total_time += start.elapsed(); + + #[allow(clippy::cast_precision_loss)] + // SAFETY: we're probably won't have 2^52 thumbnails being generated on a single task for this cast to have + // a precision loss issue + let total = (output.generated + output.skipped) as f64; + + let mean_generation_time = output.mean_time_acc / total; + + let generation_time_std_dev = Duration::from_secs_f64( + (mean_generation_time.mul_add(-mean_generation_time, output.std_dev_acc / total)) + .sqrt(), + ); + + info!( + "{{generated: {generated}, skipped: {skipped}}} thumbnails; \ + mean generation time: {mean_generation_time:?} ± {generation_time_std_dev:?}", + generated = output.generated, + skipped = output.skipped, + mean_generation_time = Duration::from_secs_f64(mean_generation_time) + ); + + Ok(ExecStatus::Done(mem::take(output).into_output())) + } +} + +#[derive(Serialize, Deserialize, Default, Debug)] +pub struct Output { + pub generated: u64, + pub skipped: u64, + pub errors: Vec, + pub total_time: Duration, + pub mean_time_acc: f64, + pub std_dev_acc: f64, +} + +#[derive(thiserror::Error, Debug, Serialize, Deserialize, Type)] +pub enum NonCriticalError { + #[error("file path has no cas_id")] + MissingCasId(file_path::id::Type), + #[error("failed to extract isolated file path data from file path : {1}")] + FailedToExtractIsolatedFilePathData(file_path::id::Type, String), + #[error("failed to generate video file thumbnail : {1}", .0.display())] + VideoThumbnailGenerationFailed(PathBuf, String), + #[error("failed to format image : {1}", .0.display())] + FormatImage(PathBuf, String), + #[error("failed to encode webp image : {1}", .0.display())] + WebPEncoding(PathBuf, String), + #[error("processing thread panicked while generating thumbnail from : {1}", .0.display())] + PanicWhileGeneratingThumbnail(PathBuf, String), + #[error("failed to create shard directory for thumbnail: {0}")] + CreateShardDirectory(String), + #[error("failed to save thumbnail : {1}", .0.display())] + SaveThumbnail(PathBuf, String), + #[error("thumbnail generation timed out ", .0.display())] + ThumbnailGenerationTimeout(PathBuf), +} + +impl Thumbnailer { + fn new( + thumbs_kind: ThumbnailKind, + thumbnails_directory_path: Arc, + thumbnails_to_generate: HashMap, + errors: Vec, + should_regenerate: bool, + with_priority: bool, + reporter: Arc, + ) -> Self { + Self { + id: TaskId::new_v4(), + thumbs_kind, + thumbnails_directory_path, + already_processed_ids: Vec::with_capacity(thumbnails_to_generate.len()), + thumbnails_to_generate, + should_regenerate, + with_priority, + output: Output { + errors, + ..Default::default() + }, + reporter, + } + } + + #[must_use] + pub fn new_ephemeral( + thumbnails_directory_path: Arc, + thumbnails_to_generate: Vec, + reporter: Arc, + ) -> Self { + Self::new( + ThumbnailKind::Ephemeral, + thumbnails_directory_path, + thumbnails_to_generate + .into_iter() + .enumerate() + .map(|(i, args)| { + #[allow(clippy::cast_possible_truncation)] + { + // SAFETY: it's fine, we will never process more than 4 billion thumbnails + // on a single task LMAO + (i as ThumbnailId, args) + } + }) + .collect(), + Vec::new(), + false, + true, + reporter, + ) + } + + #[must_use] + pub fn new_indexed( + thumbnails_directory_path: Arc, + file_paths: &[file_path_for_media_processor::Data], + (location_id, location_path): (location::id::Type, &Path), + library_id: Uuid, + should_regenerate: bool, + with_priority: bool, + reporter: Arc, + ) -> Self { + let mut errors = Vec::new(); + + Self::new( + ThumbnailKind::Indexed(library_id), + thumbnails_directory_path, + file_paths + .iter() + .filter_map(|file_path| { + if let Some(cas_id) = file_path.cas_id.as_ref() { + let file_path_id = file_path.id; + IsolatedFilePathData::try_from((location_id, file_path)) + .map_err(|e| { + errors.push( + media_processor::NonCriticalError::from( + NonCriticalError::FailedToExtractIsolatedFilePathData( + file_path_id, + e.to_string(), + ), + ) + .into(), + ); + }) + .ok() + .map(|iso_file_path| (file_path_id, cas_id, iso_file_path)) + } else { + errors.push( + media_processor::NonCriticalError::from( + NonCriticalError::MissingCasId(file_path.id), + ) + .into(), + ); + None + } + }) + .map(|(file_path_id, cas_id, iso_file_path)| { + let full_path = location_path.join(&iso_file_path); + + #[allow(clippy::cast_sign_loss)] + { + ( + // SAFETY: db doesn't have negative indexes + file_path_id as u32, + GenerateThumbnailArgs::new( + iso_file_path.extension().to_string(), + cas_id.clone(), + full_path, + ), + ) + } + }) + .collect::>(), + errors, + should_regenerate, + with_priority, + reporter, + ) + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct SaveState { + id: TaskId, + thumbs_kind: ThumbnailKind, + thumbnails_directory_path: Arc, + thumbnails_to_generate: HashMap, + should_regenerate: bool, + with_priority: bool, + output: Output, +} + +impl SerializableTask for Thumbnailer { + type SerializeError = rmp_serde::encode::Error; + + type DeserializeError = rmp_serde::decode::Error; + + type DeserializeCtx = Arc; + + async fn serialize(self) -> Result, Self::SerializeError> { + let Self { + id, + thumbs_kind, + thumbnails_directory_path, + mut thumbnails_to_generate, + already_processed_ids, + should_regenerate, + with_priority, + output, + .. + } = self; + + for id in already_processed_ids { + thumbnails_to_generate.remove(&id); + } + + rmp_serde::to_vec_named(&SaveState { + id, + thumbs_kind, + thumbnails_directory_path, + thumbnails_to_generate, + should_regenerate, + with_priority, + output, + }) + } + + async fn deserialize( + data: &[u8], + reporter: Self::DeserializeCtx, + ) -> Result { + rmp_serde::from_slice(data).map( + |SaveState { + id, + thumbs_kind, + thumbnails_to_generate, + thumbnails_directory_path, + should_regenerate, + with_priority, + output, + }| Self { + id, + reporter, + thumbs_kind, + thumbnails_to_generate, + thumbnails_directory_path, + already_processed_ids: Vec::new(), + should_regenerate, + with_priority, + output, + }, + ) + } +} + +enum GenerationStatus { + Generated, + Skipped, +} + +type ThumbnailGenerationOutput = ( + ThumbnailId, + ( + Duration, + Result<(ThumbKey, GenerationStatus), NonCriticalError>, + ), +); + +fn process_thumbnail_generation_output( + (id, (elapsed_time, res)): ThumbnailGenerationOutput, + with_priority: bool, + reporter: &impl NewThumbnailReporter, + already_processed_ids: &mut Vec, + Output { + generated, + skipped, + errors, + mean_time_acc: mean_generation_time_accumulator, + std_dev_acc: std_dev_accumulator, + .. + }: &mut Output, +) { + let elapsed_time = elapsed_time.as_secs_f64(); + *mean_generation_time_accumulator += elapsed_time; + *std_dev_accumulator += elapsed_time * elapsed_time; + + match res { + Ok((thumb_key, status)) => { + match status { + GenerationStatus::Generated => { + *generated += 1; + } + GenerationStatus::Skipped => { + *skipped += 1; + } + } + + // This if is REALLY needed, due to the sheer performance of the thumbnailer, + // I restricted to only send events notifying for thumbnails in the current + // opened directory, sending events for the entire location turns into a + // humongous bottleneck in the frontend lol, since it doesn't even knows + // what to do with thumbnails for inner directories lol + // - fogodev + if with_priority { + reporter.new_thumbnail(thumb_key); + } + } + Err(e) => { + errors.push(media_processor::NonCriticalError::from(e).into()); + *skipped += 1; + } + } + + already_processed_ids.push(id); +} + +async fn generate_thumbnail( + thumbnails_directory: &Path, + GenerateThumbnailArgs { + extension, + cas_id, + path, + }: &GenerateThumbnailArgs, + kind: &ThumbnailKind, + should_regenerate: bool, +) -> ( + Duration, + Result<(ThumbKey, GenerationStatus), NonCriticalError>, +) { + trace!("Generating thumbnail for {}", path.display()); + let start = Instant::now(); + + let mut output_path = match kind { + ThumbnailKind::Ephemeral => thumbnails_directory.join(EPHEMERAL_DIR), + ThumbnailKind::Indexed(library_id) => thumbnails_directory.join(library_id.to_string()), + }; + + output_path.push(get_shard_hex(cas_id)); + output_path.push(cas_id); + output_path.set_extension(WEBP_EXTENSION); + + if let Err(e) = fs::metadata(&*output_path).await { + if e.kind() != io::ErrorKind::NotFound { + error!( + "Failed to check if thumbnail exists, but we will try to generate it anyway: {e:#?}" + ); + } + // Otherwise we good, thumbnail doesn't exist so we can generate it + } else if !should_regenerate { + trace!( + "Skipping thumbnail generation for {} because it already exists", + path.display() + ); + return ( + start.elapsed(), + Ok((ThumbKey::new(cas_id, kind), GenerationStatus::Skipped)), + ); + } + + if let Ok(extension) = ImageExtension::from_str(extension) { + if can_generate_thumbnail_for_image(extension) { + if let Err(e) = generate_image_thumbnail(&path, &output_path).await { + return (start.elapsed(), Err(e)); + } + } + } else if let Ok(extension) = DocumentExtension::from_str(extension) { + if can_generate_thumbnail_for_document(extension) { + if let Err(e) = generate_image_thumbnail(&path, &output_path).await { + return (start.elapsed(), Err(e)); + } + } + } + + #[cfg(feature = "ffmpeg")] + { + use crate::media_processor::helpers::thumbnailer::can_generate_thumbnail_for_video; + use sd_file_ext::extensions::VideoExtension; + + if let Ok(extension) = VideoExtension::from_str(extension) { + if can_generate_thumbnail_for_video(extension) { + if let Err(e) = generate_video_thumbnail(&path, &output_path).await { + return (start.elapsed(), Err(e)); + } + } + } + } + + trace!("Generated thumbnail for {}", path.display()); + + ( + start.elapsed(), + Ok((ThumbKey::new(cas_id, kind), GenerationStatus::Generated)), + ) +} + +async fn generate_image_thumbnail( + file_path: impl AsRef + Send, + output_path: impl AsRef + Send, +) -> Result<(), NonCriticalError> { + let file_path = file_path.as_ref().to_path_buf(); + + let webp = spawn_blocking({ + let file_path = file_path.clone(); + + move || -> Result<_, NonCriticalError> { + let mut img = format_image(&file_path) + .map_err(|e| NonCriticalError::FormatImage(file_path.clone(), e.to_string()))?; + + let (w, h) = img.dimensions(); + + #[allow(clippy::cast_precision_loss)] + let (w_scaled, h_scaled) = scale_dimensions(w as f32, h as f32, TARGET_PX); + + // Optionally, resize the existing photo and convert back into DynamicImage + if w != w_scaled && h != h_scaled { + img = DynamicImage::ImageRgba8(imageops::resize( + &img, + w_scaled, + h_scaled, + imageops::FilterType::Triangle, + )); + } + + // this corrects the rotation/flip of the image based on the *available* exif data + // not all images have exif data, so we don't error. we also don't rotate HEIF as that's against the spec + if let Some(orientation) = Orientation::from_path(&file_path) { + if ConvertibleExtension::try_from(file_path.as_ref()) + .expect("we already checked if the image was convertible") + .should_rotate() + { + img = orientation.correct_thumbnail(img); + } + } + + // Create the WebP encoder for the above image + let encoder = Encoder::from_image(&img) + .map_err(|reason| NonCriticalError::WebPEncoding(file_path, reason.to_string()))?; + + // Type `WebPMemory` is !Send, which makes the `Future` in this function `!Send`, + // this make us `deref` to have a `&[u8]` and then `to_owned` to make a `Vec` + // which implies on a unwanted clone... + Ok(encoder.encode(TARGET_QUALITY).deref().to_owned()) + } + }) + .await + .map_err(|e| { + NonCriticalError::PanicWhileGeneratingThumbnail(file_path.clone(), e.to_string()) + })??; + + let output_path = output_path.as_ref(); + + if let Some(shard_dir) = output_path.parent() { + fs::create_dir_all(shard_dir).await.map_err(|e| { + NonCriticalError::CreateShardDirectory(FileIOError::from((shard_dir, e)).to_string()) + })?; + } else { + error!( + "Failed to get parent directory of '{}' for sharding parent directory", + output_path.display() + ); + } + + fs::write(output_path, &webp).await.map_err(|e| { + NonCriticalError::SaveThumbnail(file_path, FileIOError::from((output_path, e)).to_string()) + }) +} + +#[cfg(feature = "ffmpeg")] +async fn generate_video_thumbnail( + file_path: impl AsRef + Send, + output_path: impl AsRef + Send, +) -> Result<(), NonCriticalError> { + use sd_ffmpeg::{to_thumbnail, ThumbnailSize}; + + let file_path = file_path.as_ref(); + + to_thumbnail( + file_path, + output_path, + ThumbnailSize::Scale(1024), + TARGET_QUALITY, + ) + .await + .map_err(|e| { + NonCriticalError::VideoThumbnailGenerationFailed(file_path.to_path_buf(), e.to_string()) + }) +} diff --git a/core/crates/heavy-lifting/src/utils/sub_path.rs b/core/crates/heavy-lifting/src/utils/sub_path.rs index 6461ccdb7..f9e607b41 100644 --- a/core/crates/heavy-lifting/src/utils/sub_path.rs +++ b/core/crates/heavy-lifting/src/utils/sub_path.rs @@ -11,7 +11,7 @@ use std::path::{Path, PathBuf}; use prisma_client_rust::QueryError; #[derive(thiserror::Error, Debug)] -pub enum SubPathError { +pub enum Error { #[error("received sub path not in database: ", .0.display())] SubPathNotFound(Box), @@ -22,10 +22,10 @@ pub enum SubPathError { IsoFilePath(#[from] FilePathError), } -impl From for rspc::Error { - fn from(err: SubPathError) -> Self { +impl From for rspc::Error { + fn from(err: Error) -> Self { match err { - SubPathError::SubPathNotFound(_) => { + Error::SubPathNotFound(_) => { Self::with_cause(ErrorCode::NotFound, err.to_string(), err) } @@ -39,7 +39,7 @@ pub async fn get_full_path_from_sub_path( sub_path: &Option + Send + Sync>, location_path: impl AsRef + Send, db: &PrismaClient, -) -> Result { +) -> Result { let location_path = location_path.as_ref(); match sub_path { @@ -53,7 +53,7 @@ pub async fn get_full_path_from_sub_path( sub_path, &IsolatedFilePathData::new(location_id, location_path, &full_path, true)?, db, - SubPathError::SubPathNotFound, + Error::SubPathNotFound, ) .await?; @@ -68,7 +68,7 @@ pub async fn maybe_get_iso_file_path_from_sub_path( sub_path: &Option + Send + Sync>, location_path: impl AsRef + Send, db: &PrismaClient, -) -> Result>, SubPathError> { +) -> Result>, Error> { let location_path = location_path.as_ref(); match sub_path { @@ -79,14 +79,9 @@ pub async fn maybe_get_iso_file_path_from_sub_path( let sub_iso_file_path = IsolatedFilePathData::new(location_id, location_path, &full_path, true)?; - ensure_file_path_exists( - sub_path, - &sub_iso_file_path, - db, - SubPathError::SubPathNotFound, - ) - .await - .map(|()| Some(sub_iso_file_path)) + ensure_file_path_exists(sub_path, &sub_iso_file_path, db, Error::SubPathNotFound) + .await + .map(|()| Some(sub_iso_file_path)) } _ => Ok(None), } diff --git a/crates/task-system/src/error.rs b/crates/task-system/src/error.rs index 626be73c3..9b19387f8 100644 --- a/crates/task-system/src/error.rs +++ b/crates/task-system/src/error.rs @@ -11,6 +11,8 @@ pub enum SystemError { TaskAborted(TaskId), #[error("task join error ")] TaskJoin(TaskId), + #[error("task timeout error ")] + TaskTimeout(TaskId), #[error("forced abortion for task timed out")] TaskForcedAbortTimeout(TaskId), } diff --git a/crates/task-system/src/task.rs b/crates/task-system/src/task.rs index 9fd05d5b5..7804a01ca 100644 --- a/crates/task-system/src/task.rs +++ b/crates/task-system/src/task.rs @@ -7,6 +7,7 @@ use std::{ Arc, }, task::{Context, Poll}, + time::Duration, }; use async_channel as chan; @@ -141,6 +142,13 @@ pub trait Task: fmt::Debug + Downcast + Send + Sync + 'static { false } + /// Here we define if we want the task system to shutdown our task if it takes too long to finish. By default the + /// task system will wait indefinitely for the task to finish, but if the user wants to have a timeout, they can + /// return a [`Duration`] here and the task system will cancel the task if it takes longer than the specified time. + fn with_timeout(&self) -> Option { + None + } + /// This method represent the work that should be done by the worker, it will be called by the /// worker when there is a slot available in its internal queue. /// We receive a `&mut self` so any internal data can be mutated on each `run` invocation. diff --git a/crates/task-system/src/worker/runner.rs b/crates/task-system/src/worker/runner.rs index d3cc3d91e..ac4788266 100644 --- a/crates/task-system/src/worker/runner.rs +++ b/crates/task-system/src/worker/runner.rs @@ -10,13 +10,13 @@ use std::{ }; use async_channel as chan; -use futures::StreamExt; +use futures::{FutureExt, StreamExt}; use futures_concurrency::future::Race; use tokio::{ spawn, sync::oneshot, task::{JoinError, JoinHandle}, - time::{timeout, Instant}, + time::{sleep, timeout, Instant}, }; use tracing::{debug, error, trace, warn}; @@ -1165,11 +1165,26 @@ fn handle_run_task_attempt( (task, Err(SystemError::TaskAborted(task_id))) } else { - let res = task.run(&interrupter).await; + let run_result = if let Some(timeout_duration) = task.with_timeout() { + (task.run(&interrupter).map(Ok), async move { + sleep(timeout_duration) + .map(|()| Err(SystemError::TaskTimeout(task_id))) + .await + }) + .race() + .await + } else { + task.run(&interrupter).map(Ok).await + }; - trace!("Ran task: : {res:?}"); + match run_result { + Ok(res) => { + trace!("Ran task: : {res:?}"); - (task, Ok(res)) + (task, Ok(res)) + } + Err(e) => (task, Err(e)), + } } } })