diff --git a/common/cache/cache_test.go b/common/cache/cache_test.go index 101ca869..cf4a3914 100644 --- a/common/cache/cache_test.go +++ b/common/cache/cache_test.go @@ -4,6 +4,8 @@ import ( "runtime" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestCache_Basic(t *testing.T) { @@ -14,32 +16,30 @@ func TestCache_Basic(t *testing.T) { c.Put("string", "a", ttl) i := c.Get("int") - if i.(int) != 1 { - t.Error("should recv 1") - } + assert.Equal(t, i.(int), 1, "should recv 1") s := c.Get("string") - if s.(string) != "a" { - t.Error("should recv 'a'") - } + assert.Equal(t, s.(string), "a", "should recv 'a'") } func TestCache_TTL(t *testing.T) { interval := 200 * time.Millisecond ttl := 20 * time.Millisecond + now := time.Now() c := New(interval) c.Put("int", 1, ttl) + c.Put("int2", 2, ttl) i := c.Get("int") - if i.(int) != 1 { - t.Error("should recv 1") - } + _, expired := c.GetWithExpire("int2") + assert.Equal(t, i.(int), 1, "should recv 1") + assert.True(t, now.Before(expired)) time.Sleep(ttl * 2) i = c.Get("int") - if i != nil { - t.Error("should recv nil") - } + j, _ := c.GetWithExpire("int2") + assert.Nil(t, i, "should recv nil") + assert.Nil(t, j, "should recv nil") } func TestCache_AutoCleanup(t *testing.T) { @@ -50,9 +50,9 @@ func TestCache_AutoCleanup(t *testing.T) { time.Sleep(ttl * 2) i := c.Get("int") - if i != nil { - t.Error("should recv nil") - } + j, _ := c.GetWithExpire("int") + assert.Nil(t, i, "should recv nil") + assert.Nil(t, j, "should recv nil") } func TestCache_AutoGC(t *testing.T) { diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go new file mode 100644 index 00000000..5a139bf7 --- /dev/null +++ b/common/cache/lrucache.go @@ -0,0 +1,148 @@ +package cache + +// Modified by https://github.com/die-net/lrucache + +import ( + "container/list" + "sync" + "time" +) + +// Option is part of Functional Options Pattern +type Option func(*LruCache) + +// WithUpdateAgeOnGet update expires when Get element +func WithUpdateAgeOnGet() Option { + return func(l *LruCache) { + l.updateAgeOnGet = true + } +} + +// WithAge defined element max age (second) +func WithAge(maxAge int64) Option { + return func(l *LruCache) { + l.maxAge = maxAge + } +} + +// WithSize defined max length of LruCache +func WithSize(maxSize int) Option { + return func(l *LruCache) { + l.maxSize = maxSize + } +} + +// LruCache is a thread-safe, in-memory lru-cache that evicts the +// least recently used entries from memory when (if set) the entries are +// older than maxAge (in seconds). Use the New constructor to create one. +type LruCache struct { + maxAge int64 + maxSize int + mu sync.Mutex + cache map[interface{}]*list.Element + lru *list.List // Front is least-recent + updateAgeOnGet bool +} + +// NewLRUCache creates an LruCache +func NewLRUCache(options ...Option) *LruCache { + lc := &LruCache{ + lru: list.New(), + cache: make(map[interface{}]*list.Element), + } + + for _, option := range options { + option(lc) + } + + return lc +} + +// Get returns the interface{} representation of a cached response and a bool +// set to true if the key was found. +func (c *LruCache) Get(key interface{}) (interface{}, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + le, ok := c.cache[key] + if !ok { + return nil, false + } + + if c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { + c.deleteElement(le) + c.maybeDeleteOldest() + + return nil, false + } + + c.lru.MoveToBack(le) + entry := le.Value.(*entry) + if c.maxAge > 0 && c.updateAgeOnGet { + entry.expires = time.Now().Unix() + c.maxAge + } + value := entry.value + + return value, true +} + +// Set stores the interface{} representation of a response for a given key. +func (c *LruCache) Set(key interface{}, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + expires := int64(0) + if c.maxAge > 0 { + expires = time.Now().Unix() + c.maxAge + } + + if le, ok := c.cache[key]; ok { + c.lru.MoveToBack(le) + e := le.Value.(*entry) + e.value = value + e.expires = expires + } else { + e := &entry{key: key, value: value, expires: expires} + c.cache[key] = c.lru.PushBack(e) + + if c.maxSize > 0 { + if len := c.lru.Len(); len > c.maxSize { + c.deleteElement(c.lru.Front()) + } + } + } + + c.maybeDeleteOldest() +} + +// Delete removes the value associated with a key. +func (c *LruCache) Delete(key string) { + c.mu.Lock() + + if le, ok := c.cache[key]; ok { + c.deleteElement(le) + } + + c.mu.Unlock() +} + +func (c *LruCache) maybeDeleteOldest() { + if c.maxAge > 0 { + now := time.Now().Unix() + for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { + c.deleteElement(le) + } + } +} + +func (c *LruCache) deleteElement(le *list.Element) { + c.lru.Remove(le) + e := le.Value.(*entry) + delete(c.cache, e.key) +} + +type entry struct { + key interface{} + value interface{} + expires int64 +} diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go new file mode 100644 index 00000000..31f9a919 --- /dev/null +++ b/common/cache/lrucache_test.go @@ -0,0 +1,117 @@ +package cache + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var entries = []struct { + key string + value string +}{ + {"1", "one"}, + {"2", "two"}, + {"3", "three"}, + {"4", "four"}, + {"5", "five"}, +} + +func TestLRUCache(t *testing.T) { + c := NewLRUCache() + + for _, e := range entries { + c.Set(e.key, e.value) + } + + c.Delete("missing") + _, ok := c.Get("missing") + assert.False(t, ok) + + for _, e := range entries { + value, ok := c.Get(e.key) + if assert.True(t, ok) { + assert.Equal(t, e.value, value.(string)) + } + } + + for _, e := range entries { + c.Delete(e.key) + + _, ok := c.Get(e.key) + assert.False(t, ok) + } +} + +func TestLRUMaxAge(t *testing.T) { + c := NewLRUCache(WithAge(86400)) + + now := time.Now().Unix() + expected := now + 86400 + + // Add one expired entry + c.Set("foo", "bar") + c.lru.Back().Value.(*entry).expires = now + + // Reset + c.Set("foo", "bar") + e := c.lru.Back().Value.(*entry) + assert.True(t, e.expires >= now) + c.lru.Back().Value.(*entry).expires = now + + // Set a few and verify expiration times + for _, s := range entries { + c.Set(s.key, s.value) + e := c.lru.Back().Value.(*entry) + assert.True(t, e.expires >= expected && e.expires <= expected+10) + } + + // Make sure we can get them all + for _, s := range entries { + _, ok := c.Get(s.key) + assert.True(t, ok) + } + + // Expire all entries + for _, s := range entries { + le, ok := c.cache[s.key] + if assert.True(t, ok) { + le.Value.(*entry).expires = now + } + } + + // Get one expired entry, which should clear all expired entries + _, ok := c.Get("3") + assert.False(t, ok) + assert.Equal(t, c.lru.Len(), 0) +} + +func TestLRUpdateOnGet(t *testing.T) { + c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet()) + + now := time.Now().Unix() + expires := now + 86400/2 + + // Add one expired entry + c.Set("foo", "bar") + c.lru.Back().Value.(*entry).expires = expires + + _, ok := c.Get("foo") + assert.True(t, ok) + assert.True(t, c.lru.Back().Value.(*entry).expires > expires) +} + +func TestMaxSize(t *testing.T) { + c := NewLRUCache(WithSize(2)) + // Add one expired entry + c.Set("foo", "bar") + _, ok := c.Get("foo") + assert.True(t, ok) + + c.Set("bar", "foo") + c.Set("baz", "foo") + + _, ok = c.Get("foo") + assert.False(t, ok) +} diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index 32d5d574..8b7a2766 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -4,22 +4,72 @@ import ( "errors" "net" "sync" + + "github.com/Dreamacro/clash/common/cache" ) // Pool is a implementation about fake ip generator without storage type Pool struct { - max uint32 - min uint32 - offset uint32 - mux *sync.Mutex + max uint32 + min uint32 + gateway uint32 + offset uint32 + mux *sync.Mutex + cache *cache.LruCache } -// Get return a new fake ip -func (p *Pool) Get() net.IP { +// Lookup return a fake ip with host +func (p *Pool) Lookup(host string) net.IP { p.mux.Lock() defer p.mux.Unlock() - ip := uintToIP(p.min + p.offset) - p.offset = (p.offset + 1) % (p.max - p.min) + if ip, exist := p.cache.Get(host); exist { + return ip.(net.IP) + } + + ip := p.get(host) + p.cache.Set(host, ip) + return ip +} + +// LookBack return host with the fake ip +func (p *Pool) LookBack(ip net.IP) (string, bool) { + p.mux.Lock() + defer p.mux.Unlock() + + if ip = ip.To4(); ip == nil { + return "", false + } + + n := ipToUint(ip.To4()) + offset := n - p.min + 1 + + if host, exist := p.cache.Get(offset); exist { + return host.(string), true + } + + return "", false +} + +// Gateway return gateway ip +func (p *Pool) Gateway() net.IP { + return uintToIP(p.gateway) +} + +func (p *Pool) get(host string) net.IP { + current := p.offset + for { + p.offset = (p.offset + 1) % (p.max - p.min) + // Avoid infinite loops + if p.offset == current { + break + } + + if _, exist := p.cache.Get(p.offset); !exist { + break + } + } + ip := uintToIP(p.min + p.offset - 1) + p.cache.Set(p.offset, host) return ip } @@ -36,8 +86,8 @@ func uintToIP(v uint32) net.IP { } // New return Pool instance -func New(ipnet *net.IPNet) (*Pool, error) { - min := ipToUint(ipnet.IP) + 1 +func New(ipnet *net.IPNet, size int) (*Pool, error) { + min := ipToUint(ipnet.IP) + 2 ones, bits := ipnet.Mask.Size() total := 1<