[ENG-1595] Fix some problems with the AI system (#2063)

* Fix some problems with the AI system
 - Fix downloading model using an internal rust string representation as path for the model file
 - Fix Linux loading onnx shared lib from a hardcoded path
 - Fix App should not crash when the AI system fails to start
 - Fix sd-server failing to start due to onnxruntime incorrect linking
 - Some extra clippy auto fixes

* Use latest ort

* Fix dangling sd_ai reference
 - Use entrypoint.sh to initilize container

* Fix server Dockerfile
 - Fix cargo warning

* Workaround intro video breaking onboarding for the web version

* Fix rebase
This commit is contained in:
Vítor Vasconcellos 2024-02-12 14:45:17 -03:00 committed by GitHub
parent 74e2d23c11
commit 72efcc9f62
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 191 additions and 140 deletions

View file

@ -1,6 +1,9 @@
[env]
PROTOC = { force = true, value = "{{{protoc}}}" }
FFMPEG_DIR = { force = true, value = "{{{nativeDeps}}}" }
{{#isLinux}}
ORT_LIB_LOCATION = { force = true, value = "{{{nativeDeps}}}/lib" }
{{/isLinux}}
OPENSSL_STATIC = { force = true, value = "1" }
OPENSSL_NO_VENDOR = { force = true, value = "0" }
OPENSSL_RUST_USE_NASM = { force = true, value = "1" }

32
Cargo.lock generated
View file

@ -1338,6 +1338,15 @@ dependencies = [
"toml 0.7.8",
]
[[package]]
name = "castaway"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a17ed5635fc8536268e5d4de1e22e81ac34419e5f052d4d51f4e01dcc263fcc"
dependencies = [
"rustversion",
]
[[package]]
name = "cc"
version = "1.0.83"
@ -1584,6 +1593,19 @@ dependencies = [
"memchr",
]
[[package]]
name = "compact_str"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f"
dependencies = [
"castaway",
"cfg-if",
"itoa 1.0.10",
"ryu",
"static_assertions",
]
[[package]]
name = "concurrent-queue"
version = "2.4.0"
@ -5584,14 +5606,14 @@ checksum = "a86ed3f5f244b372d6b1a00b72ef7f8876d0bc6a78a4c9985c53614041512063"
[[package]]
name = "ort"
version = "2.0.0-alpha.2"
version = "2.0.0-rc.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a094a17bfe4f9eb561bfdf4454b8d0f6f89deaf9a5a572a1ef29c29ce708627"
checksum = "f8e5caf4eb2ead4bc137c3ff4e347940e3e556ceb11a4180627f04b63d7342dd"
dependencies = [
"compact_str",
"half",
"libloading 0.8.1",
"ndarray",
"once_cell",
"ort-sys",
"thiserror",
"tracing",
@ -5599,9 +5621,9 @@ dependencies = [
[[package]]
name = "ort-sys"
version = "2.0.0-alpha.3"
version = "2.0.0-rc.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "846ab4ca6873d26ac40f4bdfa8375e5e4547117693b06de5b7903e0c7471d1b7"
checksum = "f48b5623df2187e0db543ecb2032a6a999081086b7ffddd318000c00b23ace46"
dependencies = [
"flate2",
"sha2 0.10.8",

View file

@ -13,7 +13,6 @@ use std::{
use sd_core::{Node, NodeError};
use clear_localstorage::clear_localstorage;
use sd_fda::DiskAccess;
use serde::{Deserialize, Serialize};
use tauri::{

View file

@ -77,6 +77,9 @@ RUN cargo build --features assets --release -p sd-server
# Debug just means it includes busybox, and we need its tools both for the entrypoint.sh script and during runtime
FROM gcr.io/distroless/base-debian12:debug
RUN [ "/busybox/ln", "-s", "/busybox/sh", "/bin/sh" ]
RUN ln -s /busybox/env /usr/bin/env
ENV TZ=UTC \
PUID=1000 \
PGID=1000 \
@ -86,9 +89,10 @@ ENV TZ=UTC \
LANGUAGE=en \
DATA_DIR=/data
COPY --from=server /srv/spacedrive/target/release/sd-server /usr/bin/
COPY --from=server /lib/x86_64-linux-gnu/libgcc_s.so.1 /usr/lib/
COPY --from=server /srv/spacedrive/apps/.deps/lib /usr/lib/spacedrive
COPY --from=server --chmod=755 /srv/spacedrive/target/release/sd-server /usr/bin/
COPY --from=server --chmod=755 /lib/x86_64-linux-gnu/libgcc_s.so.1 /usr/lib/
COPY --from=server --chmod=755 /srv/spacedrive/apps/.deps/lib /usr/lib/spacedrive
ADD --chmod=755 https://github.com/spacedriveapp/native-deps/releases/download/yolo-2024-02-07/yolov8s.onnx /usr/share/spacedrive/models/yolov8s.onnx
COPY --chmod=755 entrypoint.sh /usr/bin/
@ -99,7 +103,7 @@ EXPOSE 8080
VOLUME [ "/data" ]
# Run the CLI when the container is started
ENTRYPOINT [ "sd-server" ]
ENTRYPOINT [ "entrypoint.sh" ]
LABEL org.opencontainers.image.title="Spacedrive Server" \
org.opencontainers.image.source="https://github.com/spacedriveapp/spacedrive"

View file

@ -36,4 +36,4 @@ fi
echo "Fix spacedrive's directories permissions"
chown -R "${PUID}:${PGID}" /data
exec su spacedrive -s /bin/server -- "$@"
exec su spacedrive -s /usr/bin/sd-server -- "$@"

View file

@ -198,7 +198,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
})?;
sync.write_op(
&db,
db,
sync.shared_update(
prisma_sync::object::SyncId {
pub_id: object.pub_id,
@ -244,7 +244,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
})?;
sync.write_op(
&db,
db,
sync.shared_update(
prisma_sync::object::SyncId {
pub_id: object.pub_id,
@ -324,7 +324,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
.unzip();
sync.write_ops(
&db,
db,
(
sync_params,
db.object().update_many(
@ -366,7 +366,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
})
.unzip();
sync.write_ops(
&db,
db,
(
sync_params,
db.object().update_many(

View file

@ -96,21 +96,30 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
if let Some(model) = new_model {
let version = model.version().to_string();
tokio::spawn(async move {
let notification = if let Err(e) =
node.image_labeller.change_model(model).await
{
NotificationData {
let notification =
if let Some(image_labeller) = node.image_labeller.as_ref() {
if let Err(e) = image_labeller.change_model(model).await {
NotificationData {
title: String::from(
"Failed to change image detection model",
),
content: format!("Error: {e}"),
kind: NotificationKind::Error,
}
} else {
NotificationData {
title: String::from("Model download completed"),
content: format!("Sucessfuly loaded model: {version}"),
kind: NotificationKind::Success,
}
}
} else {
NotificationData {
title: String::from("Failed to change image detection model"),
content: format!("Error: {e}"),
kind: NotificationKind::Error,
}
} else {
NotificationData {
title: String::from("Model download completed"),
content: format!("Sucessfuly loaded model: {version}"),
content: "The AI system is disabled due to a previous error. Contact support for help.".to_string(),
kind: NotificationKind::Success,
}
};
};
node.emit_notification(notification, None).await;
});

View file

@ -210,7 +210,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
.unzip();
sync.write_ops(
&db,
db,
(
sync_params,
db.saved_search()
@ -242,7 +242,7 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
})?;
sync.write_op(
&db,
db,
sync.shared_delete(prisma_sync::saved_search::SyncId {
pub_id: search.pub_id,
}),

View file

@ -79,7 +79,7 @@ pub async fn run_actor(
|uuid| sd_cloud_api::library::message_collections::get::InstanceTimestamp {
instance_uuid: *uuid,
from_time: cloud_timestamps
.get(&uuid)
.get(uuid)
.cloned()
.unwrap_or_default()
.as_u64()

View file

@ -67,7 +67,7 @@ pub struct Node {
pub env: Arc<env::Env>,
pub http: reqwest::Client,
#[cfg(feature = "ai")]
pub image_labeller: ImageLabeler,
pub image_labeller: Option<ImageLabeler>,
}
impl fmt::Debug for Node {
@ -115,31 +115,35 @@ impl Node {
let libraries = library::Libraries::new(data_dir.join("libraries")).await?;
let (p2p, p2p_actor) = p2p::P2PManager::new(config.clone(), libraries.clone()).await?;
let node = Arc::new(Node {
data_dir: data_dir.to_path_buf(),
jobs,
locations,
notifications: notifications::Notifications::new(),
p2p,
thumbnailer: Thumbnailer::new(
data_dir,
libraries.clone(),
event_bus.0.clone(),
config.preferences_watcher(),
)
.await,
config,
event_bus,
libraries,
files_over_p2p_flag: Arc::new(AtomicBool::new(false)),
cloud_sync_flag: Arc::new(AtomicBool::new(false)),
http: reqwest::Client::new(),
env,
#[cfg(feature = "ai")]
image_labeller: ImageLabeler::new(YoloV8::model(image_labeler_version)?, data_dir)
.await
.map_err(sd_ai::Error::from)?,
});
let node =
Arc::new(Node {
data_dir: data_dir.to_path_buf(),
jobs,
locations,
notifications: notifications::Notifications::new(),
p2p,
thumbnailer: Thumbnailer::new(
data_dir,
libraries.clone(),
event_bus.0.clone(),
config.preferences_watcher(),
)
.await,
config,
event_bus,
libraries,
files_over_p2p_flag: Arc::new(AtomicBool::new(false)),
cloud_sync_flag: Arc::new(AtomicBool::new(false)),
http: reqwest::Client::new(),
env,
#[cfg(feature = "ai")]
image_labeller: ImageLabeler::new(YoloV8::model(image_labeler_version)?, data_dir)
.await
.map_err(|e| {
error!("Failed to initialize image labeller. AI features will be disabled: {e:#?}");
})
.ok(),
});
// Restore backend feature flags
for feature in node.config.get().await.features {
@ -227,7 +231,9 @@ impl Node {
self.jobs.shutdown().await;
self.p2p.shutdown().await;
#[cfg(feature = "ai")]
self.image_labeller.shutdown().await;
if let Some(image_labeller) = &self.image_labeller {
image_labeller.shutdown().await;
}
info!("Spacedrive Core shutdown successful!");
}

View file

@ -619,7 +619,7 @@ impl Libraries {
let _ = this
.edit(
library.id.clone(),
library.id,
None,
MaybeUndefined::Undefined,
MaybeUndefined::Null,

View file

@ -917,7 +917,7 @@ pub(super) async fn remove_by_file_path(
.await?;
} else {
sync.write_op(
&db,
db,
sync.shared_delete(prisma_sync::file_path::SyncId {
pub_id: file_path.pub_id.clone(),
}),

View file

@ -1,5 +1,4 @@
use std::io::Error;
use std::process::Command;
use std::str;
use serde::{Deserialize, Serialize};
@ -37,6 +36,8 @@ impl HardwareModel {
pub fn get_hardware_model_name() -> Result<HardwareModel, Error> {
#[cfg(target_os = "macos")]
{
use std::process::Command;
let output = Command::new("system_profiler")
.arg("SPHardwareDataType")
.output()?;

View file

@ -88,7 +88,7 @@ impl StatefulJob for FileDeleterJobInit {
step.full_path.display()
);
sync.write_op(
&db,
db,
sync.shared_delete(prisma_sync::file_path::SyncId {
pub_id: step.file_path.pub_id.clone(),
}),

View file

@ -178,17 +178,21 @@ impl StatefulJob for MediaProcessorJobInit {
let total_files_for_labeling = file_paths_for_labeling.len();
#[cfg(feature = "ai")]
let (labeler_batch_token, labels_rx) = ctx
.node
.image_labeller
.new_resumable_batch(
location_id,
location_path.clone(),
file_paths_for_labeling,
Arc::clone(db),
sync.clone(),
)
.await;
let (labeler_batch_token, labels_rx) =
if let Some(image_labeller) = ctx.node.image_labeller.as_ref() {
let (labeler_batch_token, labels_rx) = image_labeller
.new_resumable_batch(
location_id,
location_path.clone(),
file_paths_for_labeling,
Arc::clone(db),
sync.clone(),
)
.await;
(labeler_batch_token, Some(labels_rx))
} else {
(uuid::Uuid::new_v4(), None)
};
let total_files = file_paths.len();
@ -240,7 +244,7 @@ impl StatefulJob for MediaProcessorJobInit {
#[cfg(feature = "ai")]
labeler_batch_token,
#[cfg(feature = "ai")]
maybe_labels_rx: Some(labels_rx),
maybe_labels_rx: labels_rx,
});
Ok((
@ -323,6 +327,12 @@ impl StatefulJob for MediaProcessorJobInit {
#[cfg(feature = "ai")]
MediaProcessorJobStep::WaitLabels(total_labels) => {
let Some(image_labeller) = ctx.node.image_labeller.as_ref() else {
let err = "AI system is disabled due to a previous error, skipping labels job";
error!(err);
return Ok(JobRunErrors(vec![err.to_string()]).into());
};
ctx.progress(vec![
JobReportUpdate::TaskCount(*total_labels),
JobReportUpdate::Phase("labels".to_string()),
@ -334,9 +344,7 @@ impl StatefulJob for MediaProcessorJobInit {
let mut labels_rx = pin!(if let Some(labels_rx) = data.maybe_labels_rx.clone() {
labels_rx
} else {
match ctx
.node
.image_labeller
match image_labeller
.resume_batch(
data.labeler_batch_token,
Arc::clone(&ctx.library.db),

View file

@ -109,16 +109,18 @@ pub async fn shallow(
);
#[cfg(feature = "ai")]
let labels_rx = node
.image_labeller
.new_batch(
location_id,
location_path.clone(),
file_paths_for_labelling,
Arc::clone(db),
sync.clone(),
)
.await;
// Check if we have an image labeller and has_labels then enqueue a new batch
let labels_rx = node.image_labeller.as_ref().and_then(|image_labeller| {
has_labels.then(|| {
image_labeller.new_batch(
location_id,
location_path.clone(),
file_paths_for_labelling,
Arc::clone(db),
sync.clone(),
)
})
});
let mut run_metadata = MediaProcessorMetadata::default();
@ -144,27 +146,30 @@ pub async fn shallow(
#[cfg(feature = "ai")]
{
if has_labels {
labels_rx
.for_each(
|LabelerOutput {
file_path_id,
has_new_labels,
result,
}| async move {
if let Err(e) = result {
error!(
if let Some(labels_rx) = labels_rx {
labels_rx
.await
.for_each(
|LabelerOutput {
file_path_id,
has_new_labels,
result,
}| async move {
if let Err(e) = result {
error!(
"Failed to generate labels <file_path_id='{file_path_id}'>: {e:#?}"
);
} else if has_new_labels {
// invalidate_query!(library, "labels.count"); // TODO: This query doesn't exist on main yet
}
},
)
.await;
} else if has_new_labels {
// invalidate_query!(library, "labels.count"); // TODO: This query doesn't exist on main yet
}
},
)
.await;
invalidate_query!(library, "labels.list");
invalidate_query!(library, "labels.getForObject");
invalidate_query!(library, "labels.getWithObjects");
invalidate_query!(library, "labels.list");
invalidate_query!(library, "labels.getForObject");
invalidate_query!(library, "labels.getWithObjects");
}
}
}

View file

@ -125,7 +125,7 @@ impl LibraryServices {
inserted = true;
Arc::new(
Service::new(
String::from_utf8_lossy(&base91::slice_encode(&*library.id.as_bytes())),
String::from_utf8_lossy(&base91::slice_encode(library.id.as_bytes())),
manager.manager.clone(),
)
.expect("error creating service with duplicate service name"),

View file

@ -42,21 +42,20 @@ half = { version = "2.1", features = ['num-traits'] }
# "gpu" means CUDA or TensorRT EP. Thus, the ort crate cannot download them at build time.
# Ref: https://github.com/pykeio/ort/blob/d7defd1862969b4b44f7f3f4b9c72263690bd67b/build.rs#L148
[target.'cfg(target_os = "windows")'.dependencies]
ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
ort = { version = "=2.0.0-rc.0", default-features = false, features = [
"ndarray",
"half",
"load-dynamic",
"directml",
] }
[target.'cfg(target_os = "linux")'.dependencies]
ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
ort = { version = "=2.0.0-rc.0", default-features = false, features = [
"ndarray",
"half",
"load-dynamic",
"xnnpack",
] }
# [target.'cfg(target_os = "android")'.dependencies]
# ort = { version = "2.0.0-alpha.2", default-features = false, features = [
# ort = { version = "=2.0.0-rc.0", default-features = false, features = [
# "half",
# "load-dynamic",
# "qnn",
@ -66,7 +65,7 @@ ort = { version = "=2.0.0-alpha.2", default-features = false, features = [
# "armnn",
# ] }
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
ort = { version = "=2.0.0-alpha.2", features = [
ort = { version = "=2.0.0-rc.0", features = [
"ndarray",
"half",
"load-dynamic",

View file

@ -29,9 +29,7 @@ pub enum ModelSource {
}
pub trait Model: Send + Sync + 'static {
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
fn name(&self) -> &'static str;
fn origin(&self) -> &ModelSource;

View file

@ -73,6 +73,10 @@ impl YoloV8 {
}
impl Model for YoloV8 {
fn name(&self) -> &'static str {
"YoloV8"
}
fn origin(&self) -> &'static ModelSource {
self.model_origin
}

View file

@ -473,7 +473,7 @@ pub async fn assign_labels(
.collect();
sync.write_ops(
&db,
db,
(
sync_params,
db.label_on_object()

View file

@ -1,21 +1,17 @@
use std::path::Path;
use ort::{EnvironmentBuilder, LoggingLevel};
use thiserror::Error;
use ort::EnvironmentBuilder;
use tracing::{debug, error};
pub mod image_labeler;
mod utils;
// This path must be relative to the running binary
#[cfg(windows)]
#[cfg(target_os = "windows")]
const BINDING_LOCATION: &str = ".";
#[cfg(unix)]
const BINDING_LOCATION: &str = if cfg!(target_os = "macos") {
"../Frameworks/Spacedrive.framework/Libraries"
} else {
"../lib/spacedrive"
};
#[cfg(target_os = "macos")]
const BINDING_LOCATION: &str = "../Frameworks/Spacedrive.framework/Libraries";
#[cfg(target_os = "windows")]
const LIB_NAME: &str = "onnxruntime.dll";
@ -23,22 +19,17 @@ const LIB_NAME: &str = "onnxruntime.dll";
#[cfg(any(target_os = "macos", target_os = "ios"))]
const LIB_NAME: &str = "libonnxruntime.dylib";
#[cfg(any(target_os = "linux", target_os = "android"))]
const LIB_NAME: &str = "libonnxruntime.so";
pub fn init() -> Result<(), Error> {
let path = utils::get_path_relative_to_exe(Path::new(BINDING_LOCATION).join(LIB_NAME));
std::env::set_var("ORT_DYLIB_PATH", path);
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "windows"))]
{
use std::path::Path;
let path = utils::get_path_relative_to_exe(Path::new(BINDING_LOCATION).join(LIB_NAME));
std::env::set_var("ORT_DYLIB_PATH", path);
}
// Initialize AI stuff
EnvironmentBuilder::default()
.with_name("spacedrive")
.with_log_level(if cfg!(debug_assertions) {
LoggingLevel::Verbose
} else {
LoggingLevel::Info
})
.with_execution_providers({
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
@ -80,6 +71,7 @@ pub fn init() -> Result<(), Error> {
// }
})
.commit()?;
debug!("Initialized AI environment");
Ok(())

View file

@ -97,7 +97,7 @@ impl Mdns {
// The max length of an MDNS record is painful so we just hash the data to come up with a pseudo-random but deterministic value.
// The full values are stored within TXT records.
let my_name = String::from_utf8_lossy(&base91::slice_encode(
&sha256::digest(format!("{}_{}", service_name, self.identity)).as_bytes(),
sha256::digest(format!("{}_{}", service_name, self.identity)).as_bytes(),
))[..63]
.to_string();
@ -236,7 +236,7 @@ impl Mdns {
info.get_fullname().to_string(),
TrackedService {
service_name: service_name.to_string(),
identity: identity.clone(),
identity,
},
);

View file

@ -89,7 +89,7 @@ pub fn r#enum(models: Vec<ModelWithSyncType>) -> TokenStream {
let item_model_sync_id_field_name_snake = models
.iter()
.find(|m| m.0.name() == item.related_model().name())
.and_then(|(m, sync)| sync.as_ref())
.and_then(|(_m, sync)| sync.as_ref())
.map(|sync| snake_ident(sync.sync_id()[0].name()))
.unwrap();
let item_model_name_snake = snake_ident(item.related_model().name());

View file

@ -13,9 +13,10 @@ import { OnboardingContext, useContextValue } from './context';
import Progress from './Progress';
export const Component = () => {
const os = useOperatingSystem();
const os = useOperatingSystem(false);
const debugState = useDebugState();
const [showIntro, setShowIntro] = useState(true);
// FIX-ME: Intro video breaks onboarding for the web version
const [showIntro, setShowIntro] = useState(os !== 'browser');
const ctx = useContextValue();
if (ctx.libraries.isLoading) return null;