diff --git a/synapse/state.py b/synapse/state.py index b9d5627a82..66e1a685e8 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -41,7 +41,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -77,6 +77,9 @@ class _StateCacheEntry(object): else: self.state_id = _gen_state_id() + def __len__(self): + return len(self.state) + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -99,6 +102,7 @@ class StateHandler(object): clock=self.clock, max_len=SIZE_OF_CACHE, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, reset_expiry_on_get=True, ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5620a655eb..963ef999d5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -169,7 +169,7 @@ class SQLBaseStore(object): max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) self._event_fetch_lock = threading.Condition() diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5d18037c7c..768e0a4451 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore): room_id, state_group, state_ids, ) - @cachedInlineCallbacks(num_args=2, cache_context=True) + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=100000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7f466c40ac..7d34dd03bf 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=1000) + @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8dba61d49f..a9ea97fd46 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -42,6 +42,25 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) +class CacheEntry(object): + __slots__ = [ + "deferred", "sequence", "callbacks", "invalidated" + ] + + def __init__(self, deferred, sequence, callbacks): + self.deferred = deferred + self.sequence = sequence + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() + + class Cache(object): __slots__ = ( "cache", @@ -51,12 +70,16 @@ class Cache(object): "sequence", "thread", "metrics", + "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False): + def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict + self._pending_deferred_cache = cache_type() + self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type + max_size=max_entries, keylen=keylen, cache_type=cache_type, + size_callback=(lambda d: len(d.result)) if iterable else None, ) self.name = name @@ -76,7 +99,15 @@ class Cache(object): ) def get(self, key, default=_CacheSentinel, callback=None): - val = self.cache.get(key, _CacheSentinel, callback=callback) + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + if val.sequence == self.sequence: + val.callbacks.update(callbacks) + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -88,15 +119,39 @@ class Cache(object): else: return default - def update(self, sequence, key, value, callback=None): + def set(self, key, value, callback=None): + callbacks = [callback] if callback else [] self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value, callback=callback) + entry = CacheEntry( + deferred=value, + sequence=self.sequence, + callbacks=callbacks, + ) + + entry.callbacks.update(callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def shuffle(result): + if self.sequence == entry.sequence: + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, entry.deferred, entry.callbacks) + else: + entry.invalidate() + else: + entry.invalidate() + return result + + entry.deferred.addCallback(shuffle) def prefill(self, key, value, callback=None): - self.cache.set(key, value, callback=callback) + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): self.check_thread() @@ -108,6 +163,10 @@ class Cache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 + entry = self._pending_deferred_cache.pop(key, None) + if entry: + entry.invalidate() + self.cache.pop(key, None) def invalidate_many(self, key): @@ -119,6 +178,12 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) + val = self._pending_deferred_cache.pop(key, None) + if val is not None: + entry_dict, _ = val + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + def invalidate_all(self): self.check_thread() self.sequence += 1 @@ -155,7 +220,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False, cache_context=False): + inlineCallbacks=False, cache_context=False, iterable=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -169,6 +234,8 @@ class CacheDescriptor(object): self.num_args = num_args self.tree = tree + self.iterable = iterable + all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] @@ -203,6 +270,7 @@ class CacheDescriptor(object): max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, + iterable=self.iterable, ) @functools.wraps(self.orig) @@ -243,11 +311,6 @@ class CacheDescriptor(object): return preserve_context_over_deferred(observer) except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - ret = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, @@ -261,7 +324,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=invalidate_callback) + cache.set(cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -359,7 +422,6 @@ class CacheListDescriptor(object): missing.append(arg) if missing: - sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -382,8 +444,8 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update( - sequence, tuple(key), observer, + cache.set( + tuple(key), observer, callback=invalidate_callback ) @@ -421,17 +483,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, cache_context=cache_context, + iterable=iterable, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -439,6 +504,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex tree=tree, inlineCallbacks=True, cache_context=cache_context, + iterable=iterable, ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index b0ca1bb79d..cb6933c61c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,9 @@ import logging logger = logging.getLogger(__name__) -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): + def __len__(self): + return len(self.value) class DictionaryCache(object): @@ -32,7 +34,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, size_callback=len) self.name = name self.sequence = 0 diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 080388958f..2987c38a2d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -15,6 +15,7 @@ from synapse.util.caches import register_cache +from collections import OrderedDict import logging @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False): + reset_expiry_on_get=False, iterable=False): """ Args: cache_name (str): Name of this cache, used for logging. @@ -36,6 +37,8 @@ class ExpiringCache(object): evicted based on time. reset_expiry_on_get (bool): If true, will reset the expiry time for an item on access. Defaults to False. + iterable (bool): If true, the size is calculated by summing the + sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -47,9 +50,13 @@ class ExpiringCache(object): self._reset_expiry_on_get = reset_expiry_on_get - self._cache = {} + self._cache = OrderedDict() - self.metrics = register_cache(cache_name, self._cache) + self.metrics = register_cache(cache_name, self) + + self.iterable = iterable + + self._size_estimate = 0 def start(self): if not self._expiry_ms: @@ -65,15 +72,14 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) - # Evict if there are now too many items - if self._max_len and len(self._cache.keys()) > self._max_len: - sorted_entries = sorted( - self._cache.items(), - key=lambda item: item[1].time, - ) + if self.iterable: + self._size_estimate += len(value) - for k, _ in sorted_entries[self._max_len:]: - self._cache.pop(k) + # Evict if there are now too many items + while self._max_len and len(self) > self._max_len: + _key, value = self._cache.popitem(last=False) + if self.iterable: + self._size_estimate -= len(value.value) def __getitem__(self, key): try: @@ -99,7 +105,7 @@ class ExpiringCache(object): # zero expiry time means don't expire. This should never get called # since we have this check in start too. return - begin_length = len(self._cache) + begin_length = len(self) now = self._clock.time_msec() @@ -110,15 +116,20 @@ class ExpiringCache(object): keys_to_delete.add(key) for k in keys_to_delete: - self._cache.pop(k) + value = self._cache.pop(k) + if self.iterable: + self._size_estimate -= len(value.value) logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache) + self._cache_name, begin_length, len(self) ) def __len__(self): - return len(self._cache) + if self.iterable: + return self._size_estimate + else: + return len(self._cache) class _CacheEntry(object): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 9c4c679175..072f9a9d19 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,7 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -58,6 +58,12 @@ class LruCache(object): lock = threading.Lock() + def evict(): + while cache_len() > max_size: + todelete = list_root.prev_node + delete_node(todelete) + cache.pop(todelete.key, None) + def synchronized(f): @wraps(f) def inner(*args, **kwargs): @@ -66,6 +72,16 @@ class LruCache(object): return inner + cached_cache_len = [0] + if size_callback is not None: + def cache_len(): + return cached_cache_len[0] + else: + def cache_len(): + return len(cache) + + self.len = synchronized(cache_len) + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node @@ -74,6 +90,9 @@ class LruCache(object): next_node.prev_node = node cache[key] = node + if size_callback: + cached_cache_len[0] += size_callback(node.value) + def move_node_to_front(node): prev_node = node.prev_node next_node = node.next_node @@ -92,23 +111,25 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + for cb in node.callbacks: cb() node.callbacks.clear() @synchronized - def cache_get(key, default=None, callback=None): + def cache_get(key, default=None, callbacks=[]): node = cache.get(key, None) if node is not None: move_node_to_front(node) - if callback: - node.callbacks.add(callback) + node.callbacks.update(callbacks) return node.value else: return default @synchronized - def cache_set(key, value, callback=None): + def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: if value != node.value: @@ -116,21 +137,18 @@ class LruCache(object): cb() node.callbacks.clear() - if callback: - node.callbacks.add(callback) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) + + node.callbacks.update(callbacks) move_node_to_front(node) node.value = value else: - if callback: - callbacks = set([callback]) - else: - callbacks = set() - add_node(key, value, callbacks) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + add_node(key, value, set(callbacks)) + + evict() @synchronized def cache_set_default(key, value): @@ -139,10 +157,7 @@ class LruCache(object): return node.value else: add_node(key, value) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + evict() return value @synchronized @@ -175,10 +190,6 @@ class LruCache(object): cb() cache.clear() - @synchronized - def cache_len(): - return len(cache) - @synchronized def cache_contains(key): return key in cache @@ -190,7 +201,7 @@ class LruCache(object): self.pop = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi - self.len = cache_len + self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c31585aea3..fcc341a6b7 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,12 +65,27 @@ class TreeCache(object): return popped def values(self): - return [e.value for e in self.root.values()] + return list(iterate_tree_cache_entry(self.root)) def __len__(self): return self.size +def iterate_tree_cache_entry(d): + """Helper function to iterate over the leaves of a tree, i.e. a dict of that + can contain dicts. + """ + if isinstance(d, dict): + for value_d in d.itervalues(): + for value in iterate_tree_cache_entry(value_d): + yield value + else: + if isinstance(d, _Entry): + yield d.value + else: + yield d + + class _Entry(object): __slots__ = ["value"] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index ab6095564a..8361dd8cee 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=2) + @cached(max_entries=20) # HACK: This makes it 2 due to cache factor def func(self, key): callcount[0] += 1 return key @@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 2) self.assertEquals(callcount2[0], 2) + yield a.func2("foo") + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + yield a.func("foo3") self.assertEquals(callcount[0], 3) diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py new file mode 100644 index 0000000000..31d24adb8b --- /dev/null +++ b/tests/util/test_expiring_cache.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .. import unittest + +from synapse.util.caches.expiringcache import ExpiringCache + +from tests.utils import MockClock + + +class ExpiringCacheTestCase(unittest.TestCase): + + def test_get_set(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=1) + + cache["key"] = "value" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache["key"], "value") + + def test_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=2) + + cache["key"] = "value" + cache["key2"] = "value2" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache.get("key2"), "value2") + + cache["key3"] = "value3" + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), "value2") + self.assertEquals(cache.get("key3"), "value3") + + def test_iterable_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=5, iterable=True) + + cache["key"] = [1] + cache["key2"] = [2, 3] + cache["key3"] = [4, 5] + + self.assertEquals(cache.get("key"), [1]) + self.assertEquals(cache.get("key2"), [2, 3]) + self.assertEquals(cache.get("key3"), [4, 5]) + + cache["key4"] = [6, 7] + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEquals(cache.get("key4"), [6, 7]) + + def test_time_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, expiry_ms=1000) + cache.start() + + cache["key"] = 1 + clock.advance_time(0.5) + cache["key2"] = 2 + + self.assertEquals(cache.get("key"), 1) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(0.9) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(1) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 1eba5b535e..dfb78cb8bd 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.get("key", "value") @@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value2") @@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value") @@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.pop("key") @@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m4 = Mock() cache = LruCache(4, 2, cache_type=TreeCache) - cache.set(("a", "1"), "value", m1) - cache.set(("a", "2"), "value", m2) - cache.set(("b", "1"), "value", m3) - cache.set(("b", "2"), "value", m4) + cache.set(("a", "1"), "value", callbacks=[m1]) + cache.set(("a", "2"), "value", callbacks=[m2]) + cache.set(("b", "1"), "value", callbacks=[m3]) + cache.set(("b", "2"), "value", callbacks=[m4]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m2 = Mock() cache = LruCache(5) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m3 = Mock(name="m3") cache = LruCache(2) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key3", "value", m3) + cache.set("key3", "value", callbacks=[m3]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) @@ -227,8 +227,33 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key1", "value", m1) + cache.set("key1", "value", callbacks=[m1]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 1) + + +class LruCacheSizedTestCase(unittest.TestCase): + + def test_evict(self): + cache = LruCache(5, size_callback=len) + cache["key1"] = [0] + cache["key2"] = [1, 2] + cache["key3"] = [3] + cache["key4"] = [4] + + self.assertEquals(cache["key1"], [0]) + self.assertEquals(cache["key2"], [1, 2]) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(len(cache), 5) + + cache["key5"] = [5, 6] + + self.assertEquals(len(cache), 4) + self.assertEquals(cache.get("key1"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(cache["key5"], [5, 6])