diff --git a/Cargo.lock b/Cargo.lock index 838077063..692051e78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7508,10 +7508,13 @@ dependencies = [ "prisma-client-rust", "reqwest", "rmp-serde", + "sd-core-sync", "sd-file-path-helper", "sd-prisma", + "sd-sync", "sd-utils", "serde", + "serde_json", "thiserror", "tokio", "tokio-stream", diff --git a/core/prisma/schema.prisma b/core/prisma/schema.prisma index d753bc58f..4052b7a35 100644 --- a/core/prisma/schema.prisma +++ b/core/prisma/schema.prisma @@ -328,14 +328,14 @@ model Tag { @@map("tag") } -/// @relation(item: tag, group: object) +/// @relation(item: object, group: tag) model TagOnObject { - tag_id Int - tag Tag @relation(fields: [tag_id], references: [id], onDelete: Restrict) - object_id Int object Object @relation(fields: [object_id], references: [id], onDelete: Restrict) + tag_id Int + tag Tag @relation(fields: [tag_id], references: [id], onDelete: Restrict) + date_created DateTime? @@id([tag_id, object_id]) @@ -344,9 +344,9 @@ model TagOnObject { //// Label //// +/// @shared(id: name) model Label { id Int @id @default(autoincrement()) - pub_id Bytes @unique name String @unique date_created DateTime @default(now()) date_modified DateTime @default(now()) @@ -356,15 +356,16 @@ model Label { @@map("label") } +/// @relation(item: object, group: label) model LabelOnObject { date_created DateTime @default(now()) - label_id Int - label Label @relation(fields: [label_id], references: [id], onDelete: Restrict) - object_id Int object Object @relation(fields: [object_id], references: [id], onDelete: Restrict) + label_id Int + label Label @relation(fields: [label_id], references: [id], onDelete: Restrict) + @@id([label_id, object_id]) @@map("label_on_object") } diff --git a/core/src/api/labels.rs b/core/src/api/labels.rs index caf4d104c..df73b02d8 100644 --- a/core/src/api/labels.rs +++ b/core/src/api/labels.rs @@ -1,6 +1,10 @@ use crate::{invalidate_query, library::Library, object::media::thumbnail::get_indexed_thumb_key}; -use sd_prisma::prisma::{label, label_on_object, object, SortOrder}; +use sd_prisma::{ + prisma::{label, label_on_object, object, SortOrder}, + prisma_sync, +}; +use sd_sync::OperationFactory; use std::collections::BTreeMap; @@ -117,12 +121,26 @@ pub(crate) fn mount() -> AlphaRouter { "delete", R.with2(library()) .mutation(|(_, library), label_id: i32| async move { - library - .db + let Library { db, sync, .. } = library.as_ref(); + + let label = db .label() - .delete(label::id::equals(label_id)) + .find_unique(label::id::equals(label_id)) .exec() - .await?; + .await? + .ok_or_else(|| { + rspc::Error::new( + rspc::ErrorCode::NotFound, + "Label not found".to_string(), + ) + })?; + + sync.write_op( + db, + sync.shared_delete(prisma_sync::label::SyncId { name: label.name }), + db.label().delete(label::id::equals(label_id)), + ) + .await?; invalidate_query!(library, "labels.list"); diff --git a/core/src/object/media/media_processor/job.rs b/core/src/object/media/media_processor/job.rs index b43fd30f9..cf3db0a0e 100644 --- a/core/src/object/media/media_processor/job.rs +++ b/core/src/object/media/media_processor/job.rs @@ -105,7 +105,7 @@ impl StatefulJob for MediaProcessorJobInit { ctx: &WorkerContext, data: &mut Option, ) -> Result, JobError> { - let Library { db, .. } = ctx.library.as_ref(); + let Library { db, sync, .. } = ctx.library.as_ref(); let location_id = self.location.id; let location_path = @@ -186,6 +186,7 @@ impl StatefulJob for MediaProcessorJobInit { location_path.clone(), file_paths_for_labeling, Arc::clone(db), + sync.clone(), ) .await; @@ -336,7 +337,11 @@ impl StatefulJob for MediaProcessorJobInit { match ctx .node .image_labeller - .resume_batch(data.labeler_batch_token, Arc::clone(&ctx.library.db)) + .resume_batch( + data.labeler_batch_token, + Arc::clone(&ctx.library.db), + ctx.library.sync.clone(), + ) .await { Ok(labels_rx) => labels_rx, diff --git a/core/src/object/media/media_processor/shallow.rs b/core/src/object/media/media_processor/shallow.rs index 1e24b8aa8..a8e52cb50 100644 --- a/core/src/object/media/media_processor/shallow.rs +++ b/core/src/object/media/media_processor/shallow.rs @@ -40,7 +40,7 @@ const BATCH_SIZE: usize = 10; pub async fn shallow( location: &location::Data, sub_path: &PathBuf, - library @ Library { db, .. }: &Library, + library @ Library { db, sync, .. }: &Library, #[cfg(feature = "ai")] regenerate_labels: bool, node: &Node, ) -> Result<(), JobError> { @@ -116,6 +116,7 @@ pub async fn shallow( location_path.clone(), file_paths_for_labelling, Arc::clone(db), + sync.clone(), ) .await; diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index a91ddc51b..6a5e94263 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -11,6 +11,8 @@ edition = { workspace = true } [dependencies] sd-prisma = { path = "../prisma" } +sd-core-sync = { path = "../../core/crates/sync" } +sd-sync = { path = "../sync" } sd-utils = { path = "../utils" } sd-file-path-helper = { path = "../file-path-helper" } @@ -24,6 +26,7 @@ prisma-client-rust = { workspace = true } reqwest = { workspace = true, features = ["stream", "native-tls-vendored"] } rmp-serde = { workspace = true } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["fs"] } tokio-stream = { workspace = true } diff --git a/crates/ai/src/image_labeler/actor.rs b/crates/ai/src/image_labeler/actor.rs index e67a616a4..289c8fe31 100644 --- a/crates/ai/src/image_labeler/actor.rs +++ b/crates/ai/src/image_labeler/actor.rs @@ -37,6 +37,7 @@ const PENDING_BATCHES_FILE: &str = "pending_image_labeler_batches.bin"; type ResumeBatchRequest = ( BatchToken, Arc, + Arc, oneshot::Sender, ImageLabelerError>>, ); @@ -53,6 +54,7 @@ pub(super) struct Batch { pub(super) output_tx: chan::Sender, pub(super) is_resumable: bool, pub(super) db: Arc, + pub(super) sync: Arc, } #[derive(Serialize, Deserialize, Debug)] @@ -165,6 +167,7 @@ impl ImageLabeler { location_path: PathBuf, file_paths: Vec, db: Arc, + sync: Arc, is_resumable: bool, ) -> (BatchToken, chan::Receiver) { let (tx, rx) = chan::bounded(usize::max(file_paths.len(), 1)); @@ -180,6 +183,7 @@ impl ImageLabeler { output_tx: tx, is_resumable, db, + sync, }) .await .is_err() @@ -201,8 +205,9 @@ impl ImageLabeler { location_path: PathBuf, file_paths: Vec, db: Arc, + sync: Arc, ) -> chan::Receiver { - self.new_batch_inner(location_id, location_path, file_paths, db, false) + self.new_batch_inner(location_id, location_path, file_paths, db, sync, false) .await .1 } @@ -214,8 +219,9 @@ impl ImageLabeler { location_path: PathBuf, file_paths: Vec, db: Arc, + sync: Arc, ) -> (BatchToken, chan::Receiver) { - self.new_batch_inner(location_id, location_path, file_paths, db, true) + self.new_batch_inner(location_id, location_path, file_paths, db, sync, true) .await } @@ -284,11 +290,12 @@ impl ImageLabeler { &self, token: BatchToken, db: Arc, + sync: Arc, ) -> Result, ImageLabelerError> { let (tx, rx) = oneshot::channel(); self.resume_batch_tx - .send((token, db, tx)) + .send((token, db, sync, tx)) .await .expect("critical error: image labeler communication channel unexpectedly closed"); @@ -334,6 +341,7 @@ async fn actor_loop( ResumeBatch( BatchToken, Arc, + Arc, oneshot::Sender, ImageLabelerError>>, ), UpdateModel( @@ -350,7 +358,8 @@ async fn actor_loop( let mut msg_stream = pin!(( new_batches_rx.map(StreamMessage::NewBatch), - resume_batch_rx.map(|(token, db, done_tx)| StreamMessage::ResumeBatch(token, db, done_tx)), + resume_batch_rx + .map(|(token, db, sync, done_tx)| StreamMessage::ResumeBatch(token, db, sync, done_tx)), update_model_rx.map(|(model, done_tx)| StreamMessage::UpdateModel(model, done_tx)), done_rx.clone().map(StreamMessage::BatchDone), shutdown_rx.map(StreamMessage::Shutdown) @@ -376,7 +385,7 @@ async fn actor_loop( } } - StreamMessage::ResumeBatch(token, db, resume_done_tx) => { + StreamMessage::ResumeBatch(token, db, sync, resume_done_tx) => { let resume_result = if let Some((batch, output_rx)) = to_resume_batches.write().await.remove(&token).map( |ResumableBatch { @@ -390,6 +399,7 @@ async fn actor_loop( Batch { token, db, + sync, output_tx, location_id, location_path, diff --git a/crates/ai/src/image_labeler/process.rs b/crates/ai/src/image_labeler/process.rs index b625bc621..c78a63004 100644 --- a/crates/ai/src/image_labeler/process.rs +++ b/crates/ai/src/image_labeler/process.rs @@ -1,9 +1,13 @@ use sd_file_path_helper::{file_path_for_media_processor, IsolatedFilePathData}; -use sd_prisma::prisma::{file_path, label, label_on_object, object, PrismaClient}; +use sd_prisma::{ + prisma::{file_path, label, label_on_object, object, PrismaClient}, + prisma_sync, +}; +use sd_sync::OperationFactory; use sd_utils::{db::MissingFieldError, error::FileIOError}; use std::{ - collections::{HashMap, HashSet, VecDeque}, + collections::{BTreeMap, HashMap, HashSet, VecDeque}, path::{Path, PathBuf}, sync::Arc, }; @@ -12,12 +16,12 @@ use async_channel as chan; use chrono::{DateTime, FixedOffset, Utc}; use futures_concurrency::future::{Join, Race}; use image::ImageFormat; +use serde_json::json; use tokio::{ fs, spawn, sync::{oneshot, OwnedRwLockReadGuard, OwnedSemaphorePermit, RwLock, Semaphore}, }; use tracing::{error, warn}; -use uuid::Uuid; use super::{actor::Batch, model::ModelAndSession, BatchToken, ImageLabelerError, LabelerOutput}; @@ -65,6 +69,7 @@ pub(super) async fn spawned_processing( file_paths, output_tx, db, + sync, is_resumable, }: Batch, available_parallelism: usize, @@ -213,6 +218,7 @@ pub(super) async fn spawned_processing( format, (output_tx.clone(), completed_tx.clone()), Arc::clone(&db), + sync.clone(), permit, ))); } @@ -247,6 +253,7 @@ pub(super) async fn spawned_processing( .collect(), output_tx, db, + sync: sync.clone(), is_resumable, }) }; @@ -289,6 +296,7 @@ async fn spawned_process_single_file( chan::Sender, ), db: Arc, + sync: Arc, _permit: OwnedSemaphorePermit, ) { let image = @@ -338,7 +346,7 @@ async fn spawned_process_single_file( } }; - let (has_new_labels, result) = match assign_labels(object_id, labels, &db).await { + let (has_new_labels, result) = match assign_labels(object_id, labels, &db, &sync).await { Ok(has_new_labels) => (has_new_labels, Ok(())), Err(e) => (false, Err(e)), }; @@ -386,7 +394,16 @@ pub async fn assign_labels( object_id: object::id::Type, mut labels: HashSet, db: &PrismaClient, + sync: &sd_core_sync::Manager, ) -> Result { + let object = db + .object() + .find_unique(object::id::equals(object_id)) + .select(object::select!({ pub_id })) + .exec() + .await? + .unwrap(); + let mut has_new_labels = false; let mut labels_ids = db @@ -399,53 +416,72 @@ pub async fn assign_labels( .map(|label| { labels.remove(&label.name); - label.id + (label.id, label.name) }) - .collect::>(); - - labels_ids.reserve(labels.len()); + .collect::>(); let date_created: DateTime = Utc::now().into(); if !labels.is_empty() { - labels_ids.extend( - db._batch( - labels - .into_iter() - .map(|name| { - db.label() - .create( - Uuid::new_v4().as_bytes().to_vec(), - name, - vec![label::date_created::set(date_created)], - ) - .select(label::select!({ id })) - }) - .collect::>(), - ) - .await? + let mut sync_params = Vec::with_capacity(labels.len() * 2); + + let db_params = labels .into_iter() - .map(|label| label.id), + .map(|name| { + sync_params.extend(sync.shared_create( + prisma_sync::label::SyncId { name: name.clone() }, + [(label::date_created::NAME, json!(&date_created))], + )); + + db.label() + .create(name, vec![label::date_created::set(date_created)]) + .select(label::select!({ id name })) + }) + .collect::>(); + + labels_ids.extend( + sync.write_ops(db, (sync_params, db_params)) + .await? + .into_iter() + .map(|l| (l.id, l.name)), ); + has_new_labels = true; } - db.label_on_object() - .create_many( - labels_ids - .into_iter() - .map(|label_id| { - label_on_object::create_unchecked( - label_id, - object_id, - vec![label_on_object::date_created::set(date_created)], - ) - }) - .collect(), - ) - .skip_duplicates() - .exec() - .await?; + let mut sync_params = Vec::with_capacity(labels_ids.len() * 2); + + let db_params: Vec<_> = labels_ids + .into_iter() + .map(|(label_id, name)| { + sync_params.extend(sync.relation_create( + prisma_sync::label_on_object::SyncId { + label: prisma_sync::label::SyncId { name }, + object: prisma_sync::object::SyncId { + pub_id: object.pub_id.clone(), + }, + }, + [], + )); + + label_on_object::create_unchecked( + label_id, + object_id, + vec![label_on_object::date_created::set(date_created)], + ) + }) + .collect(); + + sync.write_ops( + &db, + ( + sync_params, + db.label_on_object() + .create_many(db_params) + .skip_duplicates(), + ), + ) + .await?; Ok(has_new_labels) } diff --git a/crates/sync-generator/src/sync_data.rs b/crates/sync-generator/src/sync_data.rs index 3a33bfa37..d04e3e9ae 100644 --- a/crates/sync-generator/src/sync_data.rs +++ b/crates/sync-generator/src/sync_data.rs @@ -75,9 +75,10 @@ pub fn r#enum(models: Vec) -> TokenStream { ModelSyncType::Relation { item, group } => { let compound_id = format_ident!( "{}", - item.fields() + group + .fields() .unwrap() - .chain(group.fields().unwrap()) + .chain(item.fields().unwrap()) .map(|f| f.name()) .collect::>() .join("_") @@ -85,11 +86,19 @@ pub fn r#enum(models: Vec) -> TokenStream { let db_batch_items = { let batch_item = |item: &RelationFieldWalker| { + let item_model_sync_id_field_name_snake = models + .iter() + .find(|m| m.0.name() == item.related_model().name()) + .and_then(|(m, sync)| sync.as_ref()) + .map(|sync| snake_ident(sync.sync_id()[0].name())) + .unwrap(); let item_model_name_snake = snake_ident(item.related_model().name()); let item_field_name_snake = snake_ident(item.name()); quote!(db.#item_model_name_snake().find_unique( - prisma::#item_model_name_snake::pub_id::equals(id.#item_field_name_snake.pub_id.clone()) + prisma::#item_model_name_snake::#item_model_sync_id_field_name_snake::equals( + id.#item_field_name_snake.#item_model_sync_id_field_name_snake.clone() + ) )) }; @@ -117,7 +126,7 @@ pub fn r#enum(models: Vec) -> TokenStream { panic!("item and group not found!"); }; - let id = prisma::tag_on_object::#compound_id(item.id, group.id); + let id = prisma::#model_name_snake::#compound_id(item.id, group.id); match data { sd_sync::CRDTOperationData::Create => { diff --git a/interface/app/$libraryId/Explorer/util.ts b/interface/app/$libraryId/Explorer/util.ts index 33fddfe3e..ce46fb310 100644 --- a/interface/app/$libraryId/Explorer/util.ts +++ b/interface/app/$libraryId/Explorer/util.ts @@ -44,6 +44,7 @@ export const uniqueId = (item: ExplorerItem | { pub_id: number[] }) => { case 'NonIndexedPath': return item.item.path; case 'SpacedropPeer': + case 'Label': return item.item.name; default: return pubIdToString(item.item.pub_id); diff --git a/packages/client/src/core.ts b/packages/client/src/core.ts index d532dacab..7a6432490 100644 --- a/packages/client/src/core.ts +++ b/packages/client/src/core.ts @@ -19,7 +19,7 @@ export type Procedures = { { key: "jobs.isActive", input: LibraryArgs, result: boolean } | { key: "jobs.reports", input: LibraryArgs, result: JobGroup[] } | { key: "labels.count", input: LibraryArgs, result: number } | - { key: "labels.get", input: LibraryArgs, result: { id: number; pub_id: number[]; name: string; date_created: string; date_modified: string } | null } | + { key: "labels.get", input: LibraryArgs, result: { id: number; name: string; date_created: string; date_modified: string } | null } | { key: "labels.getForObject", input: LibraryArgs, result: Label[] } | { key: "labels.getWithObjects", input: LibraryArgs, result: { [key in number]: { date_created: string; object: { id: number } }[] } } | { key: "labels.list", input: LibraryArgs, result: Label[] } | @@ -355,9 +355,9 @@ export type KindStatistic = { kind: number; name: string; count: number; total_b export type KindStatistics = { statistics: KindStatistic[] } -export type Label = { id: number; pub_id: number[]; name: string; date_created: string; date_modified: string } +export type Label = { id: number; name: string; date_created: string; date_modified: string } -export type LabelWithObjects = { id: number; pub_id: number[]; name: string; date_created: string; date_modified: string; label_objects: { object: { id: number; file_paths: FilePath[] } }[] } +export type LabelWithObjects = { id: number; name: string; date_created: string; date_modified: string; label_objects: { object: { id: number; file_paths: FilePath[] } }[] } /** * Can wrap a query argument to require it to contain a `library_id` and provide helpers for working with libraries.