Chore: split enhanced mode instance (#936)

Co-authored-by: Dreamacro <305009791@qq.com>
This commit is contained in:
Kr328 2020-09-17 10:48:42 +08:00 committed by GitHub
parent e773f95f21
commit 558ac6b965
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 228 additions and 113 deletions

View File

@ -89,6 +89,11 @@ func (p *Pool) Gateway() net.IP {
return uintToIP(p.gateway) return uintToIP(p.gateway)
} }
// PatchFrom clone cache from old pool
func (p *Pool) PatchFrom(o *Pool) {
o.cache.CloneTo(p.cache)
}
func (p *Pool) get(host string) net.IP { func (p *Pool) get(host string) net.IP {
current := p.offset current := p.offset
for { for {

View File

@ -0,0 +1,46 @@
package resolver
import (
"net"
)
var DefaultHostMapper Enhancer
type Enhancer interface {
FakeIPEnabled() bool
MappingEnabled() bool
IsFakeIP(net.IP) bool
FindHostByIP(net.IP) (string, bool)
}
func FakeIPEnabled() bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.FakeIPEnabled()
}
return false
}
func MappingEnabled() bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.MappingEnabled()
}
return false
}
func IsFakeIP(ip net.IP) bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.IsFakeIP(ip)
}
return false
}
func FindHostByIP(ip net.IP) (string, bool) {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.FindHostByIP(ip)
}
return "", false
}

76
dns/enhancer.go Normal file
View File

@ -0,0 +1,76 @@
package dns
import (
"net"
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip"
)
type ResolverEnhancer struct {
mode EnhancedMode
fakePool *fakeip.Pool
mapping *cache.LruCache
}
func (h *ResolverEnhancer) FakeIPEnabled() bool {
return h.mode == FAKEIP
}
func (h *ResolverEnhancer) MappingEnabled() bool {
return h.mode == FAKEIP || h.mode == MAPPING
}
func (h *ResolverEnhancer) IsFakeIP(ip net.IP) bool {
if !h.FakeIPEnabled() {
return false
}
if pool := h.fakePool; pool != nil {
return pool.Exist(ip)
}
return false
}
func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) {
if pool := h.fakePool; pool != nil {
if host, existed := pool.LookBack(ip); existed {
return host, true
}
}
if mapping := h.mapping; mapping != nil {
if host, existed := h.mapping.Get(ip.String()); existed {
return host.(string), true
}
}
return "", false
}
func (h *ResolverEnhancer) PatchFrom(o *ResolverEnhancer) {
if h.mapping != nil && o.mapping != nil {
o.mapping.CloneTo(h.mapping)
}
if h.fakePool != nil && o.fakePool != nil {
h.fakePool.PatchFrom(o.fakePool)
}
}
func NewEnhancer(cfg Config) *ResolverEnhancer {
var fakePool *fakeip.Pool
var mapping *cache.LruCache
if cfg.EnhancedMode != NORMAL {
fakePool = cfg.Pool
mapping = cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true))
}
return &ResolverEnhancer{
mode: cfg.EnhancedMode,
fakePool: fakePool,
mapping: mapping,
}
}

View File

@ -3,7 +3,9 @@ package dns
import ( import (
"net" "net"
"strings" "strings"
"time"
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/component/trie"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
@ -11,23 +13,21 @@ import (
D "github.com/miekg/dns" D "github.com/miekg/dns"
) )
type handler func(w D.ResponseWriter, r *D.Msg) type handler func(r *D.Msg) (*D.Msg, error)
type middleware func(next handler) handler type middleware func(next handler) handler
func withHosts(hosts *trie.DomainTrie) middleware { func withHosts(hosts *trie.DomainTrie) middleware {
return func(next handler) handler { return func(next handler) handler {
return func(w D.ResponseWriter, r *D.Msg) { return func(r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
if !isIPRequest(q) { if !isIPRequest(q) {
next(w, r) return next(r)
return
} }
record := hosts.Search(strings.TrimRight(q.Name, ".")) record := hosts.Search(strings.TrimRight(q.Name, "."))
if record == nil { if record == nil {
next(w, r) return next(r)
return
} }
ip := record.Data.(net.IP) ip := record.Data.(net.IP)
@ -46,22 +46,60 @@ func withHosts(hosts *trie.DomainTrie) middleware {
msg.Answer = []D.RR{rr} msg.Answer = []D.RR{rr}
} else { } else {
next(w, r) return next(r)
return
} }
msg.SetRcode(r, D.RcodeSuccess) msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true msg.Authoritative = true
msg.RecursionAvailable = true msg.RecursionAvailable = true
w.WriteMsg(msg) return msg, nil
}
}
}
func withMapping(mapping *cache.LruCache) middleware {
return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
if !isIPRequest(q) {
return next(r)
}
msg, err := next(r)
if err != nil {
return nil, err
}
host := strings.TrimRight(q.Name, ".")
for _, ans := range msg.Answer {
var ip net.IP
var ttl uint32
switch a := ans.(type) {
case *D.A:
ip = a.A
ttl = a.Hdr.Ttl
case *D.AAAA:
ip = a.AAAA
ttl = a.Hdr.Ttl
default:
continue
}
mapping.SetWithExpire(ip.String(), host, time.Now().Add(time.Second*time.Duration(ttl)))
}
return msg, nil
} }
} }
} }
func withFakeIP(fakePool *fakeip.Pool) middleware { func withFakeIP(fakePool *fakeip.Pool) middleware {
return func(next handler) handler { return func(next handler) handler {
return func(w D.ResponseWriter, r *D.Msg) { return func(r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
if q.Qtype == D.TypeAAAA { if q.Qtype == D.TypeAAAA {
@ -72,17 +110,14 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
msg.Authoritative = true msg.Authoritative = true
msg.RecursionAvailable = true msg.RecursionAvailable = true
w.WriteMsg(msg) return msg, nil
return
} else if q.Qtype != D.TypeA { } else if q.Qtype != D.TypeA {
next(w, r) return next(r)
return
} }
host := strings.TrimRight(q.Name, ".") host := strings.TrimRight(q.Name, ".")
if fakePool.LookupHost(host) { if fakePool.LookupHost(host) {
next(w, r) return next(r)
return
} }
rr := &D.A{} rr := &D.A{}
@ -97,13 +132,13 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
msg.Authoritative = true msg.Authoritative = true
msg.RecursionAvailable = true msg.RecursionAvailable = true
w.WriteMsg(msg) return msg, nil
} }
} }
} }
func withResolver(resolver *Resolver) handler { func withResolver(resolver *Resolver) handler {
return func(w D.ResponseWriter, r *D.Msg) { return func(r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
// return a empty AAAA msg when ipv6 disabled // return a empty AAAA msg when ipv6 disabled
@ -115,19 +150,18 @@ func withResolver(resolver *Resolver) handler {
msg.Authoritative = true msg.Authoritative = true
msg.RecursionAvailable = true msg.RecursionAvailable = true
w.WriteMsg(msg) return msg, nil
return
} }
msg, err := resolver.Exchange(r) msg, err := resolver.Exchange(r)
if err != nil { if err != nil {
log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err) log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
D.HandleFailed(w, r) return msg, err
return
} }
msg.SetRcode(r, msg.Rcode) msg.SetRcode(r, msg.Rcode)
msg.Authoritative = true msg.Authoritative = true
w.WriteMsg(msg)
return msg, nil
} }
} }
@ -142,15 +176,19 @@ func compose(middlewares []middleware, endpoint handler) handler {
return h return h
} }
func newHandler(resolver *Resolver) handler { func newHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
middlewares := []middleware{} middlewares := []middleware{}
if resolver.hosts != nil { if resolver.hosts != nil {
middlewares = append(middlewares, withHosts(resolver.hosts)) middlewares = append(middlewares, withHosts(resolver.hosts))
} }
if resolver.FakeIPEnabled() { if mapper.mode == FAKEIP {
middlewares = append(middlewares, withFakeIP(resolver.pool)) middlewares = append(middlewares, withFakeIP(mapper.fakePool))
}
if mapper.mode != NORMAL {
middlewares = append(middlewares, withMapping(mapper.mapping))
} }
return compose(middlewares, withResolver(resolver)) return compose(middlewares, withResolver(resolver))

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"strings"
"time" "time"
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
@ -36,10 +35,7 @@ type result struct {
type Resolver struct { type Resolver struct {
ipv6 bool ipv6 bool
mapping bool
fakeip bool
hosts *trie.DomainTrie hosts *trie.DomainTrie
pool *fakeip.Pool
main []dnsClient main []dnsClient
fallback []dnsClient fallback []dnsClient
fallbackFilters []fallbackFilter fallbackFilters []fallbackFilter
@ -126,12 +122,6 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
msg := result.(*D.Msg) msg := result.(*D.Msg)
putMsgToCache(r.lruCache, q.String(), msg) putMsgToCache(r.lruCache, q.String(), msg)
if r.mapping || r.fakeip {
ips := r.msgToIP(msg)
for _, ip := range ips {
putMsgToCache(r.lruCache, ip.String(), msg)
}
}
}() }()
isIPReq := isIPRequest(q) isIPReq := isIPRequest(q)
@ -152,45 +142,6 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
return return
} }
// IPToHost return fake-ip or redir-host mapping host
func (r *Resolver) IPToHost(ip net.IP) (string, bool) {
if r.fakeip {
record, existed := r.pool.LookBack(ip)
if existed {
return record, true
}
}
cache, _ := r.lruCache.Get(ip.String())
if cache == nil {
return "", false
}
fqdn := cache.(*D.Msg).Question[0].Name
return strings.TrimRight(fqdn, "."), true
}
func (r *Resolver) IsMapping() bool {
return r.mapping
}
// FakeIPEnabled returns if fake-ip is enabled
func (r *Resolver) FakeIPEnabled() bool {
return r.fakeip
}
// IsFakeIP determine if given ip is a fake-ip
func (r *Resolver) IsFakeIP(ip net.IP) bool {
if r.FakeIPEnabled() {
return r.pool.Exist(ip)
}
return false
}
// PatchCache overwrite lruCache to the new resolver
func (r *Resolver) PatchCache(n *Resolver) {
r.lruCache.CloneTo(n.lruCache)
}
func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {
fast, ctx := picker.WithTimeout(context.Background(), time.Second*5) fast, ctx := picker.WithTimeout(context.Background(), time.Second*5)
for _, client := range clients { for _, client := range clients {
@ -318,7 +269,7 @@ type Config struct {
Hosts *trie.DomainTrie Hosts *trie.DomainTrie
} }
func New(config Config) *Resolver { func NewResolver(config Config) *Resolver {
defaultResolver := &Resolver{ defaultResolver := &Resolver{
main: transform(config.Default, nil), main: transform(config.Default, nil),
lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)),
@ -328,9 +279,6 @@ func New(config Config) *Resolver {
ipv6: config.IPv6, ipv6: config.IPv6,
main: transform(config.Main, defaultResolver), main: transform(config.Main, defaultResolver),
lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)),
mapping: config.EnhancedMode == MAPPING,
fakeip: config.EnhancedMode == FAKEIP,
pool: config.Pool,
hosts: config.Hosts, hosts: config.Hosts,
} }

View File

@ -27,16 +27,22 @@ func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
return return
} }
s.handler(w, r) msg, err := s.handler(r)
if err != nil {
D.HandleFailed(w, r)
return
}
w.WriteMsg(msg)
} }
func (s *Server) setHandler(handler handler) { func (s *Server) setHandler(handler handler) {
s.handler = handler s.handler = handler
} }
func ReCreateServer(addr string, resolver *Resolver) error { func ReCreateServer(addr string, resolver *Resolver, mapper *ResolverEnhancer) error {
if addr == address && resolver != nil { if addr == address && resolver != nil {
handler := newHandler(resolver) handler := newHandler(resolver, mapper)
server.setHandler(handler) server.setHandler(handler)
return nil return nil
} }
@ -68,7 +74,7 @@ func ReCreateServer(addr string, resolver *Resolver) error {
} }
address = addr address = addr
handler := newHandler(resolver) handler := newHandler(resolver, mapper)
server = &Server{handler: handler} server = &Server{handler: handler}
server.Server = &D.Server{Addr: addr, PacketConn: p, Handler: server} server.Server = &D.Server{Addr: addr, PacketConn: p, Handler: server}

View File

@ -103,11 +103,12 @@ func updateExperimental(c *config.Config) {}
func updateDNS(c *config.DNS) { func updateDNS(c *config.DNS) {
if !c.Enable { if !c.Enable {
resolver.DefaultResolver = nil resolver.DefaultResolver = nil
tunnel.SetResolver(nil) resolver.DefaultHostMapper = nil
dns.ReCreateServer("", nil) dns.ReCreateServer("", nil, nil)
return return
} }
r := dns.New(dns.Config{
cfg := dns.Config{
Main: c.NameServer, Main: c.NameServer,
Fallback: c.Fallback, Fallback: c.Fallback,
IPv6: c.IPv6, IPv6: c.IPv6,
@ -119,18 +120,20 @@ func updateDNS(c *config.DNS) {
IPCIDR: c.FallbackFilter.IPCIDR, IPCIDR: c.FallbackFilter.IPCIDR,
}, },
Default: c.DefaultNameserver, Default: c.DefaultNameserver,
})
// reuse cache of old resolver
if resolver.DefaultResolver != nil {
if o, ok := resolver.DefaultResolver.(*dns.Resolver); ok {
o.PatchCache(r)
} }
r := dns.NewResolver(cfg)
m := dns.NewEnhancer(cfg)
// reuse cache of old host mapper
if old := resolver.DefaultHostMapper; old != nil {
m.PatchFrom(old.(*dns.ResolverEnhancer))
} }
resolver.DefaultResolver = r resolver.DefaultResolver = r
tunnel.SetResolver(r) resolver.DefaultHostMapper = m
if err := dns.ReCreateServer(c.Listen, r); err != nil {
if err := dns.ReCreateServer(c.Listen, r, m); err != nil {
log.Errorln("Start DNS server error: %s", err.Error()) log.Errorln("Start DNS server error: %s", err.Error())
return return
} }

View File

@ -12,7 +12,6 @@ import (
"github.com/Dreamacro/clash/component/nat" "github.com/Dreamacro/clash/component/nat"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
channels "gopkg.in/eapache/channels.v1" channels "gopkg.in/eapache/channels.v1"
@ -26,7 +25,6 @@ var (
proxies = make(map[string]C.Proxy) proxies = make(map[string]C.Proxy)
providers map[string]provider.ProxyProvider providers map[string]provider.ProxyProvider
configMux sync.RWMutex configMux sync.RWMutex
enhancedMode *dns.Resolver
// Outbound Rule // Outbound Rule
mode = Rule mode = Rule
@ -89,11 +87,6 @@ func SetMode(m TunnelMode) {
mode = m mode = m
} }
// SetResolver set custom dns resolver for enhanced mode
func SetResolver(r *dns.Resolver) {
enhancedMode = r
}
// processUDP starts a loop to handle udp packet // processUDP starts a loop to handle udp packet
func processUDP() { func processUDP() {
queue := udpQueue.Out() queue := udpQueue.Out()
@ -120,7 +113,7 @@ func process() {
} }
func needLookupIP(metadata *C.Metadata) bool { func needLookupIP(metadata *C.Metadata) bool {
return enhancedMode != nil && (enhancedMode.IsMapping() || enhancedMode.FakeIPEnabled()) && metadata.Host == "" && metadata.DstIP != nil return resolver.MappingEnabled() && metadata.Host == "" && metadata.DstIP != nil
} }
func preHandleMetadata(metadata *C.Metadata) error { func preHandleMetadata(metadata *C.Metadata) error {
@ -131,17 +124,17 @@ func preHandleMetadata(metadata *C.Metadata) error {
// preprocess enhanced-mode metadata // preprocess enhanced-mode metadata
if needLookupIP(metadata) { if needLookupIP(metadata) {
host, exist := enhancedMode.IPToHost(metadata.DstIP) host, exist := resolver.FindHostByIP(metadata.DstIP)
if exist { if exist {
metadata.Host = host metadata.Host = host
metadata.AddrType = C.AtypDomainName metadata.AddrType = C.AtypDomainName
if enhancedMode.FakeIPEnabled() { if resolver.FakeIPEnabled() {
metadata.DstIP = nil metadata.DstIP = nil
} else if node := resolver.DefaultHosts.Search(host); node != nil { } else if node := resolver.DefaultHosts.Search(host); node != nil {
// redir-host should lookup the hosts // redir-host should lookup the hosts
metadata.DstIP = node.Data.(net.IP) metadata.DstIP = node.Data.(net.IP)
} }
} else if enhancedMode.IsFakeIP(metadata.DstIP) { } else if resolver.IsFakeIP(metadata.DstIP) {
return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) return fmt.Errorf("fake DNS record %s missing", metadata.DstIP)
} }
} }
@ -177,7 +170,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
// make a fAddr if requset ip is fakeip // make a fAddr if requset ip is fakeip
var fAddr net.Addr var fAddr net.Addr
if enhancedMode != nil && enhancedMode.IsFakeIP(metadata.DstIP) { if resolver.IsFakeIP(metadata.DstIP) {
fAddr = metadata.UDPAddr() fAddr = metadata.UDPAddr()
} }