From 966eeae41b557e906a0b55e9740223a788706e1a Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 23 Sep 2024 09:35:48 +0800 Subject: [PATCH] chore: rewrite bbolt cachefile implements never use returned byte slices outside the transaction, ref: https://pkg.go.dev/go.etcd.io/bbolt#hdr-Caveats --- common/utils/hash.go | 45 ++++++++++ component/fakeip/cachefile.go | 30 +++---- component/fakeip/pool.go | 4 +- component/fakeip/pool_test.go | 4 +- component/profile/cachefile/cache.go | 104 ++++------------------- component/profile/cachefile/fakeip.go | 115 ++++++++++++++++++++++++++ component/resource/fetcher.go | 9 +- component/resource/vehicle.go | 13 +-- component/updater/update_geo.go | 18 ++-- constant/path.go | 7 +- constant/provider/hash.go | 29 ------- constant/provider/interface.go | 2 +- 12 files changed, 212 insertions(+), 168 deletions(-) create mode 100644 common/utils/hash.go create mode 100644 component/profile/cachefile/fakeip.go delete mode 100644 constant/provider/hash.go diff --git a/common/utils/hash.go b/common/utils/hash.go new file mode 100644 index 00000000..38ba15b4 --- /dev/null +++ b/common/utils/hash.go @@ -0,0 +1,45 @@ +package utils + +import ( + "crypto/md5" + "encoding/hex" +) + +// HashType warps hash array inside struct +// someday can change to other hash algorithm simply +type HashType struct { + md5 [md5.Size]byte // MD5 +} + +func MakeHash(data []byte) HashType { + return HashType{md5.Sum(data)} +} + +func MakeHashFromBytes(hashBytes []byte) (h HashType) { + if len(hashBytes) != md5.Size { + return + } + copy(h.md5[:], hashBytes) + return +} + +func (h HashType) Equal(hash HashType) bool { + return h.md5 == hash.md5 +} + +func (h HashType) Bytes() []byte { + return h.md5[:] +} + +func (h HashType) String() string { + return hex.EncodeToString(h.Bytes()) +} + +func (h HashType) Len() int { + return len(h.md5) +} + +func (h HashType) IsValid() bool { + var zero HashType + return h != zero +} diff --git a/component/fakeip/cachefile.go b/component/fakeip/cachefile.go index 6f0cc48b..92d09721 100644 --- a/component/fakeip/cachefile.go +++ b/component/fakeip/cachefile.go @@ -7,46 +7,32 @@ import ( ) type cachefileStore struct { - cache *cachefile.CacheFile + cache *cachefile.FakeIpStore } // GetByHost implements store.GetByHost func (c *cachefileStore) GetByHost(host string) (netip.Addr, bool) { - elm := c.cache.GetFakeip([]byte(host)) - if elm == nil { - return netip.Addr{}, false - } - - if len(elm) == 4 { - return netip.AddrFrom4(*(*[4]byte)(elm)), true - } else { - return netip.AddrFrom16(*(*[16]byte)(elm)), true - } + return c.cache.GetByHost(host) } // PutByHost implements store.PutByHost func (c *cachefileStore) PutByHost(host string, ip netip.Addr) { - c.cache.PutFakeip([]byte(host), ip.AsSlice()) + c.cache.PutByHost(host, ip) } // GetByIP implements store.GetByIP func (c *cachefileStore) GetByIP(ip netip.Addr) (string, bool) { - elm := c.cache.GetFakeip(ip.AsSlice()) - if elm == nil { - return "", false - } - return string(elm), true + return c.cache.GetByIP(ip) } // PutByIP implements store.PutByIP func (c *cachefileStore) PutByIP(ip netip.Addr, host string) { - c.cache.PutFakeip(ip.AsSlice(), []byte(host)) + c.cache.PutByIP(ip, host) } // DelByIP implements store.DelByIP func (c *cachefileStore) DelByIP(ip netip.Addr) { - addr := ip.AsSlice() - c.cache.DelFakeipPair(addr, c.cache.GetFakeip(addr)) + c.cache.DelByIP(ip) } // Exist implements store.Exist @@ -63,3 +49,7 @@ func (c *cachefileStore) CloneTo(store store) {} func (c *cachefileStore) FlushFakeIP() error { return c.cache.FlushFakeIP() } + +func newCachefileStore(cache *cachefile.CacheFile) *cachefileStore { + return &cachefileStore{cache.FakeIpStore()} +} diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index 12c06332..41b848b3 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -201,9 +201,7 @@ func New(options Options) (*Pool, error) { ipnet: options.IPNet, } if options.Persistence { - pool.store = &cachefileStore{ - cache: cachefile.Cache(), - } + pool.store = newCachefileStore(cachefile.Cache()) } else { pool.store = newMemoryStore(options.Size) } diff --git a/component/fakeip/pool_test.go b/component/fakeip/pool_test.go index ee607b68..be78b87c 100644 --- a/component/fakeip/pool_test.go +++ b/component/fakeip/pool_test.go @@ -43,9 +43,7 @@ func createCachefileStore(options Options) (*Pool, string, error) { return nil, "", err } - pool.store = &cachefileStore{ - cache: &cachefile.CacheFile{DB: db}, - } + pool.store = newCachefileStore(&cachefile.CacheFile{DB: db}) return pool, f.Name(), nil } diff --git a/component/profile/cachefile/cache.go b/component/profile/cachefile/cache.go index 0591c92b..6a918041 100644 --- a/component/profile/cachefile/cache.go +++ b/component/profile/cachefile/cache.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/component/profile" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" @@ -71,93 +72,19 @@ func (c *CacheFile) SelectedMap() map[string]string { return mapping } -func (c *CacheFile) PutFakeip(key, value []byte) error { - if c.DB == nil { - return nil - } - - err := c.DB.Batch(func(t *bbolt.Tx) error { - bucket, err := t.CreateBucketIfNotExists(bucketFakeip) - if err != nil { - return err - } - return bucket.Put(key, value) - }) - if err != nil { - log.Warnln("[CacheFile] write cache to %s failed: %s", c.DB.Path(), err.Error()) - } - - return err -} - -func (c *CacheFile) DelFakeipPair(ip, host []byte) error { - if c.DB == nil { - return nil - } - - err := c.DB.Batch(func(t *bbolt.Tx) error { - bucket, err := t.CreateBucketIfNotExists(bucketFakeip) - if err != nil { - return err - } - err = bucket.Delete(ip) - if len(host) > 0 { - if err := bucket.Delete(host); err != nil { - return err - } - } - return err - }) - if err != nil { - log.Warnln("[CacheFile] write cache to %s failed: %s", c.DB.Path(), err.Error()) - } - - return err -} - -func (c *CacheFile) GetFakeip(key []byte) []byte { - if c.DB == nil { - return nil - } - - tx, err := c.DB.Begin(false) - if err != nil { - return nil - } - defer tx.Rollback() - - bucket := tx.Bucket(bucketFakeip) - if bucket == nil { - return nil - } - - return bucket.Get(key) -} - -func (c *CacheFile) FlushFakeIP() error { - err := c.DB.Batch(func(t *bbolt.Tx) error { - bucket := t.Bucket(bucketFakeip) - if bucket == nil { - return nil - } - return t.DeleteBucket(bucketFakeip) - }) - return err -} - -func (c *CacheFile) SetETagWithHash(url string, hash []byte, etag string) { +func (c *CacheFile) SetETagWithHash(url string, hash utils.HashType, etag string) { if c.DB == nil { return } - lenHash := len(hash) + lenHash := hash.Len() if lenHash > math.MaxUint8 { return // maybe panic is better } data := make([]byte, 1, 1+lenHash+len(etag)) data[0] = uint8(lenHash) - data = append(data, hash...) + data = append(data, hash.Bytes()...) data = append(data, etag...) err := c.DB.Batch(func(t *bbolt.Tx) error { @@ -173,28 +100,27 @@ func (c *CacheFile) SetETagWithHash(url string, hash []byte, etag string) { return } } -func (c *CacheFile) GetETagWithHash(key string) (hash []byte, etag string) { +func (c *CacheFile) GetETagWithHash(key string) (hash utils.HashType, etag string) { if c.DB == nil { return } - var value []byte c.DB.View(func(t *bbolt.Tx) error { if bucket := t.Bucket(bucketETag); bucket != nil { if v := bucket.Get([]byte(key)); v != nil { - value = v + if len(v) == 0 { + return nil + } + lenHash := int(v[0]) + if len(v) < 1+lenHash { + return nil + } + hash = utils.MakeHashFromBytes(v[1 : 1+lenHash]) + etag = string(v[1+lenHash:]) } } return nil }) - if len(value) == 0 { - return - } - lenHash := int(value[0]) - if len(value) < 1+lenHash { - return - } - hash = value[1 : 1+lenHash] - etag = string(value[1+lenHash:]) + return } diff --git a/component/profile/cachefile/fakeip.go b/component/profile/cachefile/fakeip.go new file mode 100644 index 00000000..20a09f9c --- /dev/null +++ b/component/profile/cachefile/fakeip.go @@ -0,0 +1,115 @@ +package cachefile + +import ( + "net/netip" + + "github.com/metacubex/mihomo/log" + + "github.com/metacubex/bbolt" +) + +type FakeIpStore struct { + *CacheFile +} + +func (c *CacheFile) FakeIpStore() *FakeIpStore { + return &FakeIpStore{c} +} + +func (c *FakeIpStore) GetByHost(host string) (ip netip.Addr, exist bool) { + if c.DB == nil { + return + } + c.DB.View(func(t *bbolt.Tx) error { + if bucket := t.Bucket(bucketFakeip); bucket != nil { + if v := bucket.Get([]byte(host)); v != nil { + ip, exist = netip.AddrFromSlice(v) + } + } + return nil + }) + return +} + +func (c *FakeIpStore) PutByHost(host string, ip netip.Addr) { + if c.DB == nil { + return + } + err := c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := t.CreateBucketIfNotExists(bucketFakeip) + if err != nil { + return err + } + return bucket.Put([]byte(host), ip.AsSlice()) + }) + if err != nil { + log.Warnln("[CacheFile] write cache to %s failed: %s", c.DB.Path(), err.Error()) + } +} + +func (c *FakeIpStore) GetByIP(ip netip.Addr) (host string, exist bool) { + if c.DB == nil { + return + } + c.DB.View(func(t *bbolt.Tx) error { + if bucket := t.Bucket(bucketFakeip); bucket != nil { + if v := bucket.Get(ip.AsSlice()); v != nil { + host, exist = string(v), true + } + } + return nil + }) + return +} + +func (c *FakeIpStore) PutByIP(ip netip.Addr, host string) { + if c.DB == nil { + return + } + err := c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := t.CreateBucketIfNotExists(bucketFakeip) + if err != nil { + return err + } + return bucket.Put(ip.AsSlice(), []byte(host)) + }) + if err != nil { + log.Warnln("[CacheFile] write cache to %s failed: %s", c.DB.Path(), err.Error()) + } +} + +func (c *FakeIpStore) DelByIP(ip netip.Addr) { + if c.DB == nil { + return + } + + addr := ip.AsSlice() + err := c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := t.CreateBucketIfNotExists(bucketFakeip) + if err != nil { + return err + } + host := bucket.Get(addr) + err = bucket.Delete(addr) + if len(host) > 0 { + if err = bucket.Delete(host); err != nil { + return err + } + } + return err + }) + if err != nil { + log.Warnln("[CacheFile] write cache to %s failed: %s", c.DB.Path(), err.Error()) + } +} + +func (c *FakeIpStore) FlushFakeIP() error { + err := c.DB.Batch(func(t *bbolt.Tx) error { + bucket := t.Bucket(bucketFakeip) + if bucket == nil { + return nil + } + return t.DeleteBucket(bucketFakeip) + }) + return err +} diff --git a/component/resource/fetcher.go b/component/resource/fetcher.go index 3e2ec239..39beee85 100644 --- a/component/resource/fetcher.go +++ b/component/resource/fetcher.go @@ -5,6 +5,7 @@ import ( "os" "time" + "github.com/metacubex/mihomo/common/utils" types "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/log" @@ -21,7 +22,7 @@ type Fetcher[V any] struct { name string vehicle types.Vehicle updatedAt time.Time - hash types.HashType + hash utils.HashType parser Parser[V] interval time.Duration onUpdate func(V) @@ -55,7 +56,7 @@ func (f *Fetcher[V]) Initial() (V, error) { // local file exists, use it first buf, err = os.ReadFile(f.vehicle.Path()) modTime := stat.ModTime() - contents, _, err = f.loadBuf(buf, types.MakeHash(buf), false) + contents, _, err = f.loadBuf(buf, utils.MakeHash(buf), false) f.updatedAt = modTime // reset updatedAt to file's modTime if err == nil { @@ -89,10 +90,10 @@ func (f *Fetcher[V]) Update() (V, bool, error) { } func (f *Fetcher[V]) SideUpdate(buf []byte) (V, bool, error) { - return f.loadBuf(buf, types.MakeHash(buf), true) + return f.loadBuf(buf, utils.MakeHash(buf), true) } -func (f *Fetcher[V]) loadBuf(buf []byte, hash types.HashType, updateFile bool) (V, bool, error) { +func (f *Fetcher[V]) loadBuf(buf []byte, hash utils.HashType, updateFile bool) (V, bool, error) { now := time.Now() if f.hash.Equal(hash) { if updateFile { diff --git a/component/resource/vehicle.go b/component/resource/vehicle.go index f30e22d0..b24adfa9 100644 --- a/component/resource/vehicle.go +++ b/component/resource/vehicle.go @@ -9,6 +9,7 @@ import ( "path/filepath" "time" + "github.com/metacubex/mihomo/common/utils" mihomoHttp "github.com/metacubex/mihomo/component/http" "github.com/metacubex/mihomo/component/profile/cachefile" types "github.com/metacubex/mihomo/constant/provider" @@ -61,12 +62,12 @@ func (f *FileVehicle) Url() string { return "file://" + f.path } -func (f *FileVehicle) Read(ctx context.Context, oldHash types.HashType) (buf []byte, hash types.HashType, err error) { +func (f *FileVehicle) Read(ctx context.Context, oldHash utils.HashType) (buf []byte, hash utils.HashType, err error) { buf, err = os.ReadFile(f.path) if err != nil { return } - hash = types.MakeHash(buf) + hash = utils.MakeHash(buf) return } @@ -110,14 +111,14 @@ func (h *HTTPVehicle) Write(buf []byte) error { return safeWrite(h.path, buf) } -func (h *HTTPVehicle) Read(ctx context.Context, oldHash types.HashType) (buf []byte, hash types.HashType, err error) { +func (h *HTTPVehicle) Read(ctx context.Context, oldHash utils.HashType) (buf []byte, hash utils.HashType, err error) { ctx, cancel := context.WithTimeout(ctx, h.timeout) defer cancel() header := h.header setIfNoneMatch := false if etag && oldHash.IsValid() { hashBytes, etag := cachefile.Cache().GetETagWithHash(h.url) - if oldHash.EqualBytes(hashBytes) && etag != "" { + if oldHash.Equal(hashBytes) && etag != "" { if header == nil { header = http.Header{} } else { @@ -143,9 +144,9 @@ func (h *HTTPVehicle) Read(ctx context.Context, oldHash types.HashType) (buf []b if err != nil { return } - hash = types.MakeHash(buf) + hash = utils.MakeHash(buf) if etag { - cachefile.Cache().SetETagWithHash(h.url, hash.Bytes(), resp.Header.Get("ETag")) + cachefile.Cache().SetETagWithHash(h.url, hash, resp.Header.Get("ETag")) } return } diff --git a/component/updater/update_geo.go b/component/updater/update_geo.go index 454cd84d..b5dc9677 100644 --- a/component/updater/update_geo.go +++ b/component/updater/update_geo.go @@ -10,12 +10,12 @@ import ( "github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/batch" + "github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/component/geodata" _ "github.com/metacubex/mihomo/component/geodata/standard" "github.com/metacubex/mihomo/component/mmdb" "github.com/metacubex/mihomo/component/resource" C "github.com/metacubex/mihomo/constant" - P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/log" "github.com/oschwald/maxminddb-golang" @@ -46,9 +46,9 @@ func SetGeoUpdateInterval(newGeoUpdateInterval int) { func UpdateMMDB() (err error) { vehicle := resource.NewHTTPVehicle(geodata.MmdbUrl(), C.Path.MMDB(), "", nil, defaultHttpTimeout) - var oldHash P.HashType + var oldHash utils.HashType if buf, err := os.ReadFile(vehicle.Path()); err == nil { - oldHash = P.MakeHash(buf) + oldHash = utils.MakeHash(buf) } data, hash, err := vehicle.Read(context.Background(), oldHash) if err != nil { @@ -77,9 +77,9 @@ func UpdateMMDB() (err error) { func UpdateASN() (err error) { vehicle := resource.NewHTTPVehicle(geodata.ASNUrl(), C.Path.ASN(), "", nil, defaultHttpTimeout) - var oldHash P.HashType + var oldHash utils.HashType if buf, err := os.ReadFile(vehicle.Path()); err == nil { - oldHash = P.MakeHash(buf) + oldHash = utils.MakeHash(buf) } data, hash, err := vehicle.Read(context.Background(), oldHash) if err != nil { @@ -110,9 +110,9 @@ func UpdateGeoIp() (err error) { geoLoader, err := geodata.GetGeoDataLoader("standard") vehicle := resource.NewHTTPVehicle(geodata.GeoIpUrl(), C.Path.GeoIP(), "", nil, defaultHttpTimeout) - var oldHash P.HashType + var oldHash utils.HashType if buf, err := os.ReadFile(vehicle.Path()); err == nil { - oldHash = P.MakeHash(buf) + oldHash = utils.MakeHash(buf) } data, hash, err := vehicle.Read(context.Background(), oldHash) if err != nil { @@ -140,9 +140,9 @@ func UpdateGeoSite() (err error) { geoLoader, err := geodata.GetGeoDataLoader("standard") vehicle := resource.NewHTTPVehicle(geodata.GeoSiteUrl(), C.Path.GeoSite(), "", nil, defaultHttpTimeout) - var oldHash P.HashType + var oldHash utils.HashType if buf, err := os.ReadFile(vehicle.Path()); err == nil { - oldHash = P.MakeHash(buf) + oldHash = utils.MakeHash(buf) } data, hash, err := vehicle.Read(context.Background(), oldHash) if err != nil { diff --git a/constant/path.go b/constant/path.go index 02279371..1594441c 100644 --- a/constant/path.go +++ b/constant/path.go @@ -1,14 +1,13 @@ package constant import ( - "crypto/md5" - "encoding/hex" "os" P "path" "path/filepath" "strconv" "strings" + "github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/constant/features" ) @@ -89,8 +88,8 @@ func (p *path) IsSafePath(path string) bool { } func (p *path) GetPathByHash(prefix, name string) string { - hash := md5.Sum([]byte(name)) - filename := hex.EncodeToString(hash[:]) + hash := utils.MakeHash([]byte(name)) + filename := hash.String() return filepath.Join(p.HomeDir(), prefix, filename) } diff --git a/constant/provider/hash.go b/constant/provider/hash.go deleted file mode 100644 index b95ffe23..00000000 --- a/constant/provider/hash.go +++ /dev/null @@ -1,29 +0,0 @@ -package provider - -import ( - "bytes" - "crypto/md5" -) - -type HashType [md5.Size]byte // MD5 - -func MakeHash(data []byte) HashType { - return md5.Sum(data) -} - -func (h HashType) Equal(hash HashType) bool { - return h == hash -} - -func (h HashType) EqualBytes(hashBytes []byte) bool { - return bytes.Equal(hashBytes, h[:]) -} - -func (h HashType) Bytes() []byte { - return h[:] -} - -func (h HashType) IsValid() bool { - var zero HashType - return h != zero -} diff --git a/constant/provider/interface.go b/constant/provider/interface.go index 511e8f18..065b801a 100644 --- a/constant/provider/interface.go +++ b/constant/provider/interface.go @@ -32,7 +32,7 @@ func (v VehicleType) String() string { } type Vehicle interface { - Read(ctx context.Context, oldHash HashType) (buf []byte, hash HashType, err error) + Read(ctx context.Context, oldHash utils.HashType) (buf []byte, hash utils.HashType, err error) Write(buf []byte) error Path() string Url() string