From b1cf4dc1a22eb4fc45233166aa502875e25294d4 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 5 Apr 2022 20:23:16 +0800 Subject: [PATCH] Refactor: lrucache use generics --- common/cache/lrucache.go | 97 ++++++++++++++++++----------------- common/cache/lrucache_test.go | 41 +++++++-------- component/fakeip/memory.go | 40 +++++++++------ component/fakeip/pool.go | 5 +- dns/enhancer.go | 8 +-- dns/middleware.go | 2 +- dns/resolver.go | 8 +-- dns/util.go | 2 +- 8 files changed, 106 insertions(+), 97 deletions(-) diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 0bea06f6..82eca7f4 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -9,43 +9,43 @@ import ( ) // Option is part of Functional Options Pattern -type Option func(*LruCache) +type Option[K comparable, V any] func(*LruCache[K, V]) // EvictCallback is used to get a callback when a cache entry is evicted type EvictCallback = func(key any, value any) // WithEvict set the evict callback -func WithEvict(cb EvictCallback) Option { - return func(l *LruCache) { +func WithEvict[K comparable, V any](cb EvictCallback) Option[K, V] { + return func(l *LruCache[K, V]) { l.onEvict = cb } } // WithUpdateAgeOnGet update expires when Get element -func WithUpdateAgeOnGet() Option { - return func(l *LruCache) { +func WithUpdateAgeOnGet[K comparable, V any]() Option[K, V] { + return func(l *LruCache[K, V]) { l.updateAgeOnGet = true } } // WithAge defined element max age (second) -func WithAge(maxAge int64) Option { - return func(l *LruCache) { +func WithAge[K comparable, V any](maxAge int64) Option[K, V] { + return func(l *LruCache[K, V]) { l.maxAge = maxAge } } // WithSize defined max length of LruCache -func WithSize(maxSize int) Option { - return func(l *LruCache) { +func WithSize[K comparable, V any](maxSize int) Option[K, V] { + return func(l *LruCache[K, V]) { l.maxSize = maxSize } } // WithStale decide whether Stale return is enabled. // If this feature is enabled, element will not get Evicted according to `WithAge`. -func WithStale(stale bool) Option { - return func(l *LruCache) { +func WithStale[K comparable, V any](stale bool) Option[K, V] { + return func(l *LruCache[K, V]) { l.staleReturn = stale } } @@ -53,7 +53,7 @@ func WithStale(stale bool) Option { // LruCache is a thread-safe, in-memory lru-cache that evicts the // least recently used entries from memory when (if set) the entries are // older than maxAge (in seconds). Use the New constructor to create one. -type LruCache struct { +type LruCache[K comparable, V any] struct { maxAge int64 maxSize int mu sync.Mutex @@ -65,8 +65,8 @@ type LruCache struct { } // NewLRUCache creates an LruCache -func NewLRUCache(options ...Option) *LruCache { - lc := &LruCache{ +func NewLRUCache[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] { + lc := &LruCache[K, V]{ lru: list.New(), cache: make(map[any]*list.Element), } @@ -80,12 +80,12 @@ func NewLRUCache(options ...Option) *LruCache { // Get returns the any representation of a cached response and a bool // set to true if the key was found. -func (c *LruCache) Get(key any) (any, bool) { - entry := c.get(key) - if entry == nil { - return nil, false +func (c *LruCache[K, V]) Get(key K) (V, bool) { + el := c.get(key) + if el == nil { + return getZero[V](), false } - value := entry.value + value := el.value return value, true } @@ -94,17 +94,17 @@ func (c *LruCache) Get(key any) (any, bool) { // a time.Time Give expected expires, // and a bool set to true if the key was found. // This method will NOT check the maxAge of element and will NOT update the expires. -func (c *LruCache) GetWithExpire(key any) (any, time.Time, bool) { - entry := c.get(key) - if entry == nil { - return nil, time.Time{}, false +func (c *LruCache[K, V]) GetWithExpire(key K) (V, time.Time, bool) { + el := c.get(key) + if el == nil { + return getZero[V](), time.Time{}, false } - return entry.value, time.Unix(entry.expires, 0), true + return el.value, time.Unix(el.expires, 0), true } // Exist returns if key exist in cache but not put item to the head of linked list -func (c *LruCache) Exist(key any) bool { +func (c *LruCache[K, V]) Exist(key K) bool { c.mu.Lock() defer c.mu.Unlock() @@ -113,7 +113,7 @@ func (c *LruCache) Exist(key any) bool { } // Set stores the any representation of a response for a given key. -func (c *LruCache) Set(key any, value any) { +func (c *LruCache[K, V]) Set(key K, value V) { expires := int64(0) if c.maxAge > 0 { expires = time.Now().Unix() + c.maxAge @@ -123,21 +123,21 @@ func (c *LruCache) Set(key any, value any) { // SetWithExpire stores the any representation of a response for a given key and given expires. // The expires time will round to second. -func (c *LruCache) SetWithExpire(key any, value any, expires time.Time) { +func (c *LruCache[K, V]) SetWithExpire(key K, value V, expires time.Time) { c.mu.Lock() defer c.mu.Unlock() if le, ok := c.cache[key]; ok { c.lru.MoveToBack(le) - e := le.Value.(*entry) + e := le.Value.(*entry[K, V]) e.value = value e.expires = expires.Unix() } else { - e := &entry{key: key, value: value, expires: expires.Unix()} + e := &entry[K, V]{key: key, value: value, expires: expires.Unix()} c.cache[key] = c.lru.PushBack(e) if c.maxSize > 0 { - if len := c.lru.Len(); len > c.maxSize { + if elLen := c.lru.Len(); elLen > c.maxSize { c.deleteElement(c.lru.Front()) } } @@ -147,7 +147,7 @@ func (c *LruCache) SetWithExpire(key any, value any, expires time.Time) { } // CloneTo clone and overwrite elements to another LruCache -func (c *LruCache) CloneTo(n *LruCache) { +func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) { c.mu.Lock() defer c.mu.Unlock() @@ -158,12 +158,12 @@ func (c *LruCache) CloneTo(n *LruCache) { n.cache = make(map[any]*list.Element) for e := c.lru.Front(); e != nil; e = e.Next() { - elm := e.Value.(*entry) + elm := e.Value.(*entry[K, V]) n.cache[elm.key] = n.lru.PushBack(elm) } } -func (c *LruCache) get(key any) *entry { +func (c *LruCache[K, V]) get(key K) *entry[K, V] { c.mu.Lock() defer c.mu.Unlock() @@ -172,7 +172,7 @@ func (c *LruCache) get(key any) *entry { return nil } - if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { + if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry[K, V]).expires <= time.Now().Unix() { c.deleteElement(le) c.maybeDeleteOldest() @@ -180,15 +180,15 @@ func (c *LruCache) get(key any) *entry { } c.lru.MoveToBack(le) - entry := le.Value.(*entry) + el := le.Value.(*entry[K, V]) if c.maxAge > 0 && c.updateAgeOnGet { - entry.expires = time.Now().Unix() + c.maxAge + el.expires = time.Now().Unix() + c.maxAge } - return entry + return el } // Delete removes the value associated with a key. -func (c *LruCache) Delete(key any) { +func (c *LruCache[K, V]) Delete(key K) { c.mu.Lock() if le, ok := c.cache[key]; ok { @@ -198,25 +198,25 @@ func (c *LruCache) Delete(key any) { c.mu.Unlock() } -func (c *LruCache) maybeDeleteOldest() { +func (c *LruCache[K, V]) maybeDeleteOldest() { if !c.staleReturn && c.maxAge > 0 { now := time.Now().Unix() - for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { + for le := c.lru.Front(); le != nil && le.Value.(*entry[K, V]).expires <= now; le = c.lru.Front() { c.deleteElement(le) } } } -func (c *LruCache) deleteElement(le *list.Element) { +func (c *LruCache[K, V]) deleteElement(le *list.Element) { c.lru.Remove(le) - e := le.Value.(*entry) + e := le.Value.(*entry[K, V]) delete(c.cache, e.key) if c.onEvict != nil { c.onEvict(e.key, e.value) } } -func (c *LruCache) Clear() error { +func (c *LruCache[K, V]) Clear() error { c.mu.Lock() c.cache = make(map[any]*list.Element) @@ -225,8 +225,13 @@ func (c *LruCache) Clear() error { return nil } -type entry struct { - key any - value any +type entry[K comparable, V any] struct { + key K + value V expires int64 } + +func getZero[T any]() T { + var result T + return result +} diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go index 1a910b4a..487c184e 100644 --- a/common/cache/lrucache_test.go +++ b/common/cache/lrucache_test.go @@ -19,7 +19,7 @@ var entries = []struct { } func TestLRUCache(t *testing.T) { - c := NewLRUCache() + c := NewLRUCache[string, string]() for _, e := range entries { c.Set(e.key, e.value) @@ -32,7 +32,7 @@ func TestLRUCache(t *testing.T) { for _, e := range entries { value, ok := c.Get(e.key) if assert.True(t, ok) { - assert.Equal(t, e.value, value.(string)) + assert.Equal(t, e.value, value) } } @@ -45,25 +45,25 @@ func TestLRUCache(t *testing.T) { } func TestLRUMaxAge(t *testing.T) { - c := NewLRUCache(WithAge(86400)) + c := NewLRUCache[string, string](WithAge[string, string](86400)) now := time.Now().Unix() expected := now + 86400 // Add one expired entry c.Set("foo", "bar") - c.lru.Back().Value.(*entry).expires = now + c.lru.Back().Value.(*entry[string, string]).expires = now // Reset c.Set("foo", "bar") - e := c.lru.Back().Value.(*entry) + e := c.lru.Back().Value.(*entry[string, string]) assert.True(t, e.expires >= now) - c.lru.Back().Value.(*entry).expires = now + c.lru.Back().Value.(*entry[string, string]).expires = now // Set a few and verify expiration times for _, s := range entries { c.Set(s.key, s.value) - e := c.lru.Back().Value.(*entry) + e := c.lru.Back().Value.(*entry[string, string]) assert.True(t, e.expires >= expected && e.expires <= expected+10) } @@ -77,7 +77,7 @@ func TestLRUMaxAge(t *testing.T) { for _, s := range entries { le, ok := c.cache[s.key] if assert.True(t, ok) { - le.Value.(*entry).expires = now + le.Value.(*entry[string, string]).expires = now } } @@ -88,22 +88,22 @@ func TestLRUMaxAge(t *testing.T) { } func TestLRUpdateOnGet(t *testing.T) { - c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet()) + c := NewLRUCache[string, string](WithAge[string, string](86400), WithUpdateAgeOnGet[string, string]()) now := time.Now().Unix() expires := now + 86400/2 // Add one expired entry c.Set("foo", "bar") - c.lru.Back().Value.(*entry).expires = expires + c.lru.Back().Value.(*entry[string, string]).expires = expires _, ok := c.Get("foo") assert.True(t, ok) - assert.True(t, c.lru.Back().Value.(*entry).expires > expires) + assert.True(t, c.lru.Back().Value.(*entry[string, string]).expires > expires) } func TestMaxSize(t *testing.T) { - c := NewLRUCache(WithSize(2)) + c := NewLRUCache[string, string](WithSize[string, string](2)) // Add one expired entry c.Set("foo", "bar") _, ok := c.Get("foo") @@ -117,7 +117,7 @@ func TestMaxSize(t *testing.T) { } func TestExist(t *testing.T) { - c := NewLRUCache(WithSize(1)) + c := NewLRUCache[int, int](WithSize[int, int](1)) c.Set(1, 2) assert.True(t, c.Exist(1)) c.Set(2, 3) @@ -130,7 +130,7 @@ func TestEvict(t *testing.T) { temp = key.(int) + value.(int) } - c := NewLRUCache(WithEvict(evict), WithSize(1)) + c := NewLRUCache[int, int](WithEvict[int, int](evict), WithSize[int, int](1)) c.Set(1, 2) c.Set(2, 3) @@ -138,21 +138,22 @@ func TestEvict(t *testing.T) { } func TestSetWithExpire(t *testing.T) { - c := NewLRUCache(WithAge(1)) + c := NewLRUCache[int, *struct{}](WithAge[int, *struct{}](1)) now := time.Now().Unix() tenSecBefore := time.Unix(now-10, 0) - c.SetWithExpire(1, 2, tenSecBefore) + c.SetWithExpire(1, &struct{}{}, tenSecBefore) // res is expected not to exist, and expires should be empty time.Time res, expires, exist := c.GetWithExpire(1) - assert.Equal(t, nil, res) + + assert.True(t, nil == res) assert.Equal(t, time.Time{}, expires) assert.Equal(t, false, exist) } func TestStale(t *testing.T) { - c := NewLRUCache(WithAge(1), WithStale(true)) + c := NewLRUCache[int, int](WithAge[int, int](1), WithStale[int, int](true)) now := time.Now().Unix() tenSecBefore := time.Unix(now-10, 0) @@ -165,11 +166,11 @@ func TestStale(t *testing.T) { } func TestCloneTo(t *testing.T) { - o := NewLRUCache(WithSize(10)) + o := NewLRUCache[string, int](WithSize[string, int](10)) o.Set("1", 1) o.Set("2", 2) - n := NewLRUCache(WithSize(2)) + n := NewLRUCache[string, int](WithSize[string, int](2)) n.Set("3", 3) n.Set("4", 4) diff --git a/component/fakeip/memory.go b/component/fakeip/memory.go index a7ff3708..2568b1d9 100644 --- a/component/fakeip/memory.go +++ b/component/fakeip/memory.go @@ -7,16 +7,15 @@ import ( ) type memoryStore struct { - cache *cache.LruCache + cacheIP *cache.LruCache[string, net.IP] + cacheHost *cache.LruCache[uint32, string] } // GetByHost implements store.GetByHost func (m *memoryStore) GetByHost(host string) (net.IP, bool) { - if elm, exist := m.cache.Get(host); exist { - ip := elm.(net.IP) - + if ip, exist := m.cacheIP.Get(host); exist { // ensure ip --> host on head of linked list - m.cache.Get(ipToUint(ip.To4())) + m.cacheHost.Get(ipToUint(ip.To4())) return ip, true } @@ -25,16 +24,14 @@ func (m *memoryStore) GetByHost(host string) (net.IP, bool) { // PutByHost implements store.PutByHost func (m *memoryStore) PutByHost(host string, ip net.IP) { - m.cache.Set(host, ip) + m.cacheIP.Set(host, ip) } // GetByIP implements store.GetByIP func (m *memoryStore) GetByIP(ip net.IP) (string, bool) { - if elm, exist := m.cache.Get(ipToUint(ip.To4())); exist { - host := elm.(string) - + if host, exist := m.cacheHost.Get(ipToUint(ip.To4())); exist { // ensure host --> ip on head of linked list - m.cache.Get(host) + m.cacheIP.Get(host) return host, true } @@ -43,32 +40,41 @@ func (m *memoryStore) GetByIP(ip net.IP) (string, bool) { // PutByIP implements store.PutByIP func (m *memoryStore) PutByIP(ip net.IP, host string) { - m.cache.Set(ipToUint(ip.To4()), host) + m.cacheHost.Set(ipToUint(ip.To4()), host) } // DelByIP implements store.DelByIP func (m *memoryStore) DelByIP(ip net.IP) { ipNum := ipToUint(ip.To4()) - if elm, exist := m.cache.Get(ipNum); exist { - m.cache.Delete(elm.(string)) + if host, exist := m.cacheHost.Get(ipNum); exist { + m.cacheIP.Delete(host) } - m.cache.Delete(ipNum) + m.cacheHost.Delete(ipNum) } // Exist implements store.Exist func (m *memoryStore) Exist(ip net.IP) bool { - return m.cache.Exist(ipToUint(ip.To4())) + return m.cacheHost.Exist(ipToUint(ip.To4())) } // CloneTo implements store.CloneTo // only for memoryStore to memoryStore func (m *memoryStore) CloneTo(store store) { if ms, ok := store.(*memoryStore); ok { - m.cache.CloneTo(ms.cache) + m.cacheIP.CloneTo(ms.cacheIP) + m.cacheHost.CloneTo(ms.cacheHost) } } // FlushFakeIP implements store.FlushFakeIP func (m *memoryStore) FlushFakeIP() error { - return m.cache.Clear() + _ = m.cacheIP.Clear() + return m.cacheHost.Clear() +} + +func newMemoryStore(size int) *memoryStore { + return &memoryStore{ + cacheIP: cache.NewLRUCache[string, net.IP](cache.WithSize[string, net.IP](size)), + cacheHost: cache.NewLRUCache[uint32, string](cache.WithSize[uint32, string](size)), + } } diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index e93873c9..a55e5463 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -5,7 +5,6 @@ import ( "net" "sync" - "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/component/profile/cachefile" "github.com/Dreamacro/clash/component/trie" ) @@ -175,9 +174,7 @@ func New(options Options) (*Pool, error) { cache: cachefile.Cache(), } } else { - pool.store = &memoryStore{ - cache: cache.NewLRUCache(cache.WithSize(options.Size * 2)), - } + pool.store = newMemoryStore(options.Size) } return pool, nil diff --git a/dns/enhancer.go b/dns/enhancer.go index 016ff02a..9d708caa 100644 --- a/dns/enhancer.go +++ b/dns/enhancer.go @@ -11,7 +11,7 @@ import ( type ResolverEnhancer struct { mode C.DNSMode fakePool *fakeip.Pool - mapping *cache.LruCache + mapping *cache.LruCache[string, string] } func (h *ResolverEnhancer) FakeIPEnabled() bool { @@ -67,7 +67,7 @@ func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) { if mapping := h.mapping; mapping != nil { if host, existed := h.mapping.Get(ip.String()); existed { - return host.(string), true + return host, true } } @@ -99,11 +99,11 @@ func (h *ResolverEnhancer) FlushFakeIP() error { func NewEnhancer(cfg Config) *ResolverEnhancer { var fakePool *fakeip.Pool - var mapping *cache.LruCache + var mapping *cache.LruCache[string, string] if cfg.EnhancedMode != C.DNSNormal { fakePool = cfg.Pool - mapping = cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)) + mapping = cache.NewLRUCache[string, string](cache.WithSize[string, string](4096), cache.WithStale[string, string](true)) } return &ResolverEnhancer{ diff --git a/dns/middleware.go b/dns/middleware.go index dc7cbe33..5958fe93 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -63,7 +63,7 @@ func withHosts(hosts *trie.DomainTrie) middleware { } } -func withMapping(mapping *cache.LruCache) middleware { +func withMapping(mapping *cache.LruCache[string, string]) middleware { return func(next handler) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] diff --git a/dns/resolver.go b/dns/resolver.go index c5e32867..86c3d2d5 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -39,7 +39,7 @@ type Resolver struct { fallbackDomainFilters []fallbackDomainFilter fallbackIPFilters []fallbackIPFilter group singleflight.Group - lruCache *cache.LruCache + lruCache *cache.LruCache[string, *D.Msg] policy *trie.DomainTrie proxyServer []dnsClient } @@ -103,7 +103,7 @@ func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, e cache, expireTime, hit := r.lruCache.GetWithExpire(q.String()) if hit { now := time.Now() - msg = cache.(*D.Msg).Copy() + msg = cache.Copy() if expireTime.Before(now) { setMsgTTL(msg, uint32(1)) // Continue fetch go r.exchangeWithoutCache(ctx, m) @@ -336,13 +336,13 @@ type Config struct { func NewResolver(config Config) *Resolver { defaultResolver := &Resolver{ main: transform(config.Default, nil), - lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), + lruCache: cache.NewLRUCache[string, *D.Msg](cache.WithSize[string, *D.Msg](4096), cache.WithStale[string, *D.Msg](true)), } r := &Resolver{ ipv6: config.IPv6, main: transform(config.Main, defaultResolver), - lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), + lruCache: cache.NewLRUCache[string, *D.Msg](cache.WithSize[string, *D.Msg](4096), cache.WithStale[string, *D.Msg](true)), hosts: config.Hosts, } diff --git a/dns/util.go b/dns/util.go index ffe35cdb..4b4fdbb3 100644 --- a/dns/util.go +++ b/dns/util.go @@ -16,7 +16,7 @@ import ( D "github.com/miekg/dns" ) -func putMsgToCache(c *cache.LruCache, key string, msg *D.Msg) { +func putMsgToCache(c *cache.LruCache[string, *D.Msg], key string, msg *D.Msg) { var ttl uint32 switch { case len(msg.Answer) != 0: