diff --git a/pkg/lrucache/README.md b/pkg/lrucache/README.md new file mode 100644 index 0000000..8cd2751 --- /dev/null +++ b/pkg/lrucache/README.md @@ -0,0 +1,121 @@ +# In-Memory LRU Cache for Golang Applications + +[![](https://pkg.go.dev/badge/github.com/iamlouk/lrucache?utm_source=godoc)](https://pkg.go.dev/github.com/iamlouk/lrucache) + +This library can be embedded into your existing go applications +and play the role *Memcached* or *Redis* might play for others. +It is inspired by [PHP Symfony's Cache Components](https://symfony.com/doc/current/components/cache/adapters/array_cache_adapter.html), +having a similar API. This library can not be used for persistance, +is not properly tested yet and a bit special in a few ways described +below (Especially with regards to the memory usage/`size`). + +In addition to the interface described below, a `http.Handler` that can be used as middleware is provided as well. + +- Advantages: + - Anything (`interface{}`) can be stored as value + - As it lives in the application itself, no serialization or de-serialization is needed + - As it lives in the application itself, no memory moving/networking is needed + - The computation of a new value for a key does __not__ block the full cache (only the key) +- Disadvantages: + - You have to provide a size estimate for every value + - __This size estimate should not change (i.e. values should not mutate)__ + - The cache can only be accessed by one application + +## Example + +```go +// Go look at the godocs and ./cache_test.go for more documentation and examples + +maxMemory := 1000 +cache := lrucache.New(maxMemory) + +bar = cache.Get("foo", func () (value interface{}, ttl time.Duration, size int) { + return "bar", 10 * time.Second, len("bar") +}).(string) + +// bar == "bar" + +bar = cache.Get("foo", func () (value interface{}, ttl time.Duration, size int) { + panic("will not be called") +}).(string) +``` + +## Why does `cache.Get` take a function as argument? + +*Using the mechanism described below is optional, the second argument to `Get` can be `nil` and there is a `Put` function as well.* + +Because this library is meant to be used by multi threaded applications and the following would +result in the same data being fetched twice if both goroutines run in parallel: + +```go +// This code shows what could happen with other cache libraries +c := lrucache.New(MAX_CACHE_ENTRIES) + +for i := 0; i < 2; i++ { + go func(){ + // This code will run twice in different goroutines, + // it could overlap. As `fetchData` probably does some + // I/O and takes a long time, the probability of both + // goroutines calling `fetchData` is very high! + url := "http://example.com/foo" + contents := c.Get(url) + if contents == nil { + contents = fetchData(url) + c.Set(url, contents) + } + + handleData(contents.([]byte)) + }() +} + +``` + +Here, if one wanted to make sure that only one of both goroutines fetches the data, +the programmer would need to build his own synchronization. That would suck! + +```go +c := lrucache.New(MAX_CACHE_SIZE) + +for i := 0; i < 2; i++ { + go func(){ + url := "http://example.com/foo" + contents := c.Get(url, func()(interface{}, time.Time, int) { + // This closure will only be called once! + // If another goroutine calls `c.Get` while this closure + // is still being executed, it will wait. + buf := fetchData(url) + return buf, 100 * time.Second, len(buf) + }) + + handleData(contents.([]byte)) + }() +} + +``` + +This is much better as less resources are wasted and synchronization is handled by +the library. If it gets called, the call to the closure happens synchronously. While +it is being executed, all other cache keys can still be accessed without having to wait +for the execution to be done. + +## How `Get` works + +The closure passed to `Get` will be called if the value asked for is not cached or +expired. It should return the following values: + +- The value corresponding to that key and to be stored in the cache +- The time to live for that value (how long until it expires and needs to be recomputed) +- A size estimate + +When `maxMemory` is reached, cache entries need to be evicted. Theoretically, +it would be possible to use reflection on every value placed in the cache +to get its exact size in bytes. This would be very expansive and slow though. +Also, size can change. Instead of this library calculating the size in bytes, you, the user, +have to provide a size for every value in whatever unit you like (as long as it is the same unit everywhere). + +Suggestions on what to use as size: `len(str)` for strings, `len(slice) * size_of_slice_type`, etc.. It is possible +to use `1` as size for every entry, in that case at most `maxMemory` entries will be in the cache at the same time. + +## Affects on GC + +Because of the way a garbage collector decides when to run ([explained in the runtime package](https://pkg.go.dev/runtime)), having large amounts of data sitting in your cache might increase the memory consumption of your process by two times the maximum size of the cache. You can decrease the *target percentage* to reduce the effect, but then you might have negative performance effects when your cache is not filled. diff --git a/pkg/lrucache/cache.go b/pkg/lrucache/cache.go new file mode 100644 index 0000000..aedfd5c --- /dev/null +++ b/pkg/lrucache/cache.go @@ -0,0 +1,288 @@ +package lrucache + +import ( + "sync" + "time" +) + +// Type of the closure that must be passed to `Get` to +// compute the value in case it is not cached. +// +// returned values are the computed value to be stored in the cache, +// the duration until this value will expire and a size estimate. +type ComputeValue func() (value interface{}, ttl time.Duration, size int) + +type cacheEntry struct { + key string + value interface{} + + expiration time.Time + size int + waitingForComputation int + + next, prev *cacheEntry +} + +type Cache struct { + mutex sync.Mutex + cond *sync.Cond + maxmemory, usedmemory int + entries map[string]*cacheEntry + head, tail *cacheEntry +} + +// Return a new instance of a LRU In-Memory Cache. +// Read [the README](./README.md) for more information +// on what is going on with `maxmemory`. +func New(maxmemory int) *Cache { + cache := &Cache{ + maxmemory: maxmemory, + entries: map[string]*cacheEntry{}, + } + cache.cond = sync.NewCond(&cache.mutex) + return cache +} + +// Return the cached value for key `key` or call `computeValue` and +// store its return value in the cache. If called, the closure will be +// called synchronous and __shall not call methods on the same cache__ +// or a deadlock might ocure. If `computeValue` is nil, the cache is checked +// and if no entry was found, nil is returned. If another goroutine is currently +// computing that value, the result is waited for. +func (c *Cache) Get(key string, computeValue ComputeValue) interface{} { + now := time.Now() + + c.mutex.Lock() + if entry, ok := c.entries[key]; ok { + // The expiration not being set is what shows us that + // the computation of that value is still ongoing. + for entry.expiration.IsZero() { + entry.waitingForComputation += 1 + c.cond.Wait() + entry.waitingForComputation -= 1 + } + + if now.After(entry.expiration) { + if !c.evictEntry(entry) { + if entry.expiration.IsZero() { + panic("cache entry that shoud have been waited for could not be evicted.") + } + c.mutex.Unlock() + return entry.value + } + } else { + if entry != c.head { + c.unlinkEntry(entry) + c.insertFront(entry) + } + c.mutex.Unlock() + return entry.value + } + } + + if computeValue == nil { + c.mutex.Unlock() + return nil + } + + entry := &cacheEntry{ + key: key, + waitingForComputation: 1, + } + + c.entries[key] = entry + + hasPaniced := true + defer func() { + if hasPaniced { + c.mutex.Lock() + delete(c.entries, key) + entry.expiration = now + entry.waitingForComputation -= 1 + } + c.mutex.Unlock() + }() + + c.mutex.Unlock() + value, ttl, size := computeValue() + c.mutex.Lock() + hasPaniced = false + + entry.value = value + entry.expiration = now.Add(ttl) + entry.size = size + entry.waitingForComputation -= 1 + + // Only broadcast if other goroutines are actually waiting + // for a result. + if entry.waitingForComputation > 0 { + // TODO: Have more than one condition variable so that there are + // less unnecessary wakeups. + c.cond.Broadcast() + } + + c.usedmemory += size + c.insertFront(entry) + + // Evict only entries with a size of more than zero. + // This is the only loop in the implementation outside of the `Keys` + // method. + evictionCandidate := c.tail + for c.usedmemory > c.maxmemory && evictionCandidate != nil { + nextCandidate := evictionCandidate.prev + if (evictionCandidate.size > 0 || now.After(evictionCandidate.expiration)) && + evictionCandidate.waitingForComputation == 0 { + c.evictEntry(evictionCandidate) + } + evictionCandidate = nextCandidate + } + + return value +} + +// Put a new value in the cache. If another goroutine is calling `Get` and +// computing the value, this function waits for the computation to be done +// before it overwrites the value. +func (c *Cache) Put(key string, value interface{}, size int, ttl time.Duration) { + now := time.Now() + c.mutex.Lock() + defer c.mutex.Unlock() + + if entry, ok := c.entries[key]; ok { + for entry.expiration.IsZero() { + entry.waitingForComputation += 1 + c.cond.Wait() + entry.waitingForComputation -= 1 + } + + c.usedmemory -= entry.size + entry.expiration = now.Add(ttl) + entry.size = size + entry.value = value + c.usedmemory += entry.size + + c.unlinkEntry(entry) + c.insertFront(entry) + return + } + + entry := &cacheEntry{ + key: key, + value: value, + expiration: now.Add(ttl), + } + c.entries[key] = entry + c.insertFront(entry) +} + +// Remove the value at key `key` from the cache. +// Return true if the key was in the cache and false +// otherwise. It is possible that true is returned even +// though the value already expired. +// It is possible that false is returned even though the value +// will show up in the cache if this function is called on a key +// while that key is beeing computed. +func (c *Cache) Del(key string) bool { + c.mutex.Lock() + defer c.mutex.Unlock() + + if entry, ok := c.entries[key]; ok { + return c.evictEntry(entry) + } + return false +} + +// Call f for every entry in the cache. Some sanity checks +// and eviction of expired keys are done as well. +// The cache is fully locked for the complete duration of this call! +func (c *Cache) Keys(f func(key string, val interface{})) { + c.mutex.Lock() + defer c.mutex.Unlock() + + now := time.Now() + + size := 0 + for key, e := range c.entries { + if key != e.key { + panic("key mismatch") + } + + if now.After(e.expiration) { + if c.evictEntry(e) { + continue + } + } + + if e.prev != nil { + if e.prev.next != e { + panic("list corrupted") + } + } + + if e.next != nil { + if e.next.prev != e { + panic("list corrupted") + } + } + + size += e.size + f(key, e.value) + } + + if size != c.usedmemory { + panic("size calculations failed") + } + + if c.head != nil { + if c.tail == nil || c.head.prev != nil { + panic("head/tail corrupted") + } + } + + if c.tail != nil { + if c.head == nil || c.tail.next != nil { + panic("head/tail corrupted") + } + } +} + +func (c *Cache) insertFront(e *cacheEntry) { + e.next = c.head + c.head = e + + e.prev = nil + if e.next != nil { + e.next.prev = e + } + + if c.tail == nil { + c.tail = e + } +} + +func (c *Cache) unlinkEntry(e *cacheEntry) { + if e == c.head { + c.head = e.next + } + if e.prev != nil { + e.prev.next = e.next + } + if e.next != nil { + e.next.prev = e.prev + } + if e == c.tail { + c.tail = e.prev + } +} + +func (c *Cache) evictEntry(e *cacheEntry) bool { + if e.waitingForComputation != 0 { + // panic("cannot evict this entry as other goroutines need the value") + return false + } + + c.unlinkEntry(e) + c.usedmemory -= e.size + delete(c.entries, e.key) + return true +} diff --git a/pkg/lrucache/cache_test.go b/pkg/lrucache/cache_test.go new file mode 100644 index 0000000..bfab653 --- /dev/null +++ b/pkg/lrucache/cache_test.go @@ -0,0 +1,219 @@ +package lrucache + +import ( + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestBasics(t *testing.T) { + cache := New(123) + + value1 := cache.Get("foo", func() (interface{}, time.Duration, int) { + return "bar", 1 * time.Second, 0 + }) + + if value1.(string) != "bar" { + t.Error("cache returned wrong value") + } + + value2 := cache.Get("foo", func() (interface{}, time.Duration, int) { + t.Error("value should be cached") + return "", 0, 0 + }) + + if value2.(string) != "bar" { + t.Error("cache returned wrong value") + } + + existed := cache.Del("foo") + if !existed { + t.Error("delete did not work as expected") + } + + value3 := cache.Get("foo", func() (interface{}, time.Duration, int) { + return "baz", 1 * time.Second, 0 + }) + + if value3.(string) != "baz" { + t.Error("cache returned wrong value") + } + + cache.Keys(func(key string, value interface{}) { + if key != "foo" || value.(string) != "baz" { + t.Error("cache corrupted") + } + }) +} + +func TestExpiration(t *testing.T) { + cache := New(123) + + failIfCalled := func() (interface{}, time.Duration, int) { + t.Error("Value should be cached!") + return "", 0, 0 + } + + val1 := cache.Get("foo", func() (interface{}, time.Duration, int) { + return "bar", 5 * time.Millisecond, 0 + }) + val2 := cache.Get("bar", func() (interface{}, time.Duration, int) { + return "foo", 20 * time.Millisecond, 0 + }) + + val3 := cache.Get("foo", failIfCalled).(string) + val4 := cache.Get("bar", failIfCalled).(string) + + if val1 != val3 || val3 != "bar" || val2 != val4 || val4 != "foo" { + t.Error("Wrong values returned") + } + + time.Sleep(10 * time.Millisecond) + + val5 := cache.Get("foo", func() (interface{}, time.Duration, int) { + return "baz", 0, 0 + }) + val6 := cache.Get("bar", failIfCalled) + + if val5.(string) != "baz" || val6.(string) != "foo" { + t.Error("unexpected values") + } + + cache.Keys(func(key string, val interface{}) { + if key != "bar" || val.(string) != "foo" { + t.Error("wrong value expired") + } + }) + + time.Sleep(15 * time.Millisecond) + cache.Keys(func(key string, val interface{}) { + t.Error("cache should be empty now") + }) +} + +func TestEviction(t *testing.T) { + c := New(100) + failIfCalled := func() (interface{}, time.Duration, int) { + t.Error("Value should be cached!") + return "", 0, 0 + } + + v1 := c.Get("foo", func() (interface{}, time.Duration, int) { + return "bar", 1 * time.Second, 1000 + }) + + v2 := c.Get("foo", func() (interface{}, time.Duration, int) { + return "baz", 1 * time.Second, 1000 + }) + + if v1.(string) != "bar" || v2.(string) != "baz" { + t.Error("wrong values returned") + } + + c.Keys(func(key string, val interface{}) { + t.Error("cache should be empty now") + }) + + _ = c.Get("A", func() (interface{}, time.Duration, int) { + return "a", 1 * time.Second, 50 + }) + + _ = c.Get("B", func() (interface{}, time.Duration, int) { + return "b", 1 * time.Second, 50 + }) + + _ = c.Get("A", failIfCalled) + _ = c.Get("B", failIfCalled) + _ = c.Get("C", func() (interface{}, time.Duration, int) { + return "c", 1 * time.Second, 50 + }) + + _ = c.Get("B", failIfCalled) + _ = c.Get("C", failIfCalled) + + v4 := c.Get("A", func() (interface{}, time.Duration, int) { + return "evicted", 1 * time.Second, 25 + }) + + if v4.(string) != "evicted" { + t.Error("value should have been evicted") + } + + c.Keys(func(key string, val interface{}) { + if key != "A" && key != "C" { + t.Errorf("'%s' was not expected", key) + } + }) +} + +// I know that this is a shity test, +// time is relative and unreliable. +func TestConcurrency(t *testing.T) { + c := New(100) + var wg sync.WaitGroup + + numActions := 20000 + numThreads := 4 + wg.Add(numThreads) + + var concurrentModifications int32 = 0 + + for i := 0; i < numThreads; i++ { + go func() { + for j := 0; j < numActions; j++ { + _ = c.Get("key", func() (interface{}, time.Duration, int) { + m := atomic.AddInt32(&concurrentModifications, 1) + if m != 1 { + t.Error("only one goroutine at a time should calculate a value for the same key") + } + + time.Sleep(1 * time.Millisecond) + atomic.AddInt32(&concurrentModifications, -1) + return "value", 3 * time.Millisecond, 1 + }) + } + + wg.Done() + }() + } + + wg.Wait() + + c.Keys(func(key string, val interface{}) {}) +} + +func TestPanic(t *testing.T) { + c := New(100) + + c.Put("bar", "baz", 3, 1*time.Minute) + + testpanic := func() { + defer func() { + if r := recover(); r != nil { + if r.(string) != "oops" { + t.Fatal("unexpected panic value") + } + } + }() + + _ = c.Get("foo", func() (value interface{}, ttl time.Duration, size int) { + panic("oops") + }) + + t.Fatal("should have paniced!") + } + + testpanic() + + v := c.Get("bar", func() (value interface{}, ttl time.Duration, size int) { + t.Fatal("should not be called!") + return nil, 0, 0 + }) + + if v.(string) != "baz" { + t.Fatal("unexpected value") + } + + testpanic() +} diff --git a/pkg/lrucache/go.mod b/pkg/lrucache/go.mod new file mode 100644 index 0000000..b5574a7 --- /dev/null +++ b/pkg/lrucache/go.mod @@ -0,0 +1,3 @@ +module github.com/iamlouk/lrucache + +go 1.16 diff --git a/pkg/lrucache/handler.go b/pkg/lrucache/handler.go new file mode 100644 index 0000000..e83ba10 --- /dev/null +++ b/pkg/lrucache/handler.go @@ -0,0 +1,120 @@ +package lrucache + +import ( + "bytes" + "net/http" + "strconv" + "time" +) + +// HttpHandler is can be used as HTTP Middleware in order to cache requests, +// for example static assets. By default, the request's raw URI is used as key and nothing else. +// Results with a status code other than 200 are cached with a TTL of zero seconds, +// so basically re-fetched as soon as the current fetch is done and a new request +// for that URI is done. +type HttpHandler struct { + cache *Cache + fetcher http.Handler + defaultTTL time.Duration + + // Allows overriding the way the cache key is extracted + // from the http request. The defailt is to use the RequestURI. + CacheKey func(*http.Request) string +} + +var _ http.Handler = (*HttpHandler)(nil) + +type cachedResponseWriter struct { + w http.ResponseWriter + statusCode int + buf bytes.Buffer +} + +type cachedResponse struct { + headers http.Header + statusCode int + data []byte + fetched time.Time +} + +var _ http.ResponseWriter = (*cachedResponseWriter)(nil) + +func (crw *cachedResponseWriter) Header() http.Header { + return crw.w.Header() +} + +func (crw *cachedResponseWriter) Write(bytes []byte) (int, error) { + return crw.buf.Write(bytes) +} + +func (crw *cachedResponseWriter) WriteHeader(statusCode int) { + crw.statusCode = statusCode +} + +// Returns a new caching HttpHandler. If no entry in the cache is found or it was too old, `fetcher` is called with +// a modified http.ResponseWriter and the response is stored in the cache. If `fetcher` sets the "Expires" header, +// the ttl is set appropriately (otherwise, the default ttl passed as argument here is used). +// `maxmemory` should be in the unit bytes. +func NewHttpHandler(maxmemory int, ttl time.Duration, fetcher http.Handler) *HttpHandler { + return &HttpHandler{ + cache: New(maxmemory), + defaultTTL: ttl, + fetcher: fetcher, + CacheKey: func(r *http.Request) string { + return r.RequestURI + }, + } +} + +// gorilla/mux style middleware: +func NewMiddleware(maxmemory int, ttl time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return NewHttpHandler(maxmemory, ttl, next) + } +} + +// Tries to serve a response to r from cache or calls next and stores the response to the cache for the next time. +func (h *HttpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + h.ServeHTTP(rw, r) + return + } + + cr := h.cache.Get(h.CacheKey(r), func() (interface{}, time.Duration, int) { + crw := &cachedResponseWriter{ + w: rw, + statusCode: 200, + buf: bytes.Buffer{}, + } + + h.fetcher.ServeHTTP(crw, r) + + cr := &cachedResponse{ + headers: rw.Header().Clone(), + statusCode: crw.statusCode, + data: crw.buf.Bytes(), + fetched: time.Now(), + } + cr.headers.Set("Content-Length", strconv.Itoa(len(cr.data))) + + ttl := h.defaultTTL + if cr.statusCode != http.StatusOK { + ttl = 0 + } else if cr.headers.Get("Expires") != "" { + if expires, err := http.ParseTime(cr.headers.Get("Expires")); err == nil { + ttl = time.Until(expires) + } + } + + return cr, ttl, len(cr.data) + }).(*cachedResponse) + + for key, val := range cr.headers { + rw.Header()[key] = val + } + + cr.headers.Set("Age", strconv.Itoa(int(time.Since(cr.fetched).Seconds()))) + + rw.WriteHeader(cr.statusCode) + rw.Write(cr.data) +} diff --git a/pkg/lrucache/handler_test.go b/pkg/lrucache/handler_test.go new file mode 100644 index 0000000..a241089 --- /dev/null +++ b/pkg/lrucache/handler_test.go @@ -0,0 +1,72 @@ +package lrucache + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" +) + +func TestHandlerBasics(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/test1", nil) + rw := httptest.NewRecorder() + shouldBeCalled := true + + handler := NewHttpHandler(1000, time.Second, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("Hello World!")) + + if !shouldBeCalled { + t.Fatal("fetcher expected to be called") + } + })) + + handler.ServeHTTP(rw, r) + + if rw.Code != 200 { + t.Fatal("unexpected status code") + } + + if !bytes.Equal(rw.Body.Bytes(), []byte("Hello World!")) { + t.Fatal("unexpected body") + } + + rw = httptest.NewRecorder() + shouldBeCalled = false + handler.ServeHTTP(rw, r) + + if rw.Code != 200 { + t.Fatal("unexpected status code") + } + + if !bytes.Equal(rw.Body.Bytes(), []byte("Hello World!")) { + t.Fatal("unexpected body") + } +} + +func TestHandlerExpiration(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/test1", nil) + rw := httptest.NewRecorder() + i := 1 + now := time.Now() + + handler := NewHttpHandler(1000, 1*time.Second, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Set("Expires", now.Add(10*time.Millisecond).Format(http.TimeFormat)) + rw.Write([]byte(strconv.Itoa(i))) + })) + + handler.ServeHTTP(rw, r) + if !(rw.Body.String() == strconv.Itoa(1)) { + t.Fatal("unexpected body") + } + + i += 1 + + time.Sleep(11 * time.Millisecond) + rw = httptest.NewRecorder() + handler.ServeHTTP(rw, r) + if !(rw.Body.String() == strconv.Itoa(1)) { + t.Fatal("unexpected body") + } +}