mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2024-12-25 12:59:06 +01:00
Include lrucache external dependency
This commit is contained in:
parent
81819db436
commit
27800b651a
121
pkg/lrucache/README.md
Normal file
121
pkg/lrucache/README.md
Normal file
@ -0,0 +1,121 @@
|
||||
# In-Memory LRU Cache for Golang Applications
|
||||
|
||||
[![](https://pkg.go.dev/badge/github.com/iamlouk/lrucache?utm_source=godoc)](https://pkg.go.dev/github.com/iamlouk/lrucache)
|
||||
|
||||
This library can be embedded into your existing go applications
|
||||
and play the role *Memcached* or *Redis* might play for others.
|
||||
It is inspired by [PHP Symfony's Cache Components](https://symfony.com/doc/current/components/cache/adapters/array_cache_adapter.html),
|
||||
having a similar API. This library can not be used for persistance,
|
||||
is not properly tested yet and a bit special in a few ways described
|
||||
below (Especially with regards to the memory usage/`size`).
|
||||
|
||||
In addition to the interface described below, a `http.Handler` that can be used as middleware is provided as well.
|
||||
|
||||
- Advantages:
|
||||
- Anything (`interface{}`) can be stored as value
|
||||
- As it lives in the application itself, no serialization or de-serialization is needed
|
||||
- As it lives in the application itself, no memory moving/networking is needed
|
||||
- The computation of a new value for a key does __not__ block the full cache (only the key)
|
||||
- Disadvantages:
|
||||
- You have to provide a size estimate for every value
|
||||
- __This size estimate should not change (i.e. values should not mutate)__
|
||||
- The cache can only be accessed by one application
|
||||
|
||||
## Example
|
||||
|
||||
```go
|
||||
// Go look at the godocs and ./cache_test.go for more documentation and examples
|
||||
|
||||
maxMemory := 1000
|
||||
cache := lrucache.New(maxMemory)
|
||||
|
||||
bar = cache.Get("foo", func () (value interface{}, ttl time.Duration, size int) {
|
||||
return "bar", 10 * time.Second, len("bar")
|
||||
}).(string)
|
||||
|
||||
// bar == "bar"
|
||||
|
||||
bar = cache.Get("foo", func () (value interface{}, ttl time.Duration, size int) {
|
||||
panic("will not be called")
|
||||
}).(string)
|
||||
```
|
||||
|
||||
## Why does `cache.Get` take a function as argument?
|
||||
|
||||
*Using the mechanism described below is optional, the second argument to `Get` can be `nil` and there is a `Put` function as well.*
|
||||
|
||||
Because this library is meant to be used by multi threaded applications and the following would
|
||||
result in the same data being fetched twice if both goroutines run in parallel:
|
||||
|
||||
```go
|
||||
// This code shows what could happen with other cache libraries
|
||||
c := lrucache.New(MAX_CACHE_ENTRIES)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
go func(){
|
||||
// This code will run twice in different goroutines,
|
||||
// it could overlap. As `fetchData` probably does some
|
||||
// I/O and takes a long time, the probability of both
|
||||
// goroutines calling `fetchData` is very high!
|
||||
url := "http://example.com/foo"
|
||||
contents := c.Get(url)
|
||||
if contents == nil {
|
||||
contents = fetchData(url)
|
||||
c.Set(url, contents)
|
||||
}
|
||||
|
||||
handleData(contents.([]byte))
|
||||
}()
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
Here, if one wanted to make sure that only one of both goroutines fetches the data,
|
||||
the programmer would need to build his own synchronization. That would suck!
|
||||
|
||||
```go
|
||||
c := lrucache.New(MAX_CACHE_SIZE)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
go func(){
|
||||
url := "http://example.com/foo"
|
||||
contents := c.Get(url, func()(interface{}, time.Time, int) {
|
||||
// This closure will only be called once!
|
||||
// If another goroutine calls `c.Get` while this closure
|
||||
// is still being executed, it will wait.
|
||||
buf := fetchData(url)
|
||||
return buf, 100 * time.Second, len(buf)
|
||||
})
|
||||
|
||||
handleData(contents.([]byte))
|
||||
}()
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
This is much better as less resources are wasted and synchronization is handled by
|
||||
the library. If it gets called, the call to the closure happens synchronously. While
|
||||
it is being executed, all other cache keys can still be accessed without having to wait
|
||||
for the execution to be done.
|
||||
|
||||
## How `Get` works
|
||||
|
||||
The closure passed to `Get` will be called if the value asked for is not cached or
|
||||
expired. It should return the following values:
|
||||
|
||||
- The value corresponding to that key and to be stored in the cache
|
||||
- The time to live for that value (how long until it expires and needs to be recomputed)
|
||||
- A size estimate
|
||||
|
||||
When `maxMemory` is reached, cache entries need to be evicted. Theoretically,
|
||||
it would be possible to use reflection on every value placed in the cache
|
||||
to get its exact size in bytes. This would be very expansive and slow though.
|
||||
Also, size can change. Instead of this library calculating the size in bytes, you, the user,
|
||||
have to provide a size for every value in whatever unit you like (as long as it is the same unit everywhere).
|
||||
|
||||
Suggestions on what to use as size: `len(str)` for strings, `len(slice) * size_of_slice_type`, etc.. It is possible
|
||||
to use `1` as size for every entry, in that case at most `maxMemory` entries will be in the cache at the same time.
|
||||
|
||||
## Affects on GC
|
||||
|
||||
Because of the way a garbage collector decides when to run ([explained in the runtime package](https://pkg.go.dev/runtime)), having large amounts of data sitting in your cache might increase the memory consumption of your process by two times the maximum size of the cache. You can decrease the *target percentage* to reduce the effect, but then you might have negative performance effects when your cache is not filled.
|
288
pkg/lrucache/cache.go
Normal file
288
pkg/lrucache/cache.go
Normal file
@ -0,0 +1,288 @@
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Type of the closure that must be passed to `Get` to
|
||||
// compute the value in case it is not cached.
|
||||
//
|
||||
// returned values are the computed value to be stored in the cache,
|
||||
// the duration until this value will expire and a size estimate.
|
||||
type ComputeValue func() (value interface{}, ttl time.Duration, size int)
|
||||
|
||||
type cacheEntry struct {
|
||||
key string
|
||||
value interface{}
|
||||
|
||||
expiration time.Time
|
||||
size int
|
||||
waitingForComputation int
|
||||
|
||||
next, prev *cacheEntry
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
mutex sync.Mutex
|
||||
cond *sync.Cond
|
||||
maxmemory, usedmemory int
|
||||
entries map[string]*cacheEntry
|
||||
head, tail *cacheEntry
|
||||
}
|
||||
|
||||
// Return a new instance of a LRU In-Memory Cache.
|
||||
// Read [the README](./README.md) for more information
|
||||
// on what is going on with `maxmemory`.
|
||||
func New(maxmemory int) *Cache {
|
||||
cache := &Cache{
|
||||
maxmemory: maxmemory,
|
||||
entries: map[string]*cacheEntry{},
|
||||
}
|
||||
cache.cond = sync.NewCond(&cache.mutex)
|
||||
return cache
|
||||
}
|
||||
|
||||
// Return the cached value for key `key` or call `computeValue` and
|
||||
// store its return value in the cache. If called, the closure will be
|
||||
// called synchronous and __shall not call methods on the same cache__
|
||||
// or a deadlock might ocure. If `computeValue` is nil, the cache is checked
|
||||
// and if no entry was found, nil is returned. If another goroutine is currently
|
||||
// computing that value, the result is waited for.
|
||||
func (c *Cache) Get(key string, computeValue ComputeValue) interface{} {
|
||||
now := time.Now()
|
||||
|
||||
c.mutex.Lock()
|
||||
if entry, ok := c.entries[key]; ok {
|
||||
// The expiration not being set is what shows us that
|
||||
// the computation of that value is still ongoing.
|
||||
for entry.expiration.IsZero() {
|
||||
entry.waitingForComputation += 1
|
||||
c.cond.Wait()
|
||||
entry.waitingForComputation -= 1
|
||||
}
|
||||
|
||||
if now.After(entry.expiration) {
|
||||
if !c.evictEntry(entry) {
|
||||
if entry.expiration.IsZero() {
|
||||
panic("cache entry that shoud have been waited for could not be evicted.")
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
} else {
|
||||
if entry != c.head {
|
||||
c.unlinkEntry(entry)
|
||||
c.insertFront(entry)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
}
|
||||
|
||||
if computeValue == nil {
|
||||
c.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
entry := &cacheEntry{
|
||||
key: key,
|
||||
waitingForComputation: 1,
|
||||
}
|
||||
|
||||
c.entries[key] = entry
|
||||
|
||||
hasPaniced := true
|
||||
defer func() {
|
||||
if hasPaniced {
|
||||
c.mutex.Lock()
|
||||
delete(c.entries, key)
|
||||
entry.expiration = now
|
||||
entry.waitingForComputation -= 1
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}()
|
||||
|
||||
c.mutex.Unlock()
|
||||
value, ttl, size := computeValue()
|
||||
c.mutex.Lock()
|
||||
hasPaniced = false
|
||||
|
||||
entry.value = value
|
||||
entry.expiration = now.Add(ttl)
|
||||
entry.size = size
|
||||
entry.waitingForComputation -= 1
|
||||
|
||||
// Only broadcast if other goroutines are actually waiting
|
||||
// for a result.
|
||||
if entry.waitingForComputation > 0 {
|
||||
// TODO: Have more than one condition variable so that there are
|
||||
// less unnecessary wakeups.
|
||||
c.cond.Broadcast()
|
||||
}
|
||||
|
||||
c.usedmemory += size
|
||||
c.insertFront(entry)
|
||||
|
||||
// Evict only entries with a size of more than zero.
|
||||
// This is the only loop in the implementation outside of the `Keys`
|
||||
// method.
|
||||
evictionCandidate := c.tail
|
||||
for c.usedmemory > c.maxmemory && evictionCandidate != nil {
|
||||
nextCandidate := evictionCandidate.prev
|
||||
if (evictionCandidate.size > 0 || now.After(evictionCandidate.expiration)) &&
|
||||
evictionCandidate.waitingForComputation == 0 {
|
||||
c.evictEntry(evictionCandidate)
|
||||
}
|
||||
evictionCandidate = nextCandidate
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// Put a new value in the cache. If another goroutine is calling `Get` and
|
||||
// computing the value, this function waits for the computation to be done
|
||||
// before it overwrites the value.
|
||||
func (c *Cache) Put(key string, value interface{}, size int, ttl time.Duration) {
|
||||
now := time.Now()
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if entry, ok := c.entries[key]; ok {
|
||||
for entry.expiration.IsZero() {
|
||||
entry.waitingForComputation += 1
|
||||
c.cond.Wait()
|
||||
entry.waitingForComputation -= 1
|
||||
}
|
||||
|
||||
c.usedmemory -= entry.size
|
||||
entry.expiration = now.Add(ttl)
|
||||
entry.size = size
|
||||
entry.value = value
|
||||
c.usedmemory += entry.size
|
||||
|
||||
c.unlinkEntry(entry)
|
||||
c.insertFront(entry)
|
||||
return
|
||||
}
|
||||
|
||||
entry := &cacheEntry{
|
||||
key: key,
|
||||
value: value,
|
||||
expiration: now.Add(ttl),
|
||||
}
|
||||
c.entries[key] = entry
|
||||
c.insertFront(entry)
|
||||
}
|
||||
|
||||
// Remove the value at key `key` from the cache.
|
||||
// Return true if the key was in the cache and false
|
||||
// otherwise. It is possible that true is returned even
|
||||
// though the value already expired.
|
||||
// It is possible that false is returned even though the value
|
||||
// will show up in the cache if this function is called on a key
|
||||
// while that key is beeing computed.
|
||||
func (c *Cache) Del(key string) bool {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if entry, ok := c.entries[key]; ok {
|
||||
return c.evictEntry(entry)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Call f for every entry in the cache. Some sanity checks
|
||||
// and eviction of expired keys are done as well.
|
||||
// The cache is fully locked for the complete duration of this call!
|
||||
func (c *Cache) Keys(f func(key string, val interface{})) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
size := 0
|
||||
for key, e := range c.entries {
|
||||
if key != e.key {
|
||||
panic("key mismatch")
|
||||
}
|
||||
|
||||
if now.After(e.expiration) {
|
||||
if c.evictEntry(e) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if e.prev != nil {
|
||||
if e.prev.next != e {
|
||||
panic("list corrupted")
|
||||
}
|
||||
}
|
||||
|
||||
if e.next != nil {
|
||||
if e.next.prev != e {
|
||||
panic("list corrupted")
|
||||
}
|
||||
}
|
||||
|
||||
size += e.size
|
||||
f(key, e.value)
|
||||
}
|
||||
|
||||
if size != c.usedmemory {
|
||||
panic("size calculations failed")
|
||||
}
|
||||
|
||||
if c.head != nil {
|
||||
if c.tail == nil || c.head.prev != nil {
|
||||
panic("head/tail corrupted")
|
||||
}
|
||||
}
|
||||
|
||||
if c.tail != nil {
|
||||
if c.head == nil || c.tail.next != nil {
|
||||
panic("head/tail corrupted")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) insertFront(e *cacheEntry) {
|
||||
e.next = c.head
|
||||
c.head = e
|
||||
|
||||
e.prev = nil
|
||||
if e.next != nil {
|
||||
e.next.prev = e
|
||||
}
|
||||
|
||||
if c.tail == nil {
|
||||
c.tail = e
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) unlinkEntry(e *cacheEntry) {
|
||||
if e == c.head {
|
||||
c.head = e.next
|
||||
}
|
||||
if e.prev != nil {
|
||||
e.prev.next = e.next
|
||||
}
|
||||
if e.next != nil {
|
||||
e.next.prev = e.prev
|
||||
}
|
||||
if e == c.tail {
|
||||
c.tail = e.prev
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) evictEntry(e *cacheEntry) bool {
|
||||
if e.waitingForComputation != 0 {
|
||||
// panic("cannot evict this entry as other goroutines need the value")
|
||||
return false
|
||||
}
|
||||
|
||||
c.unlinkEntry(e)
|
||||
c.usedmemory -= e.size
|
||||
delete(c.entries, e.key)
|
||||
return true
|
||||
}
|
219
pkg/lrucache/cache_test.go
Normal file
219
pkg/lrucache/cache_test.go
Normal file
@ -0,0 +1,219 @@
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBasics(t *testing.T) {
|
||||
cache := New(123)
|
||||
|
||||
value1 := cache.Get("foo", func() (interface{}, time.Duration, int) {
|
||||
return "bar", 1 * time.Second, 0
|
||||
})
|
||||
|
||||
if value1.(string) != "bar" {
|
||||
t.Error("cache returned wrong value")
|
||||
}
|
||||
|
||||
value2 := cache.Get("foo", func() (interface{}, time.Duration, int) {
|
||||
t.Error("value should be cached")
|
||||
return "", 0, 0
|
||||
})
|
||||
|
||||
if value2.(string) != "bar" {
|
||||
t.Error("cache returned wrong value")
|
||||
}
|
||||
|
||||
existed := cache.Del("foo")
|
||||
if !existed {
|
||||
t.Error("delete did not work as expected")
|
||||
}
|
||||
|
||||
value3 := cache.Get("foo", func() (interface{}, time.Duration, int) {
|
||||
return "baz", 1 * time.Second, 0
|
||||
})
|
||||
|
||||
if value3.(string) != "baz" {
|
||||
t.Error("cache returned wrong value")
|
||||
}
|
||||
|
||||
cache.Keys(func(key string, value interface{}) {
|
||||
if key != "foo" || value.(string) != "baz" {
|
||||
t.Error("cache corrupted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiration(t *testing.T) {
|
||||
cache := New(123)
|
||||
|
||||
failIfCalled := func() (interface{}, time.Duration, int) {
|
||||
t.Error("Value should be cached!")
|
||||
return "", 0, 0
|
||||
}
|
||||
|
||||
val1 := cache.Get("foo", func() (interface{}, time.Duration, int) {
|
||||
return "bar", 5 * time.Millisecond, 0
|
||||
})
|
||||
val2 := cache.Get("bar", func() (interface{}, time.Duration, int) {
|
||||
return "foo", 20 * time.Millisecond, 0
|
||||
})
|
||||
|
||||
val3 := cache.Get("foo", failIfCalled).(string)
|
||||
val4 := cache.Get("bar", failIfCalled).(string)
|
||||
|
||||
if val1 != val3 || val3 != "bar" || val2 != val4 || val4 != "foo" {
|
||||
t.Error("Wrong values returned")
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
val5 := cache.Get("foo", func() (interface{}, time.Duration, int) {
|
||||
return "baz", 0, 0
|
||||
})
|
||||
val6 := cache.Get("bar", failIfCalled)
|
||||
|
||||
if val5.(string) != "baz" || val6.(string) != "foo" {
|
||||
t.Error("unexpected values")
|
||||
}
|
||||
|
||||
cache.Keys(func(key string, val interface{}) {
|
||||
if key != "bar" || val.(string) != "foo" {
|
||||
t.Error("wrong value expired")
|
||||
}
|
||||
})
|
||||
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
cache.Keys(func(key string, val interface{}) {
|
||||
t.Error("cache should be empty now")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEviction(t *testing.T) {
|
||||
c := New(100)
|
||||
failIfCalled := func() (interface{}, time.Duration, int) {
|
||||
t.Error("Value should be cached!")
|
||||
return "", 0, 0
|
||||
}
|
||||
|
||||
v1 := c.Get("foo", func() (interface{}, time.Duration, int) {
|
||||
return "bar", 1 * time.Second, 1000
|
||||
})
|
||||
|
||||
v2 := c.Get("foo", func() (interface{}, time.Duration, int) {
|
||||
return "baz", 1 * time.Second, 1000
|
||||
})
|
||||
|
||||
if v1.(string) != "bar" || v2.(string) != "baz" {
|
||||
t.Error("wrong values returned")
|
||||
}
|
||||
|
||||
c.Keys(func(key string, val interface{}) {
|
||||
t.Error("cache should be empty now")
|
||||
})
|
||||
|
||||
_ = c.Get("A", func() (interface{}, time.Duration, int) {
|
||||
return "a", 1 * time.Second, 50
|
||||
})
|
||||
|
||||
_ = c.Get("B", func() (interface{}, time.Duration, int) {
|
||||
return "b", 1 * time.Second, 50
|
||||
})
|
||||
|
||||
_ = c.Get("A", failIfCalled)
|
||||
_ = c.Get("B", failIfCalled)
|
||||
_ = c.Get("C", func() (interface{}, time.Duration, int) {
|
||||
return "c", 1 * time.Second, 50
|
||||
})
|
||||
|
||||
_ = c.Get("B", failIfCalled)
|
||||
_ = c.Get("C", failIfCalled)
|
||||
|
||||
v4 := c.Get("A", func() (interface{}, time.Duration, int) {
|
||||
return "evicted", 1 * time.Second, 25
|
||||
})
|
||||
|
||||
if v4.(string) != "evicted" {
|
||||
t.Error("value should have been evicted")
|
||||
}
|
||||
|
||||
c.Keys(func(key string, val interface{}) {
|
||||
if key != "A" && key != "C" {
|
||||
t.Errorf("'%s' was not expected", key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// I know that this is a shity test,
|
||||
// time is relative and unreliable.
|
||||
func TestConcurrency(t *testing.T) {
|
||||
c := New(100)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
numActions := 20000
|
||||
numThreads := 4
|
||||
wg.Add(numThreads)
|
||||
|
||||
var concurrentModifications int32 = 0
|
||||
|
||||
for i := 0; i < numThreads; i++ {
|
||||
go func() {
|
||||
for j := 0; j < numActions; j++ {
|
||||
_ = c.Get("key", func() (interface{}, time.Duration, int) {
|
||||
m := atomic.AddInt32(&concurrentModifications, 1)
|
||||
if m != 1 {
|
||||
t.Error("only one goroutine at a time should calculate a value for the same key")
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
atomic.AddInt32(&concurrentModifications, -1)
|
||||
return "value", 3 * time.Millisecond, 1
|
||||
})
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
c.Keys(func(key string, val interface{}) {})
|
||||
}
|
||||
|
||||
func TestPanic(t *testing.T) {
|
||||
c := New(100)
|
||||
|
||||
c.Put("bar", "baz", 3, 1*time.Minute)
|
||||
|
||||
testpanic := func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if r.(string) != "oops" {
|
||||
t.Fatal("unexpected panic value")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
_ = c.Get("foo", func() (value interface{}, ttl time.Duration, size int) {
|
||||
panic("oops")
|
||||
})
|
||||
|
||||
t.Fatal("should have paniced!")
|
||||
}
|
||||
|
||||
testpanic()
|
||||
|
||||
v := c.Get("bar", func() (value interface{}, ttl time.Duration, size int) {
|
||||
t.Fatal("should not be called!")
|
||||
return nil, 0, 0
|
||||
})
|
||||
|
||||
if v.(string) != "baz" {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
|
||||
testpanic()
|
||||
}
|
3
pkg/lrucache/go.mod
Normal file
3
pkg/lrucache/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module github.com/iamlouk/lrucache
|
||||
|
||||
go 1.16
|
120
pkg/lrucache/handler.go
Normal file
120
pkg/lrucache/handler.go
Normal file
@ -0,0 +1,120 @@
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HttpHandler is can be used as HTTP Middleware in order to cache requests,
|
||||
// for example static assets. By default, the request's raw URI is used as key and nothing else.
|
||||
// Results with a status code other than 200 are cached with a TTL of zero seconds,
|
||||
// so basically re-fetched as soon as the current fetch is done and a new request
|
||||
// for that URI is done.
|
||||
type HttpHandler struct {
|
||||
cache *Cache
|
||||
fetcher http.Handler
|
||||
defaultTTL time.Duration
|
||||
|
||||
// Allows overriding the way the cache key is extracted
|
||||
// from the http request. The defailt is to use the RequestURI.
|
||||
CacheKey func(*http.Request) string
|
||||
}
|
||||
|
||||
var _ http.Handler = (*HttpHandler)(nil)
|
||||
|
||||
type cachedResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
statusCode int
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
type cachedResponse struct {
|
||||
headers http.Header
|
||||
statusCode int
|
||||
data []byte
|
||||
fetched time.Time
|
||||
}
|
||||
|
||||
var _ http.ResponseWriter = (*cachedResponseWriter)(nil)
|
||||
|
||||
func (crw *cachedResponseWriter) Header() http.Header {
|
||||
return crw.w.Header()
|
||||
}
|
||||
|
||||
func (crw *cachedResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return crw.buf.Write(bytes)
|
||||
}
|
||||
|
||||
func (crw *cachedResponseWriter) WriteHeader(statusCode int) {
|
||||
crw.statusCode = statusCode
|
||||
}
|
||||
|
||||
// Returns a new caching HttpHandler. If no entry in the cache is found or it was too old, `fetcher` is called with
|
||||
// a modified http.ResponseWriter and the response is stored in the cache. If `fetcher` sets the "Expires" header,
|
||||
// the ttl is set appropriately (otherwise, the default ttl passed as argument here is used).
|
||||
// `maxmemory` should be in the unit bytes.
|
||||
func NewHttpHandler(maxmemory int, ttl time.Duration, fetcher http.Handler) *HttpHandler {
|
||||
return &HttpHandler{
|
||||
cache: New(maxmemory),
|
||||
defaultTTL: ttl,
|
||||
fetcher: fetcher,
|
||||
CacheKey: func(r *http.Request) string {
|
||||
return r.RequestURI
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// gorilla/mux style middleware:
|
||||
func NewMiddleware(maxmemory int, ttl time.Duration) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return NewHttpHandler(maxmemory, ttl, next)
|
||||
}
|
||||
}
|
||||
|
||||
// Tries to serve a response to r from cache or calls next and stores the response to the cache for the next time.
|
||||
func (h *HttpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
h.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
cr := h.cache.Get(h.CacheKey(r), func() (interface{}, time.Duration, int) {
|
||||
crw := &cachedResponseWriter{
|
||||
w: rw,
|
||||
statusCode: 200,
|
||||
buf: bytes.Buffer{},
|
||||
}
|
||||
|
||||
h.fetcher.ServeHTTP(crw, r)
|
||||
|
||||
cr := &cachedResponse{
|
||||
headers: rw.Header().Clone(),
|
||||
statusCode: crw.statusCode,
|
||||
data: crw.buf.Bytes(),
|
||||
fetched: time.Now(),
|
||||
}
|
||||
cr.headers.Set("Content-Length", strconv.Itoa(len(cr.data)))
|
||||
|
||||
ttl := h.defaultTTL
|
||||
if cr.statusCode != http.StatusOK {
|
||||
ttl = 0
|
||||
} else if cr.headers.Get("Expires") != "" {
|
||||
if expires, err := http.ParseTime(cr.headers.Get("Expires")); err == nil {
|
||||
ttl = time.Until(expires)
|
||||
}
|
||||
}
|
||||
|
||||
return cr, ttl, len(cr.data)
|
||||
}).(*cachedResponse)
|
||||
|
||||
for key, val := range cr.headers {
|
||||
rw.Header()[key] = val
|
||||
}
|
||||
|
||||
cr.headers.Set("Age", strconv.Itoa(int(time.Since(cr.fetched).Seconds())))
|
||||
|
||||
rw.WriteHeader(cr.statusCode)
|
||||
rw.Write(cr.data)
|
||||
}
|
72
pkg/lrucache/handler_test.go
Normal file
72
pkg/lrucache/handler_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHandlerBasics(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/test1", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
shouldBeCalled := true
|
||||
|
||||
handler := NewHttpHandler(1000, time.Second, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Hello World!"))
|
||||
|
||||
if !shouldBeCalled {
|
||||
t.Fatal("fetcher expected to be called")
|
||||
}
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(rw, r)
|
||||
|
||||
if rw.Code != 200 {
|
||||
t.Fatal("unexpected status code")
|
||||
}
|
||||
|
||||
if !bytes.Equal(rw.Body.Bytes(), []byte("Hello World!")) {
|
||||
t.Fatal("unexpected body")
|
||||
}
|
||||
|
||||
rw = httptest.NewRecorder()
|
||||
shouldBeCalled = false
|
||||
handler.ServeHTTP(rw, r)
|
||||
|
||||
if rw.Code != 200 {
|
||||
t.Fatal("unexpected status code")
|
||||
}
|
||||
|
||||
if !bytes.Equal(rw.Body.Bytes(), []byte("Hello World!")) {
|
||||
t.Fatal("unexpected body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerExpiration(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/test1", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
i := 1
|
||||
now := time.Now()
|
||||
|
||||
handler := NewHttpHandler(1000, 1*time.Second, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set("Expires", now.Add(10*time.Millisecond).Format(http.TimeFormat))
|
||||
rw.Write([]byte(strconv.Itoa(i)))
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(rw, r)
|
||||
if !(rw.Body.String() == strconv.Itoa(1)) {
|
||||
t.Fatal("unexpected body")
|
||||
}
|
||||
|
||||
i += 1
|
||||
|
||||
time.Sleep(11 * time.Millisecond)
|
||||
rw = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, r)
|
||||
if !(rw.Body.String() == strconv.Itoa(1)) {
|
||||
t.Fatal("unexpected body")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user