diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 00f72dc59f..6db2b1eae2 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -27,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; push::register_module(py, m)?; + tree_cache::binding::register_module(py, m)?; Ok(()) } diff --git a/rust/src/tree_cache/binding.rs b/rust/src/tree_cache/binding.rs new file mode 100644 index 0000000000..70207f8781 --- /dev/null +++ b/rust/src/tree_cache/binding.rs @@ -0,0 +1,128 @@ +use std::hash::Hash; + +use anyhow::Error; +use pyo3::{ + pyclass, pymethods, types::PyModule, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject, +}; + +use super::TreeCache; + +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "tree_cache")?; + child_module.add_class::()?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import push` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.tree_cache", child_module)?; + + Ok(()) +} + +struct HashablePyObject { + obj: PyObject, + hash: isize, +} + +impl HashablePyObject { + pub fn new(obj: &PyAny) -> Result { + let hash = obj.hash()?; + + Ok(HashablePyObject { + obj: obj.to_object(obj.py()), + hash, + }) + } +} + +impl IntoPy for &HashablePyObject { + fn into_py(self, _: Python<'_>) -> PyObject { + self.obj.clone() + } +} + +impl Hash for HashablePyObject { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } +} + +impl PartialEq for HashablePyObject { + fn eq(&self, other: &Self) -> bool { + let equal = Python::with_gil(|py| { + let result = self.obj.as_ref(py).eq(other.obj.as_ref(py)); + result.unwrap_or(false) + }); + + equal + } +} + +impl Eq for HashablePyObject {} + +#[pyclass] +struct PythonTreeCache(TreeCache); + +#[pymethods] +impl PythonTreeCache { + #[new] + fn new() -> Self { + PythonTreeCache(Default::default()) + } + + pub fn set(&mut self, key: &PyAny, value: PyObject) -> Result<(), Error> { + let v: Vec = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::>()?; + + self.0.set(v, value)?; + + Ok(()) + } + + // pub fn get_node(&self, key: &PyAny) -> Result>, Error> { + // todo!() + // } + + pub fn get(&self, key: &PyAny) -> Result, Error> { + let v: Vec = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::>()?; + + Ok(self.0.get(&v)?) + } + + // pub fn pop_node(&mut self, key: &PyAny) -> Result>, Error> { + // todo!() + // } + + pub fn pop(&mut self, key: &PyAny) -> Result, Error> { + let v: Vec = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::>()?; + + Ok(self.0.pop(&v)?) + } + + pub fn clear(&mut self) { + self.0.clear() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn values(&self) -> Vec<&PyObject> { + self.0.values().collect() + } + + pub fn items(&self) -> Vec<(Vec<&HashablePyObject>, &PyObject)> { + todo!() + } +} diff --git a/rust/src/tree_cache.rs b/rust/src/tree_cache/mod.rs similarity index 97% rename from rust/src/tree_cache.rs rename to rust/src/tree_cache/mod.rs index 6796229d64..0a4905b881 100644 --- a/rust/src/tree_cache.rs +++ b/rust/src/tree_cache/mod.rs @@ -2,6 +2,8 @@ use std::{collections::HashMap, hash::Hash}; use anyhow::{bail, Error}; +pub mod binding; + pub enum TreeCacheNode { Leaf(V), Branch(usize, HashMap>), @@ -114,17 +116,25 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode { } } +impl Default for TreeCacheNode { + fn default() -> Self { + TreeCacheNode::new_branch() + } +} + pub struct TreeCache { root: TreeCacheNode, } -impl<'a, K: Eq + Hash + 'a, V> TreeCache { +impl TreeCache { pub fn new() -> Self { TreeCache { root: TreeCacheNode::new_branch(), } } +} +impl<'a, K: Eq + Hash + 'a, V> TreeCache { pub fn set(&mut self, key: impl IntoIterator, value: V) -> Result<(), Error> { self.root.set(key.into_iter(), value)?; @@ -224,6 +234,12 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCache { } } +impl Default for TreeCache { + fn default() -> Self { + TreeCache::new() + } +} + #[cfg(test)] mod test { use std::collections::BTreeSet;