// Copyright (C) 2022 NHR@FAU, University Erlangen-Nuremberg. // All rights reserved. // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. package lrucache import ( "bytes" "net/http" "strconv" "time" ) // HttpHandler is can be used as HTTP Middleware in order to cache requests, // for example static assets. By default, the request's raw URI is used as key and nothing else. // Results with a status code other than 200 are cached with a TTL of zero seconds, // so basically re-fetched as soon as the current fetch is done and a new request // for that URI is done. type HttpHandler struct { cache *Cache fetcher http.Handler defaultTTL time.Duration // Allows overriding the way the cache key is extracted // from the http request. The defailt is to use the RequestURI. CacheKey func(*http.Request) string } var _ http.Handler = (*HttpHandler)(nil) type cachedResponseWriter struct { w http.ResponseWriter statusCode int buf bytes.Buffer } type cachedResponse struct { headers http.Header statusCode int data []byte fetched time.Time } var _ http.ResponseWriter = (*cachedResponseWriter)(nil) func (crw *cachedResponseWriter) Header() http.Header { return crw.w.Header() } func (crw *cachedResponseWriter) Write(bytes []byte) (int, error) { return crw.buf.Write(bytes) } func (crw *cachedResponseWriter) WriteHeader(statusCode int) { crw.statusCode = statusCode } // Returns a new caching HttpHandler. If no entry in the cache is found or it was too old, `fetcher` is called with // a modified http.ResponseWriter and the response is stored in the cache. If `fetcher` sets the "Expires" header, // the ttl is set appropriately (otherwise, the default ttl passed as argument here is used). // `maxmemory` should be in the unit bytes. func NewHttpHandler(maxmemory int, ttl time.Duration, fetcher http.Handler) *HttpHandler { return &HttpHandler{ cache: New(maxmemory), defaultTTL: ttl, fetcher: fetcher, CacheKey: func(r *http.Request) string { return r.RequestURI }, } } // gorilla/mux style middleware: func NewMiddleware(maxmemory int, ttl time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return NewHttpHandler(maxmemory, ttl, next) } } // Tries to serve a response to r from cache or calls next and stores the response to the cache for the next time. func (h *HttpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { h.ServeHTTP(rw, r) return } cr := h.cache.Get(h.CacheKey(r), func() (interface{}, time.Duration, int) { crw := &cachedResponseWriter{ w: rw, statusCode: 200, buf: bytes.Buffer{}, } h.fetcher.ServeHTTP(crw, r) cr := &cachedResponse{ headers: rw.Header().Clone(), statusCode: crw.statusCode, data: crw.buf.Bytes(), fetched: time.Now(), } cr.headers.Set("Content-Length", strconv.Itoa(len(cr.data))) ttl := h.defaultTTL if cr.statusCode != http.StatusOK { ttl = 0 } else if cr.headers.Get("Expires") != "" { if expires, err := http.ParseTime(cr.headers.Get("Expires")); err == nil { ttl = time.Until(expires) } } return cr, ttl, len(cr.data) }).(*cachedResponse) for key, val := range cr.headers { rw.Header()[key] = val } cr.headers.Set("Age", strconv.Itoa(int(time.Since(cr.fetched).Seconds()))) rw.WriteHeader(cr.statusCode) rw.Write(cr.data) }