mirror of
https://github.com/spacedriveapp/spacedrive
synced 2024-07-02 10:03:28 +00:00
[ENG-1595] Fix some problems with the AI system (#2063)
* Fix some problems with the AI system - Fix downloading model using an internal rust string representation as path for the model file - Fix Linux loading onnx shared lib from a hardcoded path - Fix App should not crash when the AI system fails to start - Fix sd-server failing to start due to onnxruntime incorrect linking - Some extra clippy auto fixes * Use latest ort * Fix dangling sd_ai reference - Use entrypoint.sh to initilize container * Fix server Dockerfile - Fix cargo warning * Workaround intro video breaking onboarding for the web version * Fix rebase
This commit is contained in:
parent
74e2d23c11
commit
72efcc9f62
|
@ -1,6 +1,9 @@
|
|||
[env]
|
||||
PROTOC = { force = true, value = "{{{protoc}}}" }
|
||||
FFMPEG_DIR = { force = true, value = "{{{nativeDeps}}}" }
|
||||
{{#isLinux}}
|
||||
ORT_LIB_LOCATION = { force = true, value = "{{{nativeDeps}}}/lib" }
|
||||
{{/isLinux}}
|
||||
OPENSSL_STATIC = { force = true, value = "1" }
|
||||
OPENSSL_NO_VENDOR = { force = true, value = "0" }
|
||||
OPENSSL_RUST_USE_NASM = { force = true, value = "1" }
|
||||
|
|
32
Cargo.lock
generated
32
Cargo.lock
generated
|
@ -1338,6 +1338,15 @@ dependencies = [
|
|||
"toml 0.7.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "castaway"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a17ed5635fc8536268e5d4de1e22e81ac34419e5f052d4d51f4e01dcc263fcc"
|
||||
dependencies = [
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.83"
|
||||
|
@ -1584,6 +1593,19 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "compact_str"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f"
|
||||
dependencies = [
|
||||
"castaway",
|
||||
"cfg-if",
|
||||
"itoa 1.0.10",
|
||||
"ryu",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "concurrent-queue"
|
||||
version = "2.4.0"
|
||||
|
@ -5584,14 +5606,14 @@ checksum = "a86ed3f5f244b372d6b1a00b72ef7f8876d0bc6a78a4c9985c53614041512063"
|
|||
|
||||
[[package]]
|
||||
name = "ort"
|
||||
version = "2.0.0-alpha.2"
|
||||
version = "2.0.0-rc.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a094a17bfe4f9eb561bfdf4454b8d0f6f89deaf9a5a572a1ef29c29ce708627"
|
||||
checksum = "f8e5caf4eb2ead4bc137c3ff4e347940e3e556ceb11a4180627f04b63d7342dd"
|
||||
dependencies = [
|
||||
"compact_str",
|
||||
"half",
|
||||
"libloading 0.8.1",
|
||||
"ndarray",
|
||||
"once_cell",
|
||||
"ort-sys",
|
||||
"thiserror",
|
||||
"tracing",
|
||||
|
@ -5599,9 +5621,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "ort-sys"
|
||||
version = "2.0.0-alpha.3"
|
||||
version = "2.0.0-rc.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "846ab4ca6873d26ac40f4bdfa8375e5e4547117693b06de5b7903e0c7471d1b7"
|
||||
checksum = "f48b5623df2187e0db543ecb2032a6a999081086b7ffddd318000c00b23ace46"
|
||||
dependencies = [
|
||||
"flate2",
|
||||
"sha2 0.10.8",
|
||||
|
|
|
@ -13,7 +13,6 @@ use std::{
|
|||
|
||||
use sd_core::{Node, NodeError};
|
||||
|
||||
use clear_localstorage::clear_localstorage;
|
||||
use sd_fda::DiskAccess;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tauri::{
|
||||
|
|
|
@ -77,6 +77,9 @@ RUN cargo build --features assets --release -p sd-server
|
|||
# Debug just means it includes busybox, and we need its tools both for the entrypoint.sh script and during runtime
|
||||
FROM gcr.io/distroless/base-debian12:debug
|
||||
|
||||
RUN [ "/busybox/ln", "-s", "/busybox/sh", "/bin/sh" ]
|
||||
RUN ln -s /busybox/env /usr/bin/env
|
||||
|
||||
ENV TZ=UTC \
|
||||
PUID=1000 \
|
||||
PGID=1000 \
|
||||
|
@ -86,9 +89,10 @@ ENV TZ=UTC \
|
|||
LANGUAGE=en \
|
||||
DATA_DIR=/data
|
||||
|
||||
COPY --from=server /srv/spacedrive/target/release/sd-server /usr/bin/
|
||||
COPY --from=server /lib/x86_64-linux-gnu/libgcc_s.so.1 /usr/lib/
|
||||
COPY --from=server /srv/spacedrive/apps/.deps/lib /usr/lib/spacedrive
|
||||
COPY --from=server --chmod=755 /srv/spacedrive/target/release/sd-server /usr/bin/
|
||||
COPY --from=server --chmod=755 /lib/x86_64-linux-gnu/libgcc_s.so.1 /usr/lib/
|
||||
COPY --from=server --chmod=755 /srv/spacedrive/apps/.deps/lib /usr/lib/spacedrive
|
||||
ADD --chmod=755 https://github.com/spacedriveapp/native-deps/releases/download/yolo-2024-02-07/yolov8s.onnx /usr/share/spacedrive/models/yolov8s.onnx
|
||||
|
||||
COPY --chmod=755 entrypoint.sh /usr/bin/
|
||||
|
||||
|
@ -99,7 +103,7 @@ EXPOSE 8080
|
|||
VOLUME [ "/data" ]
|
||||
|
||||
# Run the CLI when the container is started
|
||||
ENTRYPOINT [ "sd-server" ]
|
||||
ENTRYPOINT [ "entrypoint.sh" ]
|
||||
|
||||
LABEL org.opencontainers.image.title="Spacedrive Server" \
|
||||
org.opencontainers.image.source="https://github.com/spacedriveapp/spacedrive"
|
||||
|
|
|
@ -36,4 +36,4 @@ fi
|
|||
echo "Fix spacedrive's directories permissions"
|
||||
chown -R "${PUID}:${PGID}" /data
|
||||
|
||||
exec su spacedrive -s /bin/server -- "$@"
|
||||
exec su spacedrive -s /usr/bin/sd-server -- "$@"
|
||||
|
|
|
@ -198,7 +198,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
|||
})?;
|
||||
|
||||
sync.write_op(
|
||||
&db,
|
||||
db,
|
||||
sync.shared_update(
|
||||
prisma_sync::object::SyncId {
|
||||
pub_id: object.pub_id,
|
||||
|
@ -244,7 +244,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
|||
})?;
|
||||
|
||||
sync.write_op(
|
||||
&db,
|
||||
db,
|
||||
sync.shared_update(
|
||||
prisma_sync::object::SyncId {
|
||||
pub_id: object.pub_id,
|
||||
|
@ -324,7 +324,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
|||
.unzip();
|
||||
|
||||
sync.write_ops(
|
||||
&db,
|
||||
db,
|
||||
(
|
||||
sync_params,
|
||||
db.object().update_many(
|
||||
|
@ -366,7 +366,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
|||
})
|
||||
.unzip();
|
||||
sync.write_ops(
|
||||
&db,
|
||||
db,
|
||||
(
|
||||
sync_params,
|
||||
db.object().update_many(
|
||||
|
|
|
@ -96,21 +96,30 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
|||
if let Some(model) = new_model {
|
||||
let version = model.version().to_string();
|
||||
tokio::spawn(async move {
|
||||
let notification = if let Err(e) =
|
||||
node.image_labeller.change_model(model).await
|
||||
{
|
||||
NotificationData {
|
||||
let notification =
|
||||
if let Some(image_labeller) = node.image_labeller.as_ref() {
|
||||
if let Err(e) = image_labeller.change_model(model).await {
|
||||
NotificationData {
|
||||
title: String::from(
|
||||
"Failed to change image detection model",
|
||||
),
|
||||
content: format!("Error: {e}"),
|
||||
kind: NotificationKind::Error,
|
||||
}
|
||||
} else {
|
||||
NotificationData {
|
||||
title: String::from("Model download completed"),
|
||||
content: format!("Sucessfuly loaded model: {version}"),
|
||||
kind: NotificationKind::Success,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
NotificationData {
|
||||
title: String::from("Failed to change image detection model"),
|
||||
content: format!("Error: {e}"),
|
||||
kind: NotificationKind::Error,
|
||||
}
|
||||
} else {
|
||||
NotificationData {
|
||||
title: String::from("Model download completed"),
|
||||
content: format!("Sucessfuly loaded model: {version}"),
|
||||
content: "The AI system is disabled due to a previous error. Contact support for help.".to_string(),
|
||||
kind: NotificationKind::Success,
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
node.emit_notification(notification, None).await;
|
||||
});
|
||||
|
|
|
@ -210,7 +210,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
|||
.unzip();
|
||||
|
||||
sync.write_ops(
|
||||
&db,
|
||||
db,
|
||||
(
|
||||
sync_params,
|
||||
db.saved_search()
|
||||
|
@ -242,7 +242,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
|||
})?;
|
||||
|
||||
sync.write_op(
|
||||
&db,
|
||||
db,
|
||||
sync.shared_delete(prisma_sync::saved_search::SyncId {
|
||||
pub_id: search.pub_id,
|
||||
}),
|
||||
|
|
|
@ -79,7 +79,7 @@ pub async fn run_actor(
|
|||
|uuid| sd_cloud_api::library::message_collections::get::InstanceTimestamp {
|
||||
instance_uuid: *uuid,
|
||||
from_time: cloud_timestamps
|
||||
.get(&uuid)
|
||||
.get(uuid)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.as_u64()
|
||||
|
|
|
@ -67,7 +67,7 @@ pub struct Node {
|
|||
pub env: Arc<env::Env>,
|
||||
pub http: reqwest::Client,
|
||||
#[cfg(feature = "ai")]
|
||||
pub image_labeller: ImageLabeler,
|
||||
pub image_labeller: Option<ImageLabeler>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Node {
|
||||
|
@ -115,31 +115,35 @@ impl Node {
|
|||
let libraries = library::Libraries::new(data_dir.join("libraries")).await?;
|
||||
|
||||
let (p2p, p2p_actor) = p2p::P2PManager::new(config.clone(), libraries.clone()).await?;
|
||||
let node = Arc::new(Node {
|
||||
data_dir: data_dir.to_path_buf(),
|
||||
jobs,
|
||||
locations,
|
||||
notifications: notifications::Notifications::new(),
|
||||
p2p,
|
||||
thumbnailer: Thumbnailer::new(
|
||||
data_dir,
|
||||
libraries.clone(),
|
||||
event_bus.0.clone(),
|
||||
config.preferences_watcher(),
|
||||
)
|
||||
.await,
|
||||
config,
|
||||
event_bus,
|
||||
libraries,
|
||||
files_over_p2p_flag: Arc::new(AtomicBool::new(false)),
|
||||
cloud_sync_flag: Arc::new(AtomicBool::new(false)),
|
||||
http: reqwest::Client::new(),
|
||||
env,
|
||||
#[cfg(feature = "ai")]
|
||||
image_labeller: ImageLabeler::new(YoloV8::model(image_labeler_version)?, data_dir)
|
||||
.await
|
||||
.map_err(sd_ai::Error::from)?,
|
||||
});
|
||||
let node =
|
||||
Arc::new(Node {
|
||||
data_dir: data_dir.to_path_buf(),
|
||||
jobs,
|
||||
locations,
|
||||
notifications: notifications::Notifications::new(),
|
||||
p2p,
|
||||
thumbnailer: Thumbnailer::new(
|
||||
data_dir,
|
||||
libraries.clone(),
|
||||
event_bus.0.clone(),
|
||||
config.preferences_watcher(),
|
||||
)
|
||||
.await,
|
||||
config,
|
||||
event_bus,
|
||||
libraries,
|
||||
files_over_p2p_flag: Arc::new(AtomicBool::new(false)),
|
||||
cloud_sync_flag: Arc::new(AtomicBool::new(false)),
|
||||
http: reqwest::Client::new(),
|
||||
env,
|
||||
#[cfg(feature = "ai")]
|
||||
image_labeller: ImageLabeler::new(YoloV8::model(image_labeler_version)?, data_dir)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to initialize image labeller. AI features will be disabled: {e:#?}");
|
||||
})
|
||||
.ok(),
|
||||
});
|
||||
|
||||
// Restore backend feature flags
|
||||
for feature in node.config.get().await.features {
|
||||
|
@ -227,7 +231,9 @@ impl Node {
|
|||
self.jobs.shutdown().await;
|
||||
self.p2p.shutdown().await;
|
||||
#[cfg(feature = "ai")]
|
||||
self.image_labeller.shutdown().await;
|
||||
if let Some(image_labeller) = &self.image_labeller {
|
||||
image_labeller.shutdown().await;
|
||||
}
|
||||
info!("Spacedrive Core shutdown successful!");
|
||||
}
|
||||
|
||||
|
|
|
@ -619,7 +619,7 @@ impl Libraries {
|
|||
|
||||
let _ = this
|
||||
.edit(
|
||||
library.id.clone(),
|
||||
library.id,
|
||||
None,
|
||||
MaybeUndefined::Undefined,
|
||||
MaybeUndefined::Null,
|
||||
|
|
|
@ -917,7 +917,7 @@ pub(super) async fn remove_by_file_path(
|
|||
.await?;
|
||||
} else {
|
||||
sync.write_op(
|
||||
&db,
|
||||
db,
|
||||
sync.shared_delete(prisma_sync::file_path::SyncId {
|
||||
pub_id: file_path.pub_id.clone(),
|
||||
}),
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use std::io::Error;
|
||||
use std::process::Command;
|
||||
use std::str;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -37,6 +36,8 @@ impl HardwareModel {
|
|||
pub fn get_hardware_model_name() -> Result<HardwareModel, Error> {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
use std::process::Command;
|
||||
|
||||
let output = Command::new("system_profiler")
|
||||
.arg("SPHardwareDataType")
|
||||
.output()?;
|
||||
|
|
|
@ -88,7 +88,7 @@ impl StatefulJob for FileDeleterJobInit {
|
|||
step.full_path.display()
|
||||
);
|
||||
sync.write_op(
|
||||
&db,
|
||||
db,
|
||||
sync.shared_delete(prisma_sync::file_path::SyncId {
|
||||
pub_id: step.file_path.pub_id.clone(),
|
||||
}),
|
||||
|
|
|
@ -178,17 +178,21 @@ impl StatefulJob for MediaProcessorJobInit {
|
|||
let total_files_for_labeling = file_paths_for_labeling.len();
|
||||
|
||||
#[cfg(feature = "ai")]
|
||||
let (labeler_batch_token, labels_rx) = ctx
|
||||
.node
|
||||
.image_labeller
|
||||
.new_resumable_batch(
|
||||
location_id,
|
||||
location_path.clone(),
|
||||
file_paths_for_labeling,
|
||||
Arc::clone(db),
|
||||
sync.clone(),
|
||||
)
|
||||
.await;
|
||||
let (labeler_batch_token, labels_rx) =
|
||||
if let Some(image_labeller) = ctx.node.image_labeller.as_ref() {
|
||||
let (labeler_batch_token, labels_rx) = image_labeller
|
||||
.new_resumable_batch(
|
||||
location_id,
|
||||
location_path.clone(),
|
||||
file_paths_for_labeling,
|
||||
Arc::clone(db),
|
||||
sync.clone(),
|
||||
)
|
||||
.await;
|
||||
(labeler_batch_token, Some(labels_rx))
|
||||
} else {
|
||||
(uuid::Uuid::new_v4(), None)
|
||||
};
|
||||
|
||||
let total_files = file_paths.len();
|
||||
|
||||
|
@ -240,7 +244,7 @@ impl StatefulJob for MediaProcessorJobInit {
|
|||
#[cfg(feature = "ai")]
|
||||
labeler_batch_token,
|
||||
#[cfg(feature = "ai")]
|
||||
maybe_labels_rx: Some(labels_rx),
|
||||
maybe_labels_rx: labels_rx,
|
||||
});
|
||||
|
||||
Ok((
|
||||
|
@ -323,6 +327,12 @@ impl StatefulJob for MediaProcessorJobInit {
|
|||
|
||||
#[cfg(feature = "ai")]
|
||||
MediaProcessorJobStep::WaitLabels(total_labels) => {
|
||||
let Some(image_labeller) = ctx.node.image_labeller.as_ref() else {
|
||||
let err = "AI system is disabled due to a previous error, skipping labels job";
|
||||
error!(err);
|
||||
return Ok(JobRunErrors(vec![err.to_string()]).into());
|
||||
};
|
||||
|
||||
ctx.progress(vec![
|
||||
JobReportUpdate::TaskCount(*total_labels),
|
||||
JobReportUpdate::Phase("labels".to_string()),
|
||||
|
@ -334,9 +344,7 @@ impl StatefulJob for MediaProcessorJobInit {
|
|||
let mut labels_rx = pin!(if let Some(labels_rx) = data.maybe_labels_rx.clone() {
|
||||
labels_rx
|
||||
} else {
|
||||
match ctx
|
||||
.node
|
||||
.image_labeller
|
||||
match image_labeller
|
||||
.resume_batch(
|
||||
data.labeler_batch_token,
|
||||
Arc::clone(&ctx.library.db),
|
||||
|
|
|
@ -109,16 +109,18 @@ pub async fn shallow(
|
|||
);
|
||||
|
||||
#[cfg(feature = "ai")]
|
||||
let labels_rx = node
|
||||
.image_labeller
|
||||
.new_batch(
|
||||
location_id,
|
||||
location_path.clone(),
|
||||
file_paths_for_labelling,
|
||||
Arc::clone(db),
|
||||
sync.clone(),
|
||||
)
|
||||
.await;
|
||||
// Check if we have an image labeller and has_labels then enqueue a new batch
|
||||
let labels_rx = node.image_labeller.as_ref().and_then(|image_labeller| {
|
||||
has_labels.then(|| {
|
||||
image_labeller.new_batch(
|
||||
location_id,
|
||||
location_path.clone(),
|
||||
file_paths_for_labelling,
|
||||
Arc::clone(db),
|
||||
sync.clone(),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
let mut run_metadata = MediaProcessorMetadata::default();
|
||||
|
||||
|
@ -144,27 +146,30 @@ pub async fn shallow(
|
|||
#[cfg(feature = "ai")]
|
||||
{
|
||||
if has_labels {
|
||||
labels_rx
|
||||
.for_each(
|
||||
|LabelerOutput {
|
||||
file_path_id,
|
||||
has_new_labels,
|
||||
result,
|
||||
}| async move {
|
||||
if let Err(e) = result {
|
||||
error!(
|
||||
if let Some(labels_rx) = labels_rx {
|
||||
labels_rx
|
||||
.await
|
||||
.for_each(
|
||||
|LabelerOutput {
|
||||
file_path_id,
|
||||
has_new_labels,
|
||||
result,
|
||||
}| async move {
|
||||
if let Err(e) = result {
|
||||
error!(
|
||||
"Failed to generate labels <file_path_id='{file_path_id}'>: {e:#?}"
|
||||
);
|
||||
} else if has_new_labels {
|
||||
// invalidate_query!(library, "labels.count"); // TODO: This query doesn't exist on main yet
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
} else if has_new_labels {
|
||||
// invalidate_query!(library, "labels.count"); // TODO: This query doesn't exist on main yet
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
invalidate_query!(library, "labels.list");
|
||||
invalidate_query!(library, "labels.getForObject");
|
||||
invalidate_query!(library, "labels.getWithObjects");
|
||||
invalidate_query!(library, "labels.list");
|
||||
invalidate_query!(library, "labels.getForObject");
|
||||
invalidate_query!(library, "labels.getWithObjects");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ impl LibraryServices {
|
|||
inserted = true;
|
||||
Arc::new(
|
||||
Service::new(
|
||||
String::from_utf8_lossy(&base91::slice_encode(&*library.id.as_bytes())),
|
||||
String::from_utf8_lossy(&base91::slice_encode(library.id.as_bytes())),
|
||||
manager.manager.clone(),
|
||||
)
|
||||
.expect("error creating service with duplicate service name"),
|
||||
|
|
|
@ -42,21 +42,20 @@ half = { version = "2.1", features = ['num-traits'] }
|
|||
# "gpu" means CUDA or TensorRT EP. Thus, the ort crate cannot download them at build time.
|
||||
# Ref: https://github.com/pykeio/ort/blob/d7defd1862969b4b44f7f3f4b9c72263690bd67b/build.rs#L148
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
|
||||
ort = { version = "=2.0.0-rc.0", default-features = false, features = [
|
||||
"ndarray",
|
||||
"half",
|
||||
"load-dynamic",
|
||||
"directml",
|
||||
] }
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
|
||||
ort = { version = "=2.0.0-rc.0", default-features = false, features = [
|
||||
"ndarray",
|
||||
"half",
|
||||
"load-dynamic",
|
||||
"xnnpack",
|
||||
] }
|
||||
# [target.'cfg(target_os = "android")'.dependencies]
|
||||
# ort = { version = "2.0.0-alpha.2", default-features = false, features = [
|
||||
# ort = { version = "=2.0.0-rc.0", default-features = false, features = [
|
||||
# "half",
|
||||
# "load-dynamic",
|
||||
# "qnn",
|
||||
|
@ -66,7 +65,7 @@ ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
|
|||
# "armnn",
|
||||
# ] }
|
||||
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
|
||||
ort = { version = "=2.0.0-alpha.2", features = [
|
||||
ort = { version = "=2.0.0-rc.0", features = [
|
||||
"ndarray",
|
||||
"half",
|
||||
"load-dynamic",
|
||||
|
|
|
@ -29,9 +29,7 @@ pub enum ModelSource {
|
|||
}
|
||||
|
||||
pub trait Model: Send + Sync + 'static {
|
||||
fn name(&self) -> &'static str {
|
||||
std::any::type_name::<Self>()
|
||||
}
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
fn origin(&self) -> &ModelSource;
|
||||
|
||||
|
|
|
@ -73,6 +73,10 @@ impl YoloV8 {
|
|||
}
|
||||
|
||||
impl Model for YoloV8 {
|
||||
fn name(&self) -> &'static str {
|
||||
"YoloV8"
|
||||
}
|
||||
|
||||
fn origin(&self) -> &'static ModelSource {
|
||||
self.model_origin
|
||||
}
|
||||
|
|
|
@ -473,7 +473,7 @@ pub async fn assign_labels(
|
|||
.collect();
|
||||
|
||||
sync.write_ops(
|
||||
&db,
|
||||
db,
|
||||
(
|
||||
sync_params,
|
||||
db.label_on_object()
|
||||
|
|
|
@ -1,21 +1,17 @@
|
|||
use std::path::Path;
|
||||
|
||||
use ort::{EnvironmentBuilder, LoggingLevel};
|
||||
use thiserror::Error;
|
||||
|
||||
use ort::EnvironmentBuilder;
|
||||
use tracing::{debug, error};
|
||||
|
||||
pub mod image_labeler;
|
||||
mod utils;
|
||||
|
||||
// This path must be relative to the running binary
|
||||
#[cfg(windows)]
|
||||
#[cfg(target_os = "windows")]
|
||||
const BINDING_LOCATION: &str = ".";
|
||||
#[cfg(unix)]
|
||||
const BINDING_LOCATION: &str = if cfg!(target_os = "macos") {
|
||||
"../Frameworks/Spacedrive.framework/Libraries"
|
||||
} else {
|
||||
"../lib/spacedrive"
|
||||
};
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
const BINDING_LOCATION: &str = "../Frameworks/Spacedrive.framework/Libraries";
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
const LIB_NAME: &str = "onnxruntime.dll";
|
||||
|
@ -23,22 +19,17 @@ const LIB_NAME: &str = "onnxruntime.dll";
|
|||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
||||
const LIB_NAME: &str = "libonnxruntime.dylib";
|
||||
|
||||
#[cfg(any(target_os = "linux", target_os = "android"))]
|
||||
const LIB_NAME: &str = "libonnxruntime.so";
|
||||
|
||||
pub fn init() -> Result<(), Error> {
|
||||
let path = utils::get_path_relative_to_exe(Path::new(BINDING_LOCATION).join(LIB_NAME));
|
||||
|
||||
std::env::set_var("ORT_DYLIB_PATH", path);
|
||||
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "windows"))]
|
||||
{
|
||||
use std::path::Path;
|
||||
let path = utils::get_path_relative_to_exe(Path::new(BINDING_LOCATION).join(LIB_NAME));
|
||||
std::env::set_var("ORT_DYLIB_PATH", path);
|
||||
}
|
||||
|
||||
// Initialize AI stuff
|
||||
EnvironmentBuilder::default()
|
||||
.with_name("spacedrive")
|
||||
.with_log_level(if cfg!(debug_assertions) {
|
||||
LoggingLevel::Verbose
|
||||
} else {
|
||||
LoggingLevel::Info
|
||||
})
|
||||
.with_execution_providers({
|
||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
||||
{
|
||||
|
@ -80,6 +71,7 @@ pub fn init() -> Result<(), Error> {
|
|||
// }
|
||||
})
|
||||
.commit()?;
|
||||
|
||||
debug!("Initialized AI environment");
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -97,7 +97,7 @@ impl Mdns {
|
|||
// The max length of an MDNS record is painful so we just hash the data to come up with a pseudo-random but deterministic value.
|
||||
// The full values are stored within TXT records.
|
||||
let my_name = String::from_utf8_lossy(&base91::slice_encode(
|
||||
&sha256::digest(format!("{}_{}", service_name, self.identity)).as_bytes(),
|
||||
sha256::digest(format!("{}_{}", service_name, self.identity)).as_bytes(),
|
||||
))[..63]
|
||||
.to_string();
|
||||
|
||||
|
@ -236,7 +236,7 @@ impl Mdns {
|
|||
info.get_fullname().to_string(),
|
||||
TrackedService {
|
||||
service_name: service_name.to_string(),
|
||||
identity: identity.clone(),
|
||||
identity,
|
||||
},
|
||||
);
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ pub fn r#enum(models: Vec<ModelWithSyncType>) -> TokenStream {
|
|||
let item_model_sync_id_field_name_snake = models
|
||||
.iter()
|
||||
.find(|m| m.0.name() == item.related_model().name())
|
||||
.and_then(|(m, sync)| sync.as_ref())
|
||||
.and_then(|(_m, sync)| sync.as_ref())
|
||||
.map(|sync| snake_ident(sync.sync_id()[0].name()))
|
||||
.unwrap();
|
||||
let item_model_name_snake = snake_ident(item.related_model().name());
|
||||
|
|
|
@ -13,9 +13,10 @@ import { OnboardingContext, useContextValue } from './context';
|
|||
import Progress from './Progress';
|
||||
|
||||
export const Component = () => {
|
||||
const os = useOperatingSystem();
|
||||
const os = useOperatingSystem(false);
|
||||
const debugState = useDebugState();
|
||||
const [showIntro, setShowIntro] = useState(true);
|
||||
// FIX-ME: Intro video breaks onboarding for the web version
|
||||
const [showIntro, setShowIntro] = useState(os !== 'browser');
|
||||
const ctx = useContextValue();
|
||||
|
||||
if (ctx.libraries.isLoading) return null;
|
||||
|
|
Loading…
Reference in a new issue