diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index be7bdd8d..dd8749a6 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -89,6 +89,11 @@ func (p *Pool) Gateway() net.IP { return uintToIP(p.gateway) } +// PatchFrom clone cache from old pool +func (p *Pool) PatchFrom(o *Pool) { + o.cache.CloneTo(p.cache) +} + func (p *Pool) get(host string) net.IP { current := p.offset for { diff --git a/component/resolver/enhancer.go b/component/resolver/enhancer.go new file mode 100644 index 00000000..48b49bea --- /dev/null +++ b/component/resolver/enhancer.go @@ -0,0 +1,46 @@ +package resolver + +import ( + "net" +) + +var DefaultHostMapper Enhancer + +type Enhancer interface { + FakeIPEnabled() bool + MappingEnabled() bool + IsFakeIP(net.IP) bool + FindHostByIP(net.IP) (string, bool) +} + +func FakeIPEnabled() bool { + if mapper := DefaultHostMapper; mapper != nil { + return mapper.FakeIPEnabled() + } + + return false +} + +func MappingEnabled() bool { + if mapper := DefaultHostMapper; mapper != nil { + return mapper.MappingEnabled() + } + + return false +} + +func IsFakeIP(ip net.IP) bool { + if mapper := DefaultHostMapper; mapper != nil { + return mapper.IsFakeIP(ip) + } + + return false +} + +func FindHostByIP(ip net.IP) (string, bool) { + if mapper := DefaultHostMapper; mapper != nil { + return mapper.FindHostByIP(ip) + } + + return "", false +} diff --git a/dns/enhancer.go b/dns/enhancer.go new file mode 100644 index 00000000..5018affa --- /dev/null +++ b/dns/enhancer.go @@ -0,0 +1,76 @@ +package dns + +import ( + "net" + + "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/component/fakeip" +) + +type ResolverEnhancer struct { + mode EnhancedMode + fakePool *fakeip.Pool + mapping *cache.LruCache +} + +func (h *ResolverEnhancer) FakeIPEnabled() bool { + return h.mode == FAKEIP +} + +func (h *ResolverEnhancer) MappingEnabled() bool { + return h.mode == FAKEIP || h.mode == MAPPING +} + +func (h *ResolverEnhancer) IsFakeIP(ip net.IP) bool { + if !h.FakeIPEnabled() { + return false + } + + if pool := h.fakePool; pool != nil { + return pool.Exist(ip) + } + + return false +} + +func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) { + if pool := h.fakePool; pool != nil { + if host, existed := pool.LookBack(ip); existed { + return host, true + } + } + + if mapping := h.mapping; mapping != nil { + if host, existed := h.mapping.Get(ip.String()); existed { + return host.(string), true + } + } + + return "", false +} + +func (h *ResolverEnhancer) PatchFrom(o *ResolverEnhancer) { + if h.mapping != nil && o.mapping != nil { + o.mapping.CloneTo(h.mapping) + } + + if h.fakePool != nil && o.fakePool != nil { + h.fakePool.PatchFrom(o.fakePool) + } +} + +func NewEnhancer(cfg Config) *ResolverEnhancer { + var fakePool *fakeip.Pool + var mapping *cache.LruCache + + if cfg.EnhancedMode != NORMAL { + fakePool = cfg.Pool + mapping = cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)) + } + + return &ResolverEnhancer{ + mode: cfg.EnhancedMode, + fakePool: fakePool, + mapping: mapping, + } +} diff --git a/dns/middleware.go b/dns/middleware.go index cc58c001..8aff0647 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -3,7 +3,9 @@ package dns import ( "net" "strings" + "time" + "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/log" @@ -11,23 +13,21 @@ import ( D "github.com/miekg/dns" ) -type handler func(w D.ResponseWriter, r *D.Msg) +type handler func(r *D.Msg) (*D.Msg, error) type middleware func(next handler) handler func withHosts(hosts *trie.DomainTrie) middleware { return func(next handler) handler { - return func(w D.ResponseWriter, r *D.Msg) { + return func(r *D.Msg) (*D.Msg, error) { q := r.Question[0] if !isIPRequest(q) { - next(w, r) - return + return next(r) } record := hosts.Search(strings.TrimRight(q.Name, ".")) if record == nil { - next(w, r) - return + return next(r) } ip := record.Data.(net.IP) @@ -46,22 +46,60 @@ func withHosts(hosts *trie.DomainTrie) middleware { msg.Answer = []D.RR{rr} } else { - next(w, r) - return + return next(r) } msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true msg.RecursionAvailable = true - w.WriteMsg(msg) + return msg, nil + } + } +} + +func withMapping(mapping *cache.LruCache) middleware { + return func(next handler) handler { + return func(r *D.Msg) (*D.Msg, error) { + q := r.Question[0] + + if !isIPRequest(q) { + return next(r) + } + + msg, err := next(r) + if err != nil { + return nil, err + } + + host := strings.TrimRight(q.Name, ".") + + for _, ans := range msg.Answer { + var ip net.IP + var ttl uint32 + + switch a := ans.(type) { + case *D.A: + ip = a.A + ttl = a.Hdr.Ttl + case *D.AAAA: + ip = a.AAAA + ttl = a.Hdr.Ttl + default: + continue + } + + mapping.SetWithExpire(ip.String(), host, time.Now().Add(time.Second*time.Duration(ttl))) + } + + return msg, nil } } } func withFakeIP(fakePool *fakeip.Pool) middleware { return func(next handler) handler { - return func(w D.ResponseWriter, r *D.Msg) { + return func(r *D.Msg) (*D.Msg, error) { q := r.Question[0] if q.Qtype == D.TypeAAAA { @@ -72,17 +110,14 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { msg.Authoritative = true msg.RecursionAvailable = true - w.WriteMsg(msg) - return + return msg, nil } else if q.Qtype != D.TypeA { - next(w, r) - return + return next(r) } host := strings.TrimRight(q.Name, ".") if fakePool.LookupHost(host) { - next(w, r) - return + return next(r) } rr := &D.A{} @@ -97,13 +132,13 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { msg.Authoritative = true msg.RecursionAvailable = true - w.WriteMsg(msg) + return msg, nil } } } func withResolver(resolver *Resolver) handler { - return func(w D.ResponseWriter, r *D.Msg) { + return func(r *D.Msg) (*D.Msg, error) { q := r.Question[0] // return a empty AAAA msg when ipv6 disabled @@ -115,19 +150,18 @@ func withResolver(resolver *Resolver) handler { msg.Authoritative = true msg.RecursionAvailable = true - w.WriteMsg(msg) - return + return msg, nil } msg, err := resolver.Exchange(r) if err != nil { log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err) - D.HandleFailed(w, r) - return + return msg, err } msg.SetRcode(r, msg.Rcode) msg.Authoritative = true - w.WriteMsg(msg) + + return msg, nil } } @@ -142,15 +176,19 @@ func compose(middlewares []middleware, endpoint handler) handler { return h } -func newHandler(resolver *Resolver) handler { +func newHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { middlewares := []middleware{} if resolver.hosts != nil { middlewares = append(middlewares, withHosts(resolver.hosts)) } - if resolver.FakeIPEnabled() { - middlewares = append(middlewares, withFakeIP(resolver.pool)) + if mapper.mode == FAKEIP { + middlewares = append(middlewares, withFakeIP(mapper.fakePool)) + } + + if mapper.mode != NORMAL { + middlewares = append(middlewares, withMapping(mapper.mapping)) } return compose(middlewares, withResolver(resolver)) diff --git a/dns/resolver.go b/dns/resolver.go index 21ebb00a..b7d20b5c 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -7,7 +7,6 @@ import ( "fmt" "math/rand" "net" - "strings" "time" "github.com/Dreamacro/clash/common/cache" @@ -36,10 +35,7 @@ type result struct { type Resolver struct { ipv6 bool - mapping bool - fakeip bool hosts *trie.DomainTrie - pool *fakeip.Pool main []dnsClient fallback []dnsClient fallbackFilters []fallbackFilter @@ -126,12 +122,6 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) { msg := result.(*D.Msg) putMsgToCache(r.lruCache, q.String(), msg) - if r.mapping || r.fakeip { - ips := r.msgToIP(msg) - for _, ip := range ips { - putMsgToCache(r.lruCache, ip.String(), msg) - } - } }() isIPReq := isIPRequest(q) @@ -152,45 +142,6 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) { return } -// IPToHost return fake-ip or redir-host mapping host -func (r *Resolver) IPToHost(ip net.IP) (string, bool) { - if r.fakeip { - record, existed := r.pool.LookBack(ip) - if existed { - return record, true - } - } - - cache, _ := r.lruCache.Get(ip.String()) - if cache == nil { - return "", false - } - fqdn := cache.(*D.Msg).Question[0].Name - return strings.TrimRight(fqdn, "."), true -} - -func (r *Resolver) IsMapping() bool { - return r.mapping -} - -// FakeIPEnabled returns if fake-ip is enabled -func (r *Resolver) FakeIPEnabled() bool { - return r.fakeip -} - -// IsFakeIP determine if given ip is a fake-ip -func (r *Resolver) IsFakeIP(ip net.IP) bool { - if r.FakeIPEnabled() { - return r.pool.Exist(ip) - } - return false -} - -// PatchCache overwrite lruCache to the new resolver -func (r *Resolver) PatchCache(n *Resolver) { - r.lruCache.CloneTo(n.lruCache) -} - func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { fast, ctx := picker.WithTimeout(context.Background(), time.Second*5) for _, client := range clients { @@ -318,7 +269,7 @@ type Config struct { Hosts *trie.DomainTrie } -func New(config Config) *Resolver { +func NewResolver(config Config) *Resolver { defaultResolver := &Resolver{ main: transform(config.Default, nil), lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), @@ -328,9 +279,6 @@ func New(config Config) *Resolver { ipv6: config.IPv6, main: transform(config.Main, defaultResolver), lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), - mapping: config.EnhancedMode == MAPPING, - fakeip: config.EnhancedMode == FAKEIP, - pool: config.Pool, hosts: config.Hosts, } diff --git a/dns/server.go b/dns/server.go index 6e6a02b6..23718c4f 100644 --- a/dns/server.go +++ b/dns/server.go @@ -27,16 +27,22 @@ func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { return } - s.handler(w, r) + msg, err := s.handler(r) + if err != nil { + D.HandleFailed(w, r) + return + } + + w.WriteMsg(msg) } func (s *Server) setHandler(handler handler) { s.handler = handler } -func ReCreateServer(addr string, resolver *Resolver) error { +func ReCreateServer(addr string, resolver *Resolver, mapper *ResolverEnhancer) error { if addr == address && resolver != nil { - handler := newHandler(resolver) + handler := newHandler(resolver, mapper) server.setHandler(handler) return nil } @@ -68,7 +74,7 @@ func ReCreateServer(addr string, resolver *Resolver) error { } address = addr - handler := newHandler(resolver) + handler := newHandler(resolver, mapper) server = &Server{handler: handler} server.Server = &D.Server{Addr: addr, PacketConn: p, Handler: server} diff --git a/hub/executor/executor.go b/hub/executor/executor.go index f0f4c9b6..c24d6123 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -103,11 +103,12 @@ func updateExperimental(c *config.Config) {} func updateDNS(c *config.DNS) { if !c.Enable { resolver.DefaultResolver = nil - tunnel.SetResolver(nil) - dns.ReCreateServer("", nil) + resolver.DefaultHostMapper = nil + dns.ReCreateServer("", nil, nil) return } - r := dns.New(dns.Config{ + + cfg := dns.Config{ Main: c.NameServer, Fallback: c.Fallback, IPv6: c.IPv6, @@ -119,18 +120,20 @@ func updateDNS(c *config.DNS) { IPCIDR: c.FallbackFilter.IPCIDR, }, Default: c.DefaultNameserver, - }) + } - // reuse cache of old resolver - if resolver.DefaultResolver != nil { - if o, ok := resolver.DefaultResolver.(*dns.Resolver); ok { - o.PatchCache(r) - } + r := dns.NewResolver(cfg) + m := dns.NewEnhancer(cfg) + + // reuse cache of old host mapper + if old := resolver.DefaultHostMapper; old != nil { + m.PatchFrom(old.(*dns.ResolverEnhancer)) } resolver.DefaultResolver = r - tunnel.SetResolver(r) - if err := dns.ReCreateServer(c.Listen, r); err != nil { + resolver.DefaultHostMapper = m + + if err := dns.ReCreateServer(c.Listen, r, m); err != nil { log.Errorln("Start DNS server error: %s", err.Error()) return } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 64911bc7..b4441116 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -12,21 +12,19 @@ import ( "github.com/Dreamacro/clash/component/nat" "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/log" channels "gopkg.in/eapache/channels.v1" ) var ( - tcpQueue = channels.NewInfiniteChannel() - udpQueue = channels.NewInfiniteChannel() - natTable = nat.New() - rules []C.Rule - proxies = make(map[string]C.Proxy) - providers map[string]provider.ProxyProvider - configMux sync.RWMutex - enhancedMode *dns.Resolver + tcpQueue = channels.NewInfiniteChannel() + udpQueue = channels.NewInfiniteChannel() + natTable = nat.New() + rules []C.Rule + proxies = make(map[string]C.Proxy) + providers map[string]provider.ProxyProvider + configMux sync.RWMutex // Outbound Rule mode = Rule @@ -89,11 +87,6 @@ func SetMode(m TunnelMode) { mode = m } -// SetResolver set custom dns resolver for enhanced mode -func SetResolver(r *dns.Resolver) { - enhancedMode = r -} - // processUDP starts a loop to handle udp packet func processUDP() { queue := udpQueue.Out() @@ -120,7 +113,7 @@ func process() { } func needLookupIP(metadata *C.Metadata) bool { - return enhancedMode != nil && (enhancedMode.IsMapping() || enhancedMode.FakeIPEnabled()) && metadata.Host == "" && metadata.DstIP != nil + return resolver.MappingEnabled() && metadata.Host == "" && metadata.DstIP != nil } func preHandleMetadata(metadata *C.Metadata) error { @@ -131,17 +124,17 @@ func preHandleMetadata(metadata *C.Metadata) error { // preprocess enhanced-mode metadata if needLookupIP(metadata) { - host, exist := enhancedMode.IPToHost(metadata.DstIP) + host, exist := resolver.FindHostByIP(metadata.DstIP) if exist { metadata.Host = host metadata.AddrType = C.AtypDomainName - if enhancedMode.FakeIPEnabled() { + if resolver.FakeIPEnabled() { metadata.DstIP = nil } else if node := resolver.DefaultHosts.Search(host); node != nil { // redir-host should lookup the hosts metadata.DstIP = node.Data.(net.IP) } - } else if enhancedMode.IsFakeIP(metadata.DstIP) { + } else if resolver.IsFakeIP(metadata.DstIP) { return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) } } @@ -177,7 +170,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) { // make a fAddr if requset ip is fakeip var fAddr net.Addr - if enhancedMode != nil && enhancedMode.IsFakeIP(metadata.DstIP) { + if resolver.IsFakeIP(metadata.DstIP) { fAddr = metadata.UDPAddr() }