mirror of
https://github.com/spacedriveapp/spacedrive
synced 2024-07-04 12:13:27 +00:00
[ENG-1595] Fix some problems with the AI system (#2063)
* Fix some problems with the AI system - Fix downloading model using an internal rust string representation as path for the model file - Fix Linux loading onnx shared lib from a hardcoded path - Fix App should not crash when the AI system fails to start - Fix sd-server failing to start due to onnxruntime incorrect linking - Some extra clippy auto fixes * Use latest ort * Fix dangling sd_ai reference - Use entrypoint.sh to initilize container * Fix server Dockerfile - Fix cargo warning * Workaround intro video breaking onboarding for the web version * Fix rebase
This commit is contained in:
parent
74e2d23c11
commit
72efcc9f62
|
@ -1,6 +1,9 @@
|
||||||
[env]
|
[env]
|
||||||
PROTOC = { force = true, value = "{{{protoc}}}" }
|
PROTOC = { force = true, value = "{{{protoc}}}" }
|
||||||
FFMPEG_DIR = { force = true, value = "{{{nativeDeps}}}" }
|
FFMPEG_DIR = { force = true, value = "{{{nativeDeps}}}" }
|
||||||
|
{{#isLinux}}
|
||||||
|
ORT_LIB_LOCATION = { force = true, value = "{{{nativeDeps}}}/lib" }
|
||||||
|
{{/isLinux}}
|
||||||
OPENSSL_STATIC = { force = true, value = "1" }
|
OPENSSL_STATIC = { force = true, value = "1" }
|
||||||
OPENSSL_NO_VENDOR = { force = true, value = "0" }
|
OPENSSL_NO_VENDOR = { force = true, value = "0" }
|
||||||
OPENSSL_RUST_USE_NASM = { force = true, value = "1" }
|
OPENSSL_RUST_USE_NASM = { force = true, value = "1" }
|
||||||
|
|
32
Cargo.lock
generated
32
Cargo.lock
generated
|
@ -1338,6 +1338,15 @@ dependencies = [
|
||||||
"toml 0.7.8",
|
"toml 0.7.8",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "castaway"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8a17ed5635fc8536268e5d4de1e22e81ac34419e5f052d4d51f4e01dcc263fcc"
|
||||||
|
dependencies = [
|
||||||
|
"rustversion",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.83"
|
version = "1.0.83"
|
||||||
|
@ -1584,6 +1593,19 @@ dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "compact_str"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f"
|
||||||
|
dependencies = [
|
||||||
|
"castaway",
|
||||||
|
"cfg-if",
|
||||||
|
"itoa 1.0.10",
|
||||||
|
"ryu",
|
||||||
|
"static_assertions",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "concurrent-queue"
|
name = "concurrent-queue"
|
||||||
version = "2.4.0"
|
version = "2.4.0"
|
||||||
|
@ -5584,14 +5606,14 @@ checksum = "a86ed3f5f244b372d6b1a00b72ef7f8876d0bc6a78a4c9985c53614041512063"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ort"
|
name = "ort"
|
||||||
version = "2.0.0-alpha.2"
|
version = "2.0.0-rc.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1a094a17bfe4f9eb561bfdf4454b8d0f6f89deaf9a5a572a1ef29c29ce708627"
|
checksum = "f8e5caf4eb2ead4bc137c3ff4e347940e3e556ceb11a4180627f04b63d7342dd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"compact_str",
|
||||||
"half",
|
"half",
|
||||||
"libloading 0.8.1",
|
"libloading 0.8.1",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"once_cell",
|
|
||||||
"ort-sys",
|
"ort-sys",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -5599,9 +5621,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ort-sys"
|
name = "ort-sys"
|
||||||
version = "2.0.0-alpha.3"
|
version = "2.0.0-rc.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "846ab4ca6873d26ac40f4bdfa8375e5e4547117693b06de5b7903e0c7471d1b7"
|
checksum = "f48b5623df2187e0db543ecb2032a6a999081086b7ffddd318000c00b23ace46"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"flate2",
|
"flate2",
|
||||||
"sha2 0.10.8",
|
"sha2 0.10.8",
|
||||||
|
|
|
@ -13,7 +13,6 @@ use std::{
|
||||||
|
|
||||||
use sd_core::{Node, NodeError};
|
use sd_core::{Node, NodeError};
|
||||||
|
|
||||||
use clear_localstorage::clear_localstorage;
|
|
||||||
use sd_fda::DiskAccess;
|
use sd_fda::DiskAccess;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tauri::{
|
use tauri::{
|
||||||
|
|
|
@ -77,6 +77,9 @@ RUN cargo build --features assets --release -p sd-server
|
||||||
# Debug just means it includes busybox, and we need its tools both for the entrypoint.sh script and during runtime
|
# Debug just means it includes busybox, and we need its tools both for the entrypoint.sh script and during runtime
|
||||||
FROM gcr.io/distroless/base-debian12:debug
|
FROM gcr.io/distroless/base-debian12:debug
|
||||||
|
|
||||||
|
RUN [ "/busybox/ln", "-s", "/busybox/sh", "/bin/sh" ]
|
||||||
|
RUN ln -s /busybox/env /usr/bin/env
|
||||||
|
|
||||||
ENV TZ=UTC \
|
ENV TZ=UTC \
|
||||||
PUID=1000 \
|
PUID=1000 \
|
||||||
PGID=1000 \
|
PGID=1000 \
|
||||||
|
@ -86,9 +89,10 @@ ENV TZ=UTC \
|
||||||
LANGUAGE=en \
|
LANGUAGE=en \
|
||||||
DATA_DIR=/data
|
DATA_DIR=/data
|
||||||
|
|
||||||
COPY --from=server /srv/spacedrive/target/release/sd-server /usr/bin/
|
COPY --from=server --chmod=755 /srv/spacedrive/target/release/sd-server /usr/bin/
|
||||||
COPY --from=server /lib/x86_64-linux-gnu/libgcc_s.so.1 /usr/lib/
|
COPY --from=server --chmod=755 /lib/x86_64-linux-gnu/libgcc_s.so.1 /usr/lib/
|
||||||
COPY --from=server /srv/spacedrive/apps/.deps/lib /usr/lib/spacedrive
|
COPY --from=server --chmod=755 /srv/spacedrive/apps/.deps/lib /usr/lib/spacedrive
|
||||||
|
ADD --chmod=755 https://github.com/spacedriveapp/native-deps/releases/download/yolo-2024-02-07/yolov8s.onnx /usr/share/spacedrive/models/yolov8s.onnx
|
||||||
|
|
||||||
COPY --chmod=755 entrypoint.sh /usr/bin/
|
COPY --chmod=755 entrypoint.sh /usr/bin/
|
||||||
|
|
||||||
|
@ -99,7 +103,7 @@ EXPOSE 8080
|
||||||
VOLUME [ "/data" ]
|
VOLUME [ "/data" ]
|
||||||
|
|
||||||
# Run the CLI when the container is started
|
# Run the CLI when the container is started
|
||||||
ENTRYPOINT [ "sd-server" ]
|
ENTRYPOINT [ "entrypoint.sh" ]
|
||||||
|
|
||||||
LABEL org.opencontainers.image.title="Spacedrive Server" \
|
LABEL org.opencontainers.image.title="Spacedrive Server" \
|
||||||
org.opencontainers.image.source="https://github.com/spacedriveapp/spacedrive"
|
org.opencontainers.image.source="https://github.com/spacedriveapp/spacedrive"
|
||||||
|
|
|
@ -36,4 +36,4 @@ fi
|
||||||
echo "Fix spacedrive's directories permissions"
|
echo "Fix spacedrive's directories permissions"
|
||||||
chown -R "${PUID}:${PGID}" /data
|
chown -R "${PUID}:${PGID}" /data
|
||||||
|
|
||||||
exec su spacedrive -s /bin/server -- "$@"
|
exec su spacedrive -s /usr/bin/sd-server -- "$@"
|
||||||
|
|
|
@ -198,7 +198,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
sync.write_op(
|
sync.write_op(
|
||||||
&db,
|
db,
|
||||||
sync.shared_update(
|
sync.shared_update(
|
||||||
prisma_sync::object::SyncId {
|
prisma_sync::object::SyncId {
|
||||||
pub_id: object.pub_id,
|
pub_id: object.pub_id,
|
||||||
|
@ -244,7 +244,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
sync.write_op(
|
sync.write_op(
|
||||||
&db,
|
db,
|
||||||
sync.shared_update(
|
sync.shared_update(
|
||||||
prisma_sync::object::SyncId {
|
prisma_sync::object::SyncId {
|
||||||
pub_id: object.pub_id,
|
pub_id: object.pub_id,
|
||||||
|
@ -324,7 +324,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||||
.unzip();
|
.unzip();
|
||||||
|
|
||||||
sync.write_ops(
|
sync.write_ops(
|
||||||
&db,
|
db,
|
||||||
(
|
(
|
||||||
sync_params,
|
sync_params,
|
||||||
db.object().update_many(
|
db.object().update_many(
|
||||||
|
@ -366,7 +366,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||||
})
|
})
|
||||||
.unzip();
|
.unzip();
|
||||||
sync.write_ops(
|
sync.write_ops(
|
||||||
&db,
|
db,
|
||||||
(
|
(
|
||||||
sync_params,
|
sync_params,
|
||||||
db.object().update_many(
|
db.object().update_many(
|
||||||
|
|
|
@ -96,21 +96,30 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||||
if let Some(model) = new_model {
|
if let Some(model) = new_model {
|
||||||
let version = model.version().to_string();
|
let version = model.version().to_string();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let notification = if let Err(e) =
|
let notification =
|
||||||
node.image_labeller.change_model(model).await
|
if let Some(image_labeller) = node.image_labeller.as_ref() {
|
||||||
{
|
if let Err(e) = image_labeller.change_model(model).await {
|
||||||
NotificationData {
|
NotificationData {
|
||||||
|
title: String::from(
|
||||||
|
"Failed to change image detection model",
|
||||||
|
),
|
||||||
|
content: format!("Error: {e}"),
|
||||||
|
kind: NotificationKind::Error,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
NotificationData {
|
||||||
|
title: String::from("Model download completed"),
|
||||||
|
content: format!("Sucessfuly loaded model: {version}"),
|
||||||
|
kind: NotificationKind::Success,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
NotificationData {
|
||||||
title: String::from("Failed to change image detection model"),
|
title: String::from("Failed to change image detection model"),
|
||||||
content: format!("Error: {e}"),
|
content: "The AI system is disabled due to a previous error. Contact support for help.".to_string(),
|
||||||
kind: NotificationKind::Error,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
NotificationData {
|
|
||||||
title: String::from("Model download completed"),
|
|
||||||
content: format!("Sucessfuly loaded model: {version}"),
|
|
||||||
kind: NotificationKind::Success,
|
kind: NotificationKind::Success,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
node.emit_notification(notification, None).await;
|
node.emit_notification(notification, None).await;
|
||||||
});
|
});
|
||||||
|
|
|
@ -210,7 +210,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||||
.unzip();
|
.unzip();
|
||||||
|
|
||||||
sync.write_ops(
|
sync.write_ops(
|
||||||
&db,
|
db,
|
||||||
(
|
(
|
||||||
sync_params,
|
sync_params,
|
||||||
db.saved_search()
|
db.saved_search()
|
||||||
|
@ -242,7 +242,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
sync.write_op(
|
sync.write_op(
|
||||||
&db,
|
db,
|
||||||
sync.shared_delete(prisma_sync::saved_search::SyncId {
|
sync.shared_delete(prisma_sync::saved_search::SyncId {
|
||||||
pub_id: search.pub_id,
|
pub_id: search.pub_id,
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -79,7 +79,7 @@ pub async fn run_actor(
|
||||||
|uuid| sd_cloud_api::library::message_collections::get::InstanceTimestamp {
|
|uuid| sd_cloud_api::library::message_collections::get::InstanceTimestamp {
|
||||||
instance_uuid: *uuid,
|
instance_uuid: *uuid,
|
||||||
from_time: cloud_timestamps
|
from_time: cloud_timestamps
|
||||||
.get(&uuid)
|
.get(uuid)
|
||||||
.cloned()
|
.cloned()
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.as_u64()
|
.as_u64()
|
||||||
|
|
|
@ -67,7 +67,7 @@ pub struct Node {
|
||||||
pub env: Arc<env::Env>,
|
pub env: Arc<env::Env>,
|
||||||
pub http: reqwest::Client,
|
pub http: reqwest::Client,
|
||||||
#[cfg(feature = "ai")]
|
#[cfg(feature = "ai")]
|
||||||
pub image_labeller: ImageLabeler,
|
pub image_labeller: Option<ImageLabeler>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Debug for Node {
|
impl fmt::Debug for Node {
|
||||||
|
@ -115,31 +115,35 @@ impl Node {
|
||||||
let libraries = library::Libraries::new(data_dir.join("libraries")).await?;
|
let libraries = library::Libraries::new(data_dir.join("libraries")).await?;
|
||||||
|
|
||||||
let (p2p, p2p_actor) = p2p::P2PManager::new(config.clone(), libraries.clone()).await?;
|
let (p2p, p2p_actor) = p2p::P2PManager::new(config.clone(), libraries.clone()).await?;
|
||||||
let node = Arc::new(Node {
|
let node =
|
||||||
data_dir: data_dir.to_path_buf(),
|
Arc::new(Node {
|
||||||
jobs,
|
data_dir: data_dir.to_path_buf(),
|
||||||
locations,
|
jobs,
|
||||||
notifications: notifications::Notifications::new(),
|
locations,
|
||||||
p2p,
|
notifications: notifications::Notifications::new(),
|
||||||
thumbnailer: Thumbnailer::new(
|
p2p,
|
||||||
data_dir,
|
thumbnailer: Thumbnailer::new(
|
||||||
libraries.clone(),
|
data_dir,
|
||||||
event_bus.0.clone(),
|
libraries.clone(),
|
||||||
config.preferences_watcher(),
|
event_bus.0.clone(),
|
||||||
)
|
config.preferences_watcher(),
|
||||||
.await,
|
)
|
||||||
config,
|
.await,
|
||||||
event_bus,
|
config,
|
||||||
libraries,
|
event_bus,
|
||||||
files_over_p2p_flag: Arc::new(AtomicBool::new(false)),
|
libraries,
|
||||||
cloud_sync_flag: Arc::new(AtomicBool::new(false)),
|
files_over_p2p_flag: Arc::new(AtomicBool::new(false)),
|
||||||
http: reqwest::Client::new(),
|
cloud_sync_flag: Arc::new(AtomicBool::new(false)),
|
||||||
env,
|
http: reqwest::Client::new(),
|
||||||
#[cfg(feature = "ai")]
|
env,
|
||||||
image_labeller: ImageLabeler::new(YoloV8::model(image_labeler_version)?, data_dir)
|
#[cfg(feature = "ai")]
|
||||||
.await
|
image_labeller: ImageLabeler::new(YoloV8::model(image_labeler_version)?, data_dir)
|
||||||
.map_err(sd_ai::Error::from)?,
|
.await
|
||||||
});
|
.map_err(|e| {
|
||||||
|
error!("Failed to initialize image labeller. AI features will be disabled: {e:#?}");
|
||||||
|
})
|
||||||
|
.ok(),
|
||||||
|
});
|
||||||
|
|
||||||
// Restore backend feature flags
|
// Restore backend feature flags
|
||||||
for feature in node.config.get().await.features {
|
for feature in node.config.get().await.features {
|
||||||
|
@ -227,7 +231,9 @@ impl Node {
|
||||||
self.jobs.shutdown().await;
|
self.jobs.shutdown().await;
|
||||||
self.p2p.shutdown().await;
|
self.p2p.shutdown().await;
|
||||||
#[cfg(feature = "ai")]
|
#[cfg(feature = "ai")]
|
||||||
self.image_labeller.shutdown().await;
|
if let Some(image_labeller) = &self.image_labeller {
|
||||||
|
image_labeller.shutdown().await;
|
||||||
|
}
|
||||||
info!("Spacedrive Core shutdown successful!");
|
info!("Spacedrive Core shutdown successful!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -619,7 +619,7 @@ impl Libraries {
|
||||||
|
|
||||||
let _ = this
|
let _ = this
|
||||||
.edit(
|
.edit(
|
||||||
library.id.clone(),
|
library.id,
|
||||||
None,
|
None,
|
||||||
MaybeUndefined::Undefined,
|
MaybeUndefined::Undefined,
|
||||||
MaybeUndefined::Null,
|
MaybeUndefined::Null,
|
||||||
|
|
|
@ -917,7 +917,7 @@ pub(super) async fn remove_by_file_path(
|
||||||
.await?;
|
.await?;
|
||||||
} else {
|
} else {
|
||||||
sync.write_op(
|
sync.write_op(
|
||||||
&db,
|
db,
|
||||||
sync.shared_delete(prisma_sync::file_path::SyncId {
|
sync.shared_delete(prisma_sync::file_path::SyncId {
|
||||||
pub_id: file_path.pub_id.clone(),
|
pub_id: file_path.pub_id.clone(),
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
use std::io::Error;
|
use std::io::Error;
|
||||||
use std::process::Command;
|
|
||||||
use std::str;
|
use std::str;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -37,6 +36,8 @@ impl HardwareModel {
|
||||||
pub fn get_hardware_model_name() -> Result<HardwareModel, Error> {
|
pub fn get_hardware_model_name() -> Result<HardwareModel, Error> {
|
||||||
#[cfg(target_os = "macos")]
|
#[cfg(target_os = "macos")]
|
||||||
{
|
{
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
let output = Command::new("system_profiler")
|
let output = Command::new("system_profiler")
|
||||||
.arg("SPHardwareDataType")
|
.arg("SPHardwareDataType")
|
||||||
.output()?;
|
.output()?;
|
||||||
|
|
|
@ -88,7 +88,7 @@ impl StatefulJob for FileDeleterJobInit {
|
||||||
step.full_path.display()
|
step.full_path.display()
|
||||||
);
|
);
|
||||||
sync.write_op(
|
sync.write_op(
|
||||||
&db,
|
db,
|
||||||
sync.shared_delete(prisma_sync::file_path::SyncId {
|
sync.shared_delete(prisma_sync::file_path::SyncId {
|
||||||
pub_id: step.file_path.pub_id.clone(),
|
pub_id: step.file_path.pub_id.clone(),
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -178,17 +178,21 @@ impl StatefulJob for MediaProcessorJobInit {
|
||||||
let total_files_for_labeling = file_paths_for_labeling.len();
|
let total_files_for_labeling = file_paths_for_labeling.len();
|
||||||
|
|
||||||
#[cfg(feature = "ai")]
|
#[cfg(feature = "ai")]
|
||||||
let (labeler_batch_token, labels_rx) = ctx
|
let (labeler_batch_token, labels_rx) =
|
||||||
.node
|
if let Some(image_labeller) = ctx.node.image_labeller.as_ref() {
|
||||||
.image_labeller
|
let (labeler_batch_token, labels_rx) = image_labeller
|
||||||
.new_resumable_batch(
|
.new_resumable_batch(
|
||||||
location_id,
|
location_id,
|
||||||
location_path.clone(),
|
location_path.clone(),
|
||||||
file_paths_for_labeling,
|
file_paths_for_labeling,
|
||||||
Arc::clone(db),
|
Arc::clone(db),
|
||||||
sync.clone(),
|
sync.clone(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
(labeler_batch_token, Some(labels_rx))
|
||||||
|
} else {
|
||||||
|
(uuid::Uuid::new_v4(), None)
|
||||||
|
};
|
||||||
|
|
||||||
let total_files = file_paths.len();
|
let total_files = file_paths.len();
|
||||||
|
|
||||||
|
@ -240,7 +244,7 @@ impl StatefulJob for MediaProcessorJobInit {
|
||||||
#[cfg(feature = "ai")]
|
#[cfg(feature = "ai")]
|
||||||
labeler_batch_token,
|
labeler_batch_token,
|
||||||
#[cfg(feature = "ai")]
|
#[cfg(feature = "ai")]
|
||||||
maybe_labels_rx: Some(labels_rx),
|
maybe_labels_rx: labels_rx,
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
|
@ -323,6 +327,12 @@ impl StatefulJob for MediaProcessorJobInit {
|
||||||
|
|
||||||
#[cfg(feature = "ai")]
|
#[cfg(feature = "ai")]
|
||||||
MediaProcessorJobStep::WaitLabels(total_labels) => {
|
MediaProcessorJobStep::WaitLabels(total_labels) => {
|
||||||
|
let Some(image_labeller) = ctx.node.image_labeller.as_ref() else {
|
||||||
|
let err = "AI system is disabled due to a previous error, skipping labels job";
|
||||||
|
error!(err);
|
||||||
|
return Ok(JobRunErrors(vec![err.to_string()]).into());
|
||||||
|
};
|
||||||
|
|
||||||
ctx.progress(vec![
|
ctx.progress(vec![
|
||||||
JobReportUpdate::TaskCount(*total_labels),
|
JobReportUpdate::TaskCount(*total_labels),
|
||||||
JobReportUpdate::Phase("labels".to_string()),
|
JobReportUpdate::Phase("labels".to_string()),
|
||||||
|
@ -334,9 +344,7 @@ impl StatefulJob for MediaProcessorJobInit {
|
||||||
let mut labels_rx = pin!(if let Some(labels_rx) = data.maybe_labels_rx.clone() {
|
let mut labels_rx = pin!(if let Some(labels_rx) = data.maybe_labels_rx.clone() {
|
||||||
labels_rx
|
labels_rx
|
||||||
} else {
|
} else {
|
||||||
match ctx
|
match image_labeller
|
||||||
.node
|
|
||||||
.image_labeller
|
|
||||||
.resume_batch(
|
.resume_batch(
|
||||||
data.labeler_batch_token,
|
data.labeler_batch_token,
|
||||||
Arc::clone(&ctx.library.db),
|
Arc::clone(&ctx.library.db),
|
||||||
|
|
|
@ -109,16 +109,18 @@ pub async fn shallow(
|
||||||
);
|
);
|
||||||
|
|
||||||
#[cfg(feature = "ai")]
|
#[cfg(feature = "ai")]
|
||||||
let labels_rx = node
|
// Check if we have an image labeller and has_labels then enqueue a new batch
|
||||||
.image_labeller
|
let labels_rx = node.image_labeller.as_ref().and_then(|image_labeller| {
|
||||||
.new_batch(
|
has_labels.then(|| {
|
||||||
location_id,
|
image_labeller.new_batch(
|
||||||
location_path.clone(),
|
location_id,
|
||||||
file_paths_for_labelling,
|
location_path.clone(),
|
||||||
Arc::clone(db),
|
file_paths_for_labelling,
|
||||||
sync.clone(),
|
Arc::clone(db),
|
||||||
)
|
sync.clone(),
|
||||||
.await;
|
)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
let mut run_metadata = MediaProcessorMetadata::default();
|
let mut run_metadata = MediaProcessorMetadata::default();
|
||||||
|
|
||||||
|
@ -144,27 +146,30 @@ pub async fn shallow(
|
||||||
#[cfg(feature = "ai")]
|
#[cfg(feature = "ai")]
|
||||||
{
|
{
|
||||||
if has_labels {
|
if has_labels {
|
||||||
labels_rx
|
if let Some(labels_rx) = labels_rx {
|
||||||
.for_each(
|
labels_rx
|
||||||
|LabelerOutput {
|
.await
|
||||||
file_path_id,
|
.for_each(
|
||||||
has_new_labels,
|
|LabelerOutput {
|
||||||
result,
|
file_path_id,
|
||||||
}| async move {
|
has_new_labels,
|
||||||
if let Err(e) = result {
|
result,
|
||||||
error!(
|
}| async move {
|
||||||
|
if let Err(e) = result {
|
||||||
|
error!(
|
||||||
"Failed to generate labels <file_path_id='{file_path_id}'>: {e:#?}"
|
"Failed to generate labels <file_path_id='{file_path_id}'>: {e:#?}"
|
||||||
);
|
);
|
||||||
} else if has_new_labels {
|
} else if has_new_labels {
|
||||||
// invalidate_query!(library, "labels.count"); // TODO: This query doesn't exist on main yet
|
// invalidate_query!(library, "labels.count"); // TODO: This query doesn't exist on main yet
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
invalidate_query!(library, "labels.list");
|
invalidate_query!(library, "labels.list");
|
||||||
invalidate_query!(library, "labels.getForObject");
|
invalidate_query!(library, "labels.getForObject");
|
||||||
invalidate_query!(library, "labels.getWithObjects");
|
invalidate_query!(library, "labels.getWithObjects");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ impl LibraryServices {
|
||||||
inserted = true;
|
inserted = true;
|
||||||
Arc::new(
|
Arc::new(
|
||||||
Service::new(
|
Service::new(
|
||||||
String::from_utf8_lossy(&base91::slice_encode(&*library.id.as_bytes())),
|
String::from_utf8_lossy(&base91::slice_encode(library.id.as_bytes())),
|
||||||
manager.manager.clone(),
|
manager.manager.clone(),
|
||||||
)
|
)
|
||||||
.expect("error creating service with duplicate service name"),
|
.expect("error creating service with duplicate service name"),
|
||||||
|
|
|
@ -42,21 +42,20 @@ half = { version = "2.1", features = ['num-traits'] }
|
||||||
# "gpu" means CUDA or TensorRT EP. Thus, the ort crate cannot download them at build time.
|
# "gpu" means CUDA or TensorRT EP. Thus, the ort crate cannot download them at build time.
|
||||||
# Ref: https://github.com/pykeio/ort/blob/d7defd1862969b4b44f7f3f4b9c72263690bd67b/build.rs#L148
|
# Ref: https://github.com/pykeio/ort/blob/d7defd1862969b4b44f7f3f4b9c72263690bd67b/build.rs#L148
|
||||||
[target.'cfg(target_os = "windows")'.dependencies]
|
[target.'cfg(target_os = "windows")'.dependencies]
|
||||||
ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
|
ort = { version = "=2.0.0-rc.0", default-features = false, features = [
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"half",
|
"half",
|
||||||
"load-dynamic",
|
"load-dynamic",
|
||||||
"directml",
|
"directml",
|
||||||
] }
|
] }
|
||||||
[target.'cfg(target_os = "linux")'.dependencies]
|
[target.'cfg(target_os = "linux")'.dependencies]
|
||||||
ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
|
ort = { version = "=2.0.0-rc.0", default-features = false, features = [
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"half",
|
"half",
|
||||||
"load-dynamic",
|
|
||||||
"xnnpack",
|
"xnnpack",
|
||||||
] }
|
] }
|
||||||
# [target.'cfg(target_os = "android")'.dependencies]
|
# [target.'cfg(target_os = "android")'.dependencies]
|
||||||
# ort = { version = "2.0.0-alpha.2", default-features = false, features = [
|
# ort = { version = "=2.0.0-rc.0", default-features = false, features = [
|
||||||
# "half",
|
# "half",
|
||||||
# "load-dynamic",
|
# "load-dynamic",
|
||||||
# "qnn",
|
# "qnn",
|
||||||
|
@ -66,7 +65,7 @@ ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
|
||||||
# "armnn",
|
# "armnn",
|
||||||
# ] }
|
# ] }
|
||||||
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
|
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
|
||||||
ort = { version = "=2.0.0-alpha.2", features = [
|
ort = { version = "=2.0.0-rc.0", features = [
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"half",
|
"half",
|
||||||
"load-dynamic",
|
"load-dynamic",
|
||||||
|
|
|
@ -29,9 +29,7 @@ pub enum ModelSource {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Model: Send + Sync + 'static {
|
pub trait Model: Send + Sync + 'static {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> &'static str;
|
||||||
std::any::type_name::<Self>()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn origin(&self) -> &ModelSource;
|
fn origin(&self) -> &ModelSource;
|
||||||
|
|
||||||
|
|
|
@ -73,6 +73,10 @@ impl YoloV8 {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model for YoloV8 {
|
impl Model for YoloV8 {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"YoloV8"
|
||||||
|
}
|
||||||
|
|
||||||
fn origin(&self) -> &'static ModelSource {
|
fn origin(&self) -> &'static ModelSource {
|
||||||
self.model_origin
|
self.model_origin
|
||||||
}
|
}
|
||||||
|
|
|
@ -473,7 +473,7 @@ pub async fn assign_labels(
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
sync.write_ops(
|
sync.write_ops(
|
||||||
&db,
|
db,
|
||||||
(
|
(
|
||||||
sync_params,
|
sync_params,
|
||||||
db.label_on_object()
|
db.label_on_object()
|
||||||
|
|
|
@ -1,21 +1,17 @@
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use ort::{EnvironmentBuilder, LoggingLevel};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
|
use ort::EnvironmentBuilder;
|
||||||
use tracing::{debug, error};
|
use tracing::{debug, error};
|
||||||
|
|
||||||
pub mod image_labeler;
|
pub mod image_labeler;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
// This path must be relative to the running binary
|
// This path must be relative to the running binary
|
||||||
#[cfg(windows)]
|
#[cfg(target_os = "windows")]
|
||||||
const BINDING_LOCATION: &str = ".";
|
const BINDING_LOCATION: &str = ".";
|
||||||
#[cfg(unix)]
|
|
||||||
const BINDING_LOCATION: &str = if cfg!(target_os = "macos") {
|
#[cfg(target_os = "macos")]
|
||||||
"../Frameworks/Spacedrive.framework/Libraries"
|
const BINDING_LOCATION: &str = "../Frameworks/Spacedrive.framework/Libraries";
|
||||||
} else {
|
|
||||||
"../lib/spacedrive"
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
const LIB_NAME: &str = "onnxruntime.dll";
|
const LIB_NAME: &str = "onnxruntime.dll";
|
||||||
|
@ -23,22 +19,17 @@ const LIB_NAME: &str = "onnxruntime.dll";
|
||||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
||||||
const LIB_NAME: &str = "libonnxruntime.dylib";
|
const LIB_NAME: &str = "libonnxruntime.dylib";
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_os = "android"))]
|
|
||||||
const LIB_NAME: &str = "libonnxruntime.so";
|
|
||||||
|
|
||||||
pub fn init() -> Result<(), Error> {
|
pub fn init() -> Result<(), Error> {
|
||||||
let path = utils::get_path_relative_to_exe(Path::new(BINDING_LOCATION).join(LIB_NAME));
|
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "windows"))]
|
||||||
|
{
|
||||||
std::env::set_var("ORT_DYLIB_PATH", path);
|
use std::path::Path;
|
||||||
|
let path = utils::get_path_relative_to_exe(Path::new(BINDING_LOCATION).join(LIB_NAME));
|
||||||
|
std::env::set_var("ORT_DYLIB_PATH", path);
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize AI stuff
|
// Initialize AI stuff
|
||||||
EnvironmentBuilder::default()
|
EnvironmentBuilder::default()
|
||||||
.with_name("spacedrive")
|
.with_name("spacedrive")
|
||||||
.with_log_level(if cfg!(debug_assertions) {
|
|
||||||
LoggingLevel::Verbose
|
|
||||||
} else {
|
|
||||||
LoggingLevel::Info
|
|
||||||
})
|
|
||||||
.with_execution_providers({
|
.with_execution_providers({
|
||||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
||||||
{
|
{
|
||||||
|
@ -80,6 +71,7 @@ pub fn init() -> Result<(), Error> {
|
||||||
// }
|
// }
|
||||||
})
|
})
|
||||||
.commit()?;
|
.commit()?;
|
||||||
|
|
||||||
debug!("Initialized AI environment");
|
debug!("Initialized AI environment");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -97,7 +97,7 @@ impl Mdns {
|
||||||
// The max length of an MDNS record is painful so we just hash the data to come up with a pseudo-random but deterministic value.
|
// The max length of an MDNS record is painful so we just hash the data to come up with a pseudo-random but deterministic value.
|
||||||
// The full values are stored within TXT records.
|
// The full values are stored within TXT records.
|
||||||
let my_name = String::from_utf8_lossy(&base91::slice_encode(
|
let my_name = String::from_utf8_lossy(&base91::slice_encode(
|
||||||
&sha256::digest(format!("{}_{}", service_name, self.identity)).as_bytes(),
|
sha256::digest(format!("{}_{}", service_name, self.identity)).as_bytes(),
|
||||||
))[..63]
|
))[..63]
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
|
@ -236,7 +236,7 @@ impl Mdns {
|
||||||
info.get_fullname().to_string(),
|
info.get_fullname().to_string(),
|
||||||
TrackedService {
|
TrackedService {
|
||||||
service_name: service_name.to_string(),
|
service_name: service_name.to_string(),
|
||||||
identity: identity.clone(),
|
identity,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ pub fn r#enum(models: Vec<ModelWithSyncType>) -> TokenStream {
|
||||||
let item_model_sync_id_field_name_snake = models
|
let item_model_sync_id_field_name_snake = models
|
||||||
.iter()
|
.iter()
|
||||||
.find(|m| m.0.name() == item.related_model().name())
|
.find(|m| m.0.name() == item.related_model().name())
|
||||||
.and_then(|(m, sync)| sync.as_ref())
|
.and_then(|(_m, sync)| sync.as_ref())
|
||||||
.map(|sync| snake_ident(sync.sync_id()[0].name()))
|
.map(|sync| snake_ident(sync.sync_id()[0].name()))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let item_model_name_snake = snake_ident(item.related_model().name());
|
let item_model_name_snake = snake_ident(item.related_model().name());
|
||||||
|
|
|
@ -13,9 +13,10 @@ import { OnboardingContext, useContextValue } from './context';
|
||||||
import Progress from './Progress';
|
import Progress from './Progress';
|
||||||
|
|
||||||
export const Component = () => {
|
export const Component = () => {
|
||||||
const os = useOperatingSystem();
|
const os = useOperatingSystem(false);
|
||||||
const debugState = useDebugState();
|
const debugState = useDebugState();
|
||||||
const [showIntro, setShowIntro] = useState(true);
|
// FIX-ME: Intro video breaks onboarding for the web version
|
||||||
|
const [showIntro, setShowIntro] = useState(os !== 'browser');
|
||||||
const ctx = useContextValue();
|
const ctx = useContextValue();
|
||||||
|
|
||||||
if (ctx.libraries.isLoading) return null;
|
if (ctx.libraries.isLoading) return null;
|
||||||
|
|
Loading…
Reference in a new issue