# SPDX-License-Identifier: AGPL-3.0-or-later
import time
import tracemalloc
import threading
import unittest
from functools import cached_property
from cachetools import TTLCache
from concurrent.futures import ThreadPoolExecutor

# ─────────────────────────────────────────────────────────────────────────────
# MODULE UNDER TEST
# ─────────────────────────────────────────────────────────────────────────────

# Global cache
cache = TTLCache(maxsize=10, ttl=5.0)
cache_write_lock = threading.RLock()


class Dummy:
    def __init__(self, id):
        self.id = id

    @cached_property
    def value(self):
        # pretend this is an expensive computation
        return "Hello"


def get_dummy(id):
    """Fetch from cache or create & cache a new Dummy."""
    global cache
    try:
        return cache[id]
    except KeyError:
        d = Dummy(id)
        with cache_write_lock:
            cache[id] = d
        return d


# ─────────────────────────────────────────────────────────────────────────────
# TEST SUITE
# ─────────────────────────────────────────────────────────────────────────────

class TestDummyCache(unittest.TestCase):
    def setUp(self):
        # reset to a fresh 10-item/1s TTL cache before each test
        global cache
        cache = TTLCache(maxsize=10, ttl=1.0)

    def test_cached_property(self):
        d = Dummy(1)
        # first access
        self.assertEqual(d.value, "Hello")
        # cached on the instance
        self.assertIn('value', d.__dict__)

    def test_cache_eviction_basic(self):
        for i in range(20):
            get_dummy(i)
            # never exceed 10
            self.assertLessEqual(len(cache), 10)
            # only the last 10 ids remain
            expected = list(range(max(0, i - 9), i + 1))
            self.assertEqual(sorted(cache.keys()), expected)

    def test_ttl_expiration_basic(self):
        # override to very short TTL
        global cache
        cache = TTLCache(maxsize=10, ttl=0.1)
        a = get_dummy(42)
        time.sleep(0.2)
        b = get_dummy(42)
        self.assertIsNot(a, b)
        c = get_dummy(42)
        self.assertIs(b, c)

    def test_memory_usage_bounded_basic(self):
        tracemalloc.start()
        for i in range(10_000):
            get_dummy(i)
        time.sleep(0.2)
        _, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        print(f"[basic] peak memory = {peak / (1024 * 1024):.2f} MiB")
        # peak should stay under ~100KiB
        self.assertLess(peak, 100 * 1024)

    def test_cache_eviction_threaded(self):
        # concurrently insert 20 different ids
        with ThreadPoolExecutor(max_workers=5) as pool:
            pool.map(get_dummy, range(1000))
        self.assertLessEqual(len(cache), 10)

    def test_cached_property_threaded(self):
        d = Dummy(123)
        results = []

        def reader():
            results.append(d.value)

        threads = [threading.Thread(target=reader) for _ in range(20)]
        for t in threads: t.start()
        for t in threads: t.join()

        self.assertTrue(all(r == "Hello" for r in results))
        self.assertIn('value', d.__dict__)

    def test_ttl_expiration_threaded(self):
        # override to very short TTL
        global cache
        cache = TTLCache(maxsize=10, ttl=0.1)
        a = get_dummy(7)
        time.sleep(0.2)

        with ThreadPoolExecutor(max_workers=5) as pool:
            objs = list(pool.map(get_dummy, [7] * 5))

        first = objs[0]
        self.assertTrue(all(o is first for o in objs))
        self.assertIsNot(a, first)

    def test_memory_usage_bounded_threaded(self):  # NOTE: this can take ~15 seconds!
        tracemalloc.start()

        _failed = False

        def worker(start, end):
            nonlocal _failed
            try:
                for i in range(start, end):
                    get_dummy(i)
            except Exception:
                _failed = True
                raise
                # NOTE: Use locks! This might occasionally fail with KeyError inside the thread unless the lock is used.

        threads = []
        n = 10_000
        # 20 threads × 10.000 creations = 200.000 total
        for i in range(20):
            t = threading.Thread(target=worker, args=(i * n, (i + 1) * n))
            threads.append(t)
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        _, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        print(f"[threaded] peak memory = {peak / (1024 * 1024):.2f} MiB")
        self.assertLess(peak, 10 * 1024 * 1024)
        self.assertFalse(_failed)


if __name__ == '__main__':
    unittest.main(verbosity=2)
