sync support for labels (#2070)

* more sync support for file paths + saved searches

* sync support for labels

* update sync prisma generator to support more than tags

* workey

* don't do illegal db migration

* use name as label id in explorer
This commit is contained in:
Brendan Allan 2024-02-09 21:20:51 +08:00 committed by GitHub
parent 6f28d8ec28
commit 177b2a23d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 156 additions and 69 deletions

3
Cargo.lock generated
View file

@ -7508,10 +7508,13 @@ dependencies = [
"prisma-client-rust",
"reqwest",
"rmp-serde",
"sd-core-sync",
"sd-file-path-helper",
"sd-prisma",
"sd-sync",
"sd-utils",
"serde",
"serde_json",
"thiserror",
"tokio",
"tokio-stream",

View file

@ -328,14 +328,14 @@ model Tag {
@@map("tag")
}
/// @relation(item: tag, group: object)
/// @relation(item: object, group: tag)
model TagOnObject {
tag_id Int
tag Tag @relation(fields: [tag_id], references: [id], onDelete: Restrict)
object_id Int
object Object @relation(fields: [object_id], references: [id], onDelete: Restrict)
tag_id Int
tag Tag @relation(fields: [tag_id], references: [id], onDelete: Restrict)
date_created DateTime?
@@id([tag_id, object_id])
@ -344,9 +344,9 @@ model TagOnObject {
//// Label ////
/// @shared(id: name)
model Label {
id Int @id @default(autoincrement())
pub_id Bytes @unique
name String @unique
date_created DateTime @default(now())
date_modified DateTime @default(now())
@ -356,15 +356,16 @@ model Label {
@@map("label")
}
/// @relation(item: object, group: label)
model LabelOnObject {
date_created DateTime @default(now())
label_id Int
label Label @relation(fields: [label_id], references: [id], onDelete: Restrict)
object_id Int
object Object @relation(fields: [object_id], references: [id], onDelete: Restrict)
label_id Int
label Label @relation(fields: [label_id], references: [id], onDelete: Restrict)
@@id([label_id, object_id])
@@map("label_on_object")
}

View file

@ -1,6 +1,10 @@
use crate::{invalidate_query, library::Library, object::media::thumbnail::get_indexed_thumb_key};
use sd_prisma::prisma::{label, label_on_object, object, SortOrder};
use sd_prisma::{
prisma::{label, label_on_object, object, SortOrder},
prisma_sync,
};
use sd_sync::OperationFactory;
use std::collections::BTreeMap;
@ -117,12 +121,26 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
"delete",
R.with2(library())
.mutation(|(_, library), label_id: i32| async move {
library
.db
let Library { db, sync, .. } = library.as_ref();
let label = db
.label()
.delete(label::id::equals(label_id))
.find_unique(label::id::equals(label_id))
.exec()
.await?;
.await?
.ok_or_else(|| {
rspc::Error::new(
rspc::ErrorCode::NotFound,
"Label not found".to_string(),
)
})?;
sync.write_op(
db,
sync.shared_delete(prisma_sync::label::SyncId { name: label.name }),
db.label().delete(label::id::equals(label_id)),
)
.await?;
invalidate_query!(library, "labels.list");

View file

@ -105,7 +105,7 @@ impl StatefulJob for MediaProcessorJobInit {
ctx: &WorkerContext,
data: &mut Option<Self::Data>,
) -> Result<JobInitOutput<Self::RunMetadata, Self::Step>, JobError> {
let Library { db, .. } = ctx.library.as_ref();
let Library { db, sync, .. } = ctx.library.as_ref();
let location_id = self.location.id;
let location_path =
@ -186,6 +186,7 @@ impl StatefulJob for MediaProcessorJobInit {
location_path.clone(),
file_paths_for_labeling,
Arc::clone(db),
sync.clone(),
)
.await;
@ -336,7 +337,11 @@ impl StatefulJob for MediaProcessorJobInit {
match ctx
.node
.image_labeller
.resume_batch(data.labeler_batch_token, Arc::clone(&ctx.library.db))
.resume_batch(
data.labeler_batch_token,
Arc::clone(&ctx.library.db),
ctx.library.sync.clone(),
)
.await
{
Ok(labels_rx) => labels_rx,

View file

@ -40,7 +40,7 @@ const BATCH_SIZE: usize = 10;
pub async fn shallow(
location: &location::Data,
sub_path: &PathBuf,
library @ Library { db, .. }: &Library,
library @ Library { db, sync, .. }: &Library,
#[cfg(feature = "ai")] regenerate_labels: bool,
node: &Node,
) -> Result<(), JobError> {
@ -116,6 +116,7 @@ pub async fn shallow(
location_path.clone(),
file_paths_for_labelling,
Arc::clone(db),
sync.clone(),
)
.await;

View file

@ -11,6 +11,8 @@ edition = { workspace = true }
[dependencies]
sd-prisma = { path = "../prisma" }
sd-core-sync = { path = "../../core/crates/sync" }
sd-sync = { path = "../sync" }
sd-utils = { path = "../utils" }
sd-file-path-helper = { path = "../file-path-helper" }
@ -24,6 +26,7 @@ prisma-client-rust = { workspace = true }
reqwest = { workspace = true, features = ["stream", "native-tls-vendored"] }
rmp-serde = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["fs"] }
tokio-stream = { workspace = true }

View file

@ -37,6 +37,7 @@ const PENDING_BATCHES_FILE: &str = "pending_image_labeler_batches.bin";
type ResumeBatchRequest = (
BatchToken,
Arc<PrismaClient>,
Arc<sd_core_sync::Manager>,
oneshot::Sender<Result<chan::Receiver<LabelerOutput>, ImageLabelerError>>,
);
@ -53,6 +54,7 @@ pub(super) struct Batch {
pub(super) output_tx: chan::Sender<LabelerOutput>,
pub(super) is_resumable: bool,
pub(super) db: Arc<PrismaClient>,
pub(super) sync: Arc<sd_core_sync::Manager>,
}
#[derive(Serialize, Deserialize, Debug)]
@ -165,6 +167,7 @@ impl ImageLabeler {
location_path: PathBuf,
file_paths: Vec<file_path_for_media_processor::Data>,
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
is_resumable: bool,
) -> (BatchToken, chan::Receiver<LabelerOutput>) {
let (tx, rx) = chan::bounded(usize::max(file_paths.len(), 1));
@ -180,6 +183,7 @@ impl ImageLabeler {
output_tx: tx,
is_resumable,
db,
sync,
})
.await
.is_err()
@ -201,8 +205,9 @@ impl ImageLabeler {
location_path: PathBuf,
file_paths: Vec<file_path_for_media_processor::Data>,
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
) -> chan::Receiver<LabelerOutput> {
self.new_batch_inner(location_id, location_path, file_paths, db, false)
self.new_batch_inner(location_id, location_path, file_paths, db, sync, false)
.await
.1
}
@ -214,8 +219,9 @@ impl ImageLabeler {
location_path: PathBuf,
file_paths: Vec<file_path_for_media_processor::Data>,
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
) -> (BatchToken, chan::Receiver<LabelerOutput>) {
self.new_batch_inner(location_id, location_path, file_paths, db, true)
self.new_batch_inner(location_id, location_path, file_paths, db, sync, true)
.await
}
@ -284,11 +290,12 @@ impl ImageLabeler {
&self,
token: BatchToken,
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
) -> Result<chan::Receiver<LabelerOutput>, ImageLabelerError> {
let (tx, rx) = oneshot::channel();
self.resume_batch_tx
.send((token, db, tx))
.send((token, db, sync, tx))
.await
.expect("critical error: image labeler communication channel unexpectedly closed");
@ -334,6 +341,7 @@ async fn actor_loop(
ResumeBatch(
BatchToken,
Arc<PrismaClient>,
Arc<sd_core_sync::Manager>,
oneshot::Sender<Result<chan::Receiver<LabelerOutput>, ImageLabelerError>>,
),
UpdateModel(
@ -350,7 +358,8 @@ async fn actor_loop(
let mut msg_stream = pin!((
new_batches_rx.map(StreamMessage::NewBatch),
resume_batch_rx.map(|(token, db, done_tx)| StreamMessage::ResumeBatch(token, db, done_tx)),
resume_batch_rx
.map(|(token, db, sync, done_tx)| StreamMessage::ResumeBatch(token, db, sync, done_tx)),
update_model_rx.map(|(model, done_tx)| StreamMessage::UpdateModel(model, done_tx)),
done_rx.clone().map(StreamMessage::BatchDone),
shutdown_rx.map(StreamMessage::Shutdown)
@ -376,7 +385,7 @@ async fn actor_loop(
}
}
StreamMessage::ResumeBatch(token, db, resume_done_tx) => {
StreamMessage::ResumeBatch(token, db, sync, resume_done_tx) => {
let resume_result = if let Some((batch, output_rx)) =
to_resume_batches.write().await.remove(&token).map(
|ResumableBatch {
@ -390,6 +399,7 @@ async fn actor_loop(
Batch {
token,
db,
sync,
output_tx,
location_id,
location_path,

View file

@ -1,9 +1,13 @@
use sd_file_path_helper::{file_path_for_media_processor, IsolatedFilePathData};
use sd_prisma::prisma::{file_path, label, label_on_object, object, PrismaClient};
use sd_prisma::{
prisma::{file_path, label, label_on_object, object, PrismaClient},
prisma_sync,
};
use sd_sync::OperationFactory;
use sd_utils::{db::MissingFieldError, error::FileIOError};
use std::{
collections::{HashMap, HashSet, VecDeque},
collections::{BTreeMap, HashMap, HashSet, VecDeque},
path::{Path, PathBuf},
sync::Arc,
};
@ -12,12 +16,12 @@ use async_channel as chan;
use chrono::{DateTime, FixedOffset, Utc};
use futures_concurrency::future::{Join, Race};
use image::ImageFormat;
use serde_json::json;
use tokio::{
fs, spawn,
sync::{oneshot, OwnedRwLockReadGuard, OwnedSemaphorePermit, RwLock, Semaphore},
};
use tracing::{error, warn};
use uuid::Uuid;
use super::{actor::Batch, model::ModelAndSession, BatchToken, ImageLabelerError, LabelerOutput};
@ -65,6 +69,7 @@ pub(super) async fn spawned_processing(
file_paths,
output_tx,
db,
sync,
is_resumable,
}: Batch,
available_parallelism: usize,
@ -213,6 +218,7 @@ pub(super) async fn spawned_processing(
format,
(output_tx.clone(), completed_tx.clone()),
Arc::clone(&db),
sync.clone(),
permit,
)));
}
@ -247,6 +253,7 @@ pub(super) async fn spawned_processing(
.collect(),
output_tx,
db,
sync: sync.clone(),
is_resumable,
})
};
@ -289,6 +296,7 @@ async fn spawned_process_single_file(
chan::Sender<file_path::id::Type>,
),
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
_permit: OwnedSemaphorePermit,
) {
let image =
@ -338,7 +346,7 @@ async fn spawned_process_single_file(
}
};
let (has_new_labels, result) = match assign_labels(object_id, labels, &db).await {
let (has_new_labels, result) = match assign_labels(object_id, labels, &db, &sync).await {
Ok(has_new_labels) => (has_new_labels, Ok(())),
Err(e) => (false, Err(e)),
};
@ -386,7 +394,16 @@ pub async fn assign_labels(
object_id: object::id::Type,
mut labels: HashSet<String>,
db: &PrismaClient,
sync: &sd_core_sync::Manager,
) -> Result<bool, ImageLabelerError> {
let object = db
.object()
.find_unique(object::id::equals(object_id))
.select(object::select!({ pub_id }))
.exec()
.await?
.unwrap();
let mut has_new_labels = false;
let mut labels_ids = db
@ -399,53 +416,72 @@ pub async fn assign_labels(
.map(|label| {
labels.remove(&label.name);
label.id
(label.id, label.name)
})
.collect::<Vec<_>>();
labels_ids.reserve(labels.len());
.collect::<BTreeMap<_, _>>();
let date_created: DateTime<FixedOffset> = Utc::now().into();
if !labels.is_empty() {
labels_ids.extend(
db._batch(
labels
.into_iter()
.map(|name| {
db.label()
.create(
Uuid::new_v4().as_bytes().to_vec(),
name,
vec![label::date_created::set(date_created)],
)
.select(label::select!({ id }))
})
.collect::<Vec<_>>(),
)
.await?
let mut sync_params = Vec::with_capacity(labels.len() * 2);
let db_params = labels
.into_iter()
.map(|label| label.id),
.map(|name| {
sync_params.extend(sync.shared_create(
prisma_sync::label::SyncId { name: name.clone() },
[(label::date_created::NAME, json!(&date_created))],
));
db.label()
.create(name, vec![label::date_created::set(date_created)])
.select(label::select!({ id name }))
})
.collect::<Vec<_>>();
labels_ids.extend(
sync.write_ops(db, (sync_params, db_params))
.await?
.into_iter()
.map(|l| (l.id, l.name)),
);
has_new_labels = true;
}
db.label_on_object()
.create_many(
labels_ids
.into_iter()
.map(|label_id| {
label_on_object::create_unchecked(
label_id,
object_id,
vec![label_on_object::date_created::set(date_created)],
)
})
.collect(),
)
.skip_duplicates()
.exec()
.await?;
let mut sync_params = Vec::with_capacity(labels_ids.len() * 2);
let db_params: Vec<_> = labels_ids
.into_iter()
.map(|(label_id, name)| {
sync_params.extend(sync.relation_create(
prisma_sync::label_on_object::SyncId {
label: prisma_sync::label::SyncId { name },
object: prisma_sync::object::SyncId {
pub_id: object.pub_id.clone(),
},
},
[],
));
label_on_object::create_unchecked(
label_id,
object_id,
vec![label_on_object::date_created::set(date_created)],
)
})
.collect();
sync.write_ops(
&db,
(
sync_params,
db.label_on_object()
.create_many(db_params)
.skip_duplicates(),
),
)
.await?;
Ok(has_new_labels)
}

View file

@ -75,9 +75,10 @@ pub fn r#enum(models: Vec<ModelWithSyncType>) -> TokenStream {
ModelSyncType::Relation { item, group } => {
let compound_id = format_ident!(
"{}",
item.fields()
group
.fields()
.unwrap()
.chain(group.fields().unwrap())
.chain(item.fields().unwrap())
.map(|f| f.name())
.collect::<Vec<_>>()
.join("_")
@ -85,11 +86,19 @@ pub fn r#enum(models: Vec<ModelWithSyncType>) -> TokenStream {
let db_batch_items = {
let batch_item = |item: &RelationFieldWalker| {
let item_model_sync_id_field_name_snake = models
.iter()
.find(|m| m.0.name() == item.related_model().name())
.and_then(|(m, sync)| sync.as_ref())
.map(|sync| snake_ident(sync.sync_id()[0].name()))
.unwrap();
let item_model_name_snake = snake_ident(item.related_model().name());
let item_field_name_snake = snake_ident(item.name());
quote!(db.#item_model_name_snake().find_unique(
prisma::#item_model_name_snake::pub_id::equals(id.#item_field_name_snake.pub_id.clone())
prisma::#item_model_name_snake::#item_model_sync_id_field_name_snake::equals(
id.#item_field_name_snake.#item_model_sync_id_field_name_snake.clone()
)
))
};
@ -117,7 +126,7 @@ pub fn r#enum(models: Vec<ModelWithSyncType>) -> TokenStream {
panic!("item and group not found!");
};
let id = prisma::tag_on_object::#compound_id(item.id, group.id);
let id = prisma::#model_name_snake::#compound_id(item.id, group.id);
match data {
sd_sync::CRDTOperationData::Create => {

View file

@ -44,6 +44,7 @@ export const uniqueId = (item: ExplorerItem | { pub_id: number[] }) => {
case 'NonIndexedPath':
return item.item.path;
case 'SpacedropPeer':
case 'Label':
return item.item.name;
default:
return pubIdToString(item.item.pub_id);

View file

@ -19,7 +19,7 @@ export type Procedures = {
{ key: "jobs.isActive", input: LibraryArgs<null>, result: boolean } |
{ key: "jobs.reports", input: LibraryArgs<null>, result: JobGroup[] } |
{ key: "labels.count", input: LibraryArgs<null>, result: number } |
{ key: "labels.get", input: LibraryArgs<number>, result: { id: number; pub_id: number[]; name: string; date_created: string; date_modified: string } | null } |
{ key: "labels.get", input: LibraryArgs<number>, result: { id: number; name: string; date_created: string; date_modified: string } | null } |
{ key: "labels.getForObject", input: LibraryArgs<number>, result: Label[] } |
{ key: "labels.getWithObjects", input: LibraryArgs<number[]>, result: { [key in number]: { date_created: string; object: { id: number } }[] } } |
{ key: "labels.list", input: LibraryArgs<null>, result: Label[] } |
@ -355,9 +355,9 @@ export type KindStatistic = { kind: number; name: string; count: number; total_b
export type KindStatistics = { statistics: KindStatistic[] }
export type Label = { id: number; pub_id: number[]; name: string; date_created: string; date_modified: string }
export type Label = { id: number; name: string; date_created: string; date_modified: string }
export type LabelWithObjects = { id: number; pub_id: number[]; name: string; date_created: string; date_modified: string; label_objects: { object: { id: number; file_paths: FilePath[] } }[] }
export type LabelWithObjects = { id: number; name: string; date_created: string; date_modified: string; label_objects: { object: { id: number; file_paths: FilePath[] } }[] }
/**
* Can wrap a query argument to require it to contain a `library_id` and provide helpers for working with libraries.