diff --git a/adapters/inbound/socket.go b/adapters/inbound/socket.go index 8d701ad6..8c0ef49d 100644 --- a/adapters/inbound/socket.go +++ b/adapters/inbound/socket.go @@ -29,9 +29,12 @@ func (s *SocketAdapter) Conn() net.Conn { } // NewSocket is SocketAdapter generator -func NewSocket(target socks.Addr, conn net.Conn) *SocketAdapter { +func NewSocket(target socks.Addr, conn net.Conn, source C.SourceType) *SocketAdapter { + metadata := parseSocksAddr(target) + metadata.Source = source + return &SocketAdapter{ conn: conn, - metadata: parseSocksAddr(target), + metadata: metadata, } } diff --git a/adapters/inbound/util.go b/adapters/inbound/util.go index 06499f04..a1f5cd08 100644 --- a/adapters/inbound/util.go +++ b/adapters/inbound/util.go @@ -10,32 +10,26 @@ import ( ) func parseSocksAddr(target socks.Addr) *C.Metadata { - var host, port string - var ip net.IP + metadata := &C.Metadata{ + NetWork: C.TCP, + AddrType: int(target[0]), + } switch target[0] { case socks.AtypDomainName: - host = string(target[2 : 2+target[1]]) - port = strconv.Itoa((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1])) - ipAddr, err := net.ResolveIPAddr("ip", host) - if err == nil { - ip = ipAddr.IP - } + metadata.Host = string(target[2 : 2+target[1]]) + metadata.Port = strconv.Itoa((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1])) case socks.AtypIPv4: - ip = net.IP(target[1 : 1+net.IPv4len]) - port = strconv.Itoa((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1])) + ip := net.IP(target[1 : 1+net.IPv4len]) + metadata.IP = &ip + metadata.Port = strconv.Itoa((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1])) case socks.AtypIPv6: - ip = net.IP(target[1 : 1+net.IPv6len]) - port = strconv.Itoa((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1])) + ip := net.IP(target[1 : 1+net.IPv6len]) + metadata.IP = &ip + metadata.Port = strconv.Itoa((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1])) } - return &C.Metadata{ - NetWork: C.TCP, - AddrType: int(target[0]), - Host: host, - IP: &ip, - Port: port, - } + return metadata } func parseHTTPAddr(request *http.Request) *C.Metadata { @@ -44,28 +38,26 @@ func parseHTTPAddr(request *http.Request) *C.Metadata { if port == "" { port = "80" } - ipAddr, err := net.ResolveIPAddr("ip", host) - var resolveIP *net.IP - if err == nil { - resolveIP = &ipAddr.IP - } - var addType int - ip := net.ParseIP(host) - switch { - case ip == nil: - addType = socks.AtypDomainName - case ip.To4() == nil: - addType = socks.AtypIPv6 - default: - addType = socks.AtypIPv4 - } - - return &C.Metadata{ + metadata := &C.Metadata{ NetWork: C.TCP, - AddrType: addType, + Source: C.HTTP, + AddrType: C.AtypDomainName, Host: host, - IP: resolveIP, + IP: nil, Port: port, } + + ip := net.ParseIP(host) + if ip != nil { + switch { + case ip.To4() == nil: + metadata.AddrType = C.AtypIPv6 + default: + metadata.AddrType = C.AtypIPv4 + } + metadata.IP = &ip + } + + return metadata } diff --git a/adapters/outbound/direct.go b/adapters/outbound/direct.go index 1a6a1144..052d3f3b 100644 --- a/adapters/outbound/direct.go +++ b/adapters/outbound/direct.go @@ -33,7 +33,12 @@ func (d *Direct) Type() C.AdapterType { } func (d *Direct) Generator(metadata *C.Metadata) (adapter C.ProxyAdapter, err error) { - c, err := net.DialTimeout("tcp", net.JoinHostPort(metadata.String(), metadata.Port), tcpTimeout) + address := net.JoinHostPort(metadata.Host, metadata.Port) + if metadata.IP != nil { + address = net.JoinHostPort(metadata.IP.String(), metadata.Port) + } + + c, err := net.DialTimeout("tcp", address, tcpTimeout) if err != nil { return } diff --git a/common/cache/cache.go b/common/cache/cache.go new file mode 100644 index 00000000..f33591ea --- /dev/null +++ b/common/cache/cache.go @@ -0,0 +1,91 @@ +package cache + +import ( + "runtime" + "sync" + "time" +) + +// Cache store element with a expired time +type Cache struct { + *cache +} + +type cache struct { + mapping sync.Map + janitor *janitor +} + +type element struct { + Expired time.Time + Payload interface{} +} + +// Put element in Cache with its ttl +func (c *cache) Put(key interface{}, payload interface{}, ttl time.Duration) { + c.mapping.Store(key, &element{ + Payload: payload, + Expired: time.Now().Add(ttl), + }) +} + +// Get element in Cache, and drop when it expired +func (c *cache) Get(key interface{}) interface{} { + item, exist := c.mapping.Load(key) + if !exist { + return nil + } + elm := item.(*element) + // expired + if time.Since(elm.Expired) > 0 { + c.mapping.Delete(key) + return nil + } + return elm.Payload +} + +func (c *cache) cleanup() { + c.mapping.Range(func(k, v interface{}) bool { + key := k.(string) + elm := v.(*element) + if time.Since(elm.Expired) > 0 { + c.mapping.Delete(key) + } + return true + }) +} + +type janitor struct { + interval time.Duration + stop chan struct{} +} + +func (j *janitor) process(c *cache) { + ticker := time.NewTicker(j.interval) + for { + select { + case <-ticker.C: + c.cleanup() + case <-j.stop: + ticker.Stop() + return + } + } +} + +func stopJanitor(c *Cache) { + c.janitor.stop <- struct{}{} +} + +// New return *Cache +func New(interval time.Duration) *Cache { + j := &janitor{ + interval: interval, + stop: make(chan struct{}), + } + c := &cache{janitor: j} + go j.process(c) + C := &Cache{c} + runtime.SetFinalizer(C, stopJanitor) + return C +} diff --git a/common/cache/cache_test.go b/common/cache/cache_test.go new file mode 100644 index 00000000..101ca869 --- /dev/null +++ b/common/cache/cache_test.go @@ -0,0 +1,70 @@ +package cache + +import ( + "runtime" + "testing" + "time" +) + +func TestCache_Basic(t *testing.T) { + interval := 200 * time.Millisecond + ttl := 20 * time.Millisecond + c := New(interval) + c.Put("int", 1, ttl) + c.Put("string", "a", ttl) + + i := c.Get("int") + if i.(int) != 1 { + t.Error("should recv 1") + } + + s := c.Get("string") + if s.(string) != "a" { + t.Error("should recv 'a'") + } +} + +func TestCache_TTL(t *testing.T) { + interval := 200 * time.Millisecond + ttl := 20 * time.Millisecond + c := New(interval) + c.Put("int", 1, ttl) + + i := c.Get("int") + if i.(int) != 1 { + t.Error("should recv 1") + } + + time.Sleep(ttl * 2) + i = c.Get("int") + if i != nil { + t.Error("should recv nil") + } +} + +func TestCache_AutoCleanup(t *testing.T) { + interval := 10 * time.Millisecond + ttl := 15 * time.Millisecond + c := New(interval) + c.Put("int", 1, ttl) + + time.Sleep(ttl * 2) + i := c.Get("int") + if i != nil { + t.Error("should recv nil") + } +} + +func TestCache_AutoGC(t *testing.T) { + sign := make(chan struct{}) + go func() { + interval := 10 * time.Millisecond + ttl := 15 * time.Millisecond + c := New(interval) + c.Put("int", 1, ttl) + sign <- struct{}{} + }() + + <-sign + runtime.GC() +} diff --git a/common/picker/picker.go b/common/picker/picker.go new file mode 100644 index 00000000..07e2076d --- /dev/null +++ b/common/picker/picker.go @@ -0,0 +1,22 @@ +package picker + +import "context" + +func SelectFast(ctx context.Context, in <-chan interface{}) <-chan interface{} { + out := make(chan interface{}) + go func() { + select { + case p, open := <-in: + if open { + out <- p + } + case <-ctx.Done(): + } + + close(out) + for range in { + } + }() + + return out +} diff --git a/common/picker/picker_test.go b/common/picker/picker_test.go new file mode 100644 index 00000000..f33627f7 --- /dev/null +++ b/common/picker/picker_test.go @@ -0,0 +1,44 @@ +package picker + +import ( + "context" + "testing" + "time" +) + +func sleepAndSend(delay int, in chan<- interface{}, input interface{}) { + time.Sleep(time.Millisecond * time.Duration(delay)) + in <- input +} + +func sleepAndClose(delay int, in chan interface{}) { + time.Sleep(time.Millisecond * time.Duration(delay)) + close(in) +} + +func TestPicker_Basic(t *testing.T) { + in := make(chan interface{}) + fast := SelectFast(context.Background(), in) + go sleepAndSend(20, in, 1) + go sleepAndSend(30, in, 2) + go sleepAndClose(40, in) + + number, exist := <-fast + if !exist || number != 1 { + t.Error("should recv 1", exist, number) + } +} + +func TestPicker_Timeout(t *testing.T) { + in := make(chan interface{}) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5) + defer cancel() + fast := SelectFast(ctx, in) + go sleepAndSend(20, in, 1) + go sleepAndClose(30, in) + + _, exist := <-fast + if exist { + t.Error("should recv false") + } +} diff --git a/config/config.go b/config/config.go index 51ed4f2e..775f9497 100644 --- a/config/config.go +++ b/config/config.go @@ -3,12 +3,15 @@ package config import ( "fmt" "io/ioutil" + "net" + "net/url" "os" "strings" adapters "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/common/structure" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/log" R "github.com/Dreamacro/clash/rules" T "github.com/Dreamacro/clash/tunnel" @@ -28,28 +31,49 @@ type General struct { Secret string `json:"secret,omitempty"` } -type rawConfig struct { - Port int `yaml:"port"` - SocksPort int `yaml:"socks-port"` - RedirPort int `yaml:"redir-port"` - AllowLan bool `yaml:"allow-lan"` - Mode string `yaml:"mode"` - LogLevel string `yaml:"log-level"` - ExternalController string `yaml:"external-controller"` - Secret string `yaml:"secret"` - - Proxy []map[string]interface{} `yaml:"Proxy"` - ProxyGroup []map[string]interface{} `yaml:"Proxy Group"` - Rule []string `yaml:"Rule"` +// DNS config +type DNS struct { + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + NameServer []dns.NameServer `yaml:"nameserver"` + Fallback []dns.NameServer `yaml:"fallback"` + Listen string `yaml:"listen"` + EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` } // Config is clash config manager type Config struct { General *General + DNS *DNS Rules []C.Rule Proxies map[string]C.Proxy } +type rawDNS struct { + Enable bool `yaml:"enable"` + IPv6 bool `yaml:"ipv6"` + NameServer []string `yaml:"nameserver"` + Fallback []string `yaml:"fallback"` + Listen string `yaml:"listen"` + EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` +} + +type rawConfig struct { + Port int `yaml:"port"` + SocksPort int `yaml:"socks-port"` + RedirPort int `yaml:"redir-port"` + AllowLan bool `yaml:"allow-lan"` + Mode T.Mode `yaml:"mode"` + LogLevel log.LogLevel `yaml:"log-level"` + ExternalController string `yaml:"external-controller"` + Secret string `yaml:"secret"` + + DNS *rawDNS `yaml:"dns"` + Proxy []map[string]interface{} `yaml:"Proxy"` + ProxyGroup []map[string]interface{} `yaml:"Proxy Group"` + Rule []string `yaml:"Rule"` +} + func readConfig(path string) (*rawConfig, error) { if _, err := os.Stat(path); os.IsNotExist(err) { return nil, err @@ -66,8 +90,8 @@ func readConfig(path string) (*rawConfig, error) { // config with some default value rawConfig := &rawConfig{ AllowLan: false, - Mode: T.Rule.String(), - LogLevel: log.INFO.String(), + Mode: T.Rule, + LogLevel: log.INFO, Rule: []string{}, Proxy: []map[string]interface{}{}, ProxyGroup: []map[string]interface{}{}, @@ -103,6 +127,12 @@ func Parse(path string) (*Config, error) { } config.Rules = rules + dnsCfg, err := parseDNS(rawCfg.DNS) + if err != nil { + return nil, err + } + config.DNS = dnsCfg + return config, nil } @@ -111,20 +141,10 @@ func parseGeneral(cfg *rawConfig) (*General, error) { socksPort := cfg.SocksPort redirPort := cfg.RedirPort allowLan := cfg.AllowLan - logLevelString := cfg.LogLevel - modeString := cfg.Mode externalController := cfg.ExternalController secret := cfg.Secret - - mode, exist := T.ModeMapping[modeString] - if !exist { - return nil, fmt.Errorf("General.mode value invalid") - } - - logLevel, exist := log.LogLevelMapping[logLevelString] - if !exist { - return nil, fmt.Errorf("General.log-level value invalid") - } + mode := cfg.Mode + logLevel := cfg.LogLevel general := &General{ Port: port, @@ -310,3 +330,78 @@ func parseRules(cfg *rawConfig) ([]C.Rule, error) { return rules, nil } + +func hostWithDefaultPort(host string, defPort string) (string, error) { + if !strings.Contains(host, ":") { + host += ":" + } + + hostname, port, err := net.SplitHostPort(host) + if err != nil { + return "", err + } + + if port == "" { + port = defPort + } + + return net.JoinHostPort(hostname, port), nil +} + +func parseNameServer(servers []string) ([]dns.NameServer, error) { + nameservers := []dns.NameServer{} + log.Debugln("%#v", servers) + + for idx, server := range servers { + // parse without scheme .e.g 8.8.8.8:53 + if host, err := hostWithDefaultPort(server, "53"); err == nil { + nameservers = append( + nameservers, + dns.NameServer{Addr: host}, + ) + continue + } + + u, err := url.Parse(server) + if err != nil { + return nil, fmt.Errorf("DNS NameServer[%d] format error: %s", idx, err.Error()) + } + + if u.Scheme != "tls" { + return nil, fmt.Errorf("DNS NameServer[%d] unsupport scheme: %s", idx, u.Scheme) + } + + host, err := hostWithDefaultPort(u.Host, "853") + nameservers = append( + nameservers, + dns.NameServer{ + Net: "tcp-tls", + Addr: host, + }, + ) + } + + return nameservers, nil +} + +func parseDNS(cfg *rawDNS) (*DNS, error) { + if cfg.Enable && len(cfg.NameServer) == 0 { + return nil, fmt.Errorf("If DNS configuration is turned on, NameServer cannot be empty") + } + + dnsCfg := &DNS{ + Enable: cfg.Enable, + Listen: cfg.Listen, + EnhancedMode: cfg.EnhancedMode, + } + + if nameserver, err := parseNameServer(cfg.NameServer); err == nil { + dnsCfg.NameServer = nameserver + } + + if fallback, err := parseNameServer(cfg.Fallback); err == nil { + dnsCfg.Fallback = fallback + } + + return dnsCfg, nil +} diff --git a/constant/metadata.go b/constant/metadata.go index dfc274e7..1cb19033 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -15,6 +15,7 @@ const ( HTTP SourceType = iota SOCKS + REDIR ) type NetWork int diff --git a/dns/client.go b/dns/client.go new file mode 100644 index 00000000..162f1640 --- /dev/null +++ b/dns/client.go @@ -0,0 +1,263 @@ +package dns + +import ( + "context" + "crypto/tls" + "errors" + "net" + "strings" + "sync" + "time" + + "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/common/picker" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + + D "github.com/miekg/dns" + geoip2 "github.com/oschwald/geoip2-golang" +) + +var ( + globalSessionCache = tls.NewLRUClientSessionCache(64) + + mmdb *geoip2.Reader + once sync.Once + resolver *Resolver +) + +type Resolver struct { + ipv6 bool + mapping bool + fallback []*nameserver + main []*nameserver + cache *cache.Cache +} + +type result struct { + Msg *D.Msg + Error error +} + +func isIPRequest(q D.Question) bool { + if q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA) { + return true + } + return false +} + +func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { + if len(m.Question) == 0 { + return nil, errors.New("should have one question at least") + } + + q := m.Question[0] + cache := r.cache.Get(q.String()) + if cache != nil { + return cache.(*D.Msg).Copy(), nil + } + defer func() { + if msg != nil { + putMsgToCache(r.cache, q.String(), msg) + if r.mapping { + ips, err := r.msgToIP(msg) + if err != nil { + log.Debugln("[DNS] msg to ip error: %s", err.Error()) + return + } + for _, ip := range ips { + putMsgToCache(r.cache, ip.String(), msg) + } + } + } + }() + + isIPReq := isIPRequest(q) + if isIPReq { + msg, err = r.resolveIP(m) + return + } + + msg, err = r.exchange(r.main, m) + return +} + +func (r *Resolver) exchange(servers []*nameserver, m *D.Msg) (msg *D.Msg, err error) { + in := make(chan interface{}) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + fast := picker.SelectFast(ctx, in) + + wg := sync.WaitGroup{} + wg.Add(len(servers)) + for _, server := range servers { + go func(s *nameserver) { + defer wg.Done() + msg, _, err := s.Client.Exchange(m, s.Address) + if err != nil || msg.Rcode != D.RcodeSuccess { + return + } + in <- &result{Msg: msg, Error: err} + }(server) + } + + // release in channel + go func() { + wg.Wait() + close(in) + }() + + elm, exist := <-fast + if !exist { + return nil, errors.New("All DNS requests failed") + } + + resp := elm.(*result) + msg, err = resp.Msg, resp.Error + return +} + +func (r *Resolver) resolveIP(m *D.Msg) (msg *D.Msg, err error) { + msgCh := r.resolve(r.main, m) + if r.fallback == nil { + res := <-msgCh + msg, err = res.Msg, res.Error + return + } + fallbackMsg := r.resolve(r.fallback, m) + res := <-msgCh + if res.Error == nil { + if mmdb == nil { + return nil, errors.New("GeoIP can't use") + } + + ips, _ := r.msgToIP(res.Msg) + if record, _ := mmdb.Country(ips[0]); record.Country.IsoCode == "CN" || record.Country.IsoCode == "" { + // release channel + go func() { <-fallbackMsg }() + msg = res.Msg + return + } + } + + res = <-fallbackMsg + msg, err = res.Msg, res.Error + return +} + +func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { + query := &D.Msg{} + dnsType := D.TypeA + if r.ipv6 { + dnsType = D.TypeAAAA + } + query.SetQuestion(D.Fqdn(host), dnsType) + + msg, err := r.Exchange(query) + if err != nil { + return nil, err + } + + var ips []net.IP + ips, err = r.msgToIP(msg) + if err != nil { + return nil, err + } + + ip = ips[0] + return +} + +func (r *Resolver) msgToIP(msg *D.Msg) ([]net.IP, error) { + var ips []net.IP + + for _, answer := range msg.Answer { + if r.ipv6 { + ans, ok := answer.(*D.AAAA) + if !ok { + continue + } + ips = append(ips, ans.AAAA) + continue + } + + ans, ok := answer.(*D.A) + if !ok { + continue + } + ips = append(ips, ans.A) + } + + if len(ips) == 0 { + return nil, errors.New("Can't parse msg") + } + + return ips, nil +} + +func (r *Resolver) IPToHost(ip net.IP) (string, bool) { + cache := r.cache.Get(ip.String()) + if cache == nil { + return "", false + } + fqdn := cache.(*D.Msg).Question[0].Name + return strings.TrimRight(fqdn, "."), true +} + +func (r *Resolver) resolve(client []*nameserver, msg *D.Msg) <-chan *result { + ch := make(chan *result) + go func() { + res, err := r.exchange(client, msg) + ch <- &result{Msg: res, Error: err} + }() + return ch +} + +type NameServer struct { + Net string + Addr string +} + +type nameserver struct { + Client *D.Client + Address string +} + +type Config struct { + Main, Fallback []NameServer + IPv6 bool + EnhancedMode EnhancedMode +} + +func transform(servers []NameServer) []*nameserver { + var ret []*nameserver + for _, s := range servers { + ret = append(ret, &nameserver{ + Client: &D.Client{ + Net: s.Net, + TLSConfig: &tls.Config{ + ClientSessionCache: globalSessionCache, + }, + }, + Address: s.Addr, + }) + } + return ret +} + +func New(config Config) *Resolver { + once.Do(func() { + mmdb, _ = geoip2.Open(C.Path.MMDB()) + }) + + r := &Resolver{ + main: transform(config.Main), + ipv6: config.IPv6, + cache: cache.New(time.Second * 60), + mapping: config.EnhancedMode == MAPPING, + } + if config.Fallback != nil { + r.fallback = transform(config.Fallback) + } + return r +} diff --git a/dns/server.go b/dns/server.go new file mode 100644 index 00000000..2ece2e97 --- /dev/null +++ b/dns/server.go @@ -0,0 +1,62 @@ +package dns + +import ( + "net" + + D "github.com/miekg/dns" +) + +var ( + address string + server = &Server{} +) + +type Server struct { + *D.Server + r *Resolver +} + +func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { + msg, err := s.r.Exchange(r) + + if err != nil { + D.HandleFailed(w, r) + return + } + msg.SetReply(r) + w.WriteMsg(msg) +} + +func ReCreateServer(addr string, resolver *Resolver) error { + if server.Server != nil { + server.Shutdown() + } + + if addr == address { + return nil + } + + _, port, err := net.SplitHostPort(addr) + if port == "0" || port == "" || err != nil { + return nil + } + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + + p, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err + } + + address = addr + server = &Server{r: resolver} + server.Server = &D.Server{Addr: addr, PacketConn: p, Handler: server} + + go func() { + server.ActivateAndServe() + }() + return nil +} diff --git a/dns/util.go b/dns/util.go new file mode 100644 index 00000000..c64e4978 --- /dev/null +++ b/dns/util.go @@ -0,0 +1,85 @@ +package dns + +import ( + "encoding/json" + "errors" + "time" + + "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/log" + yaml "gopkg.in/yaml.v2" + + D "github.com/miekg/dns" +) + +var ( + // EnhancedModeMapping is a mapping for EnhancedMode enum + EnhancedModeMapping = map[string]EnhancedMode{ + FAKEIP.String(): FAKEIP, + MAPPING.String(): MAPPING, + } +) + +const ( + FAKEIP EnhancedMode = iota + MAPPING +) + +type EnhancedMode int + +// UnmarshalYAML unserialize EnhancedMode with yaml +func (e *EnhancedMode) UnmarshalYAML(unmarshal func(interface{}) error) error { + var tp string + if err := unmarshal(&tp); err != nil { + return err + } + mode, exist := EnhancedModeMapping[tp] + if !exist { + return errors.New("invalid mode") + } + *e = mode + return nil +} + +// MarshalYAML serialize EnhancedMode with yaml +func (e EnhancedMode) MarshalYAML() ([]byte, error) { + return yaml.Marshal(e.String()) +} + +// UnmarshalJSON unserialize EnhancedMode with json +func (e *EnhancedMode) UnmarshalJSON(data []byte) error { + var tp string + json.Unmarshal(data, &tp) + mode, exist := EnhancedModeMapping[tp] + if !exist { + return errors.New("invalid mode") + } + *e = mode + return nil +} + +// MarshalJSON serialize EnhancedMode with json +func (e EnhancedMode) MarshalJSON() ([]byte, error) { + return json.Marshal(e.String()) +} + +func (e EnhancedMode) String() string { + switch e { + case FAKEIP: + return "fakeip" + case MAPPING: + return "redir-host" + default: + return "unknown" + } +} + +func putMsgToCache(c *cache.Cache, key string, msg *D.Msg) { + if len(msg.Answer) == 0 { + log.Debugln("[DNS] answer length is zero: %#v", msg) + return + } + + ttl := time.Duration(msg.Answer[0].Header().Ttl) * time.Second + c.Put(key, msg, ttl) +} diff --git a/go.mod b/go.mod index 2e536bf5..595d53ef 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,20 @@ module github.com/Dreamacro/clash require ( - github.com/Dreamacro/go-shadowsocks2 v0.1.2-0.20181019110427-0a03f1a25270 + github.com/Dreamacro/go-shadowsocks2 v0.1.2 github.com/eapache/queue v1.1.0 // indirect github.com/go-chi/chi v3.3.3+incompatible github.com/go-chi/cors v1.0.0 github.com/go-chi/render v1.0.1 github.com/gofrs/uuid v3.1.0+incompatible github.com/gorilla/websocket v1.4.0 + github.com/miekg/dns v1.1.0 github.com/oschwald/geoip2-golang v1.2.1 github.com/oschwald/maxminddb-golang v1.3.0 // indirect - github.com/sirupsen/logrus v1.1.0 - golang.org/x/crypto v0.0.0-20181009213950-7c1a557ab941 + github.com/sirupsen/logrus v1.2.0 + golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 + golang.org/x/net v0.0.0-20181108082009-03003ca0c849 // indirect + golang.org/x/sync v0.0.0-20181108010431-42b317875d0f // indirect gopkg.in/eapache/channels.v1 v1.1.0 - gopkg.in/yaml.v2 v2.2.1 + gopkg.in/yaml.v2 v2.2.2 ) diff --git a/go.sum b/go.sum index 185a7fdf..4d0e310b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/Dreamacro/go-shadowsocks2 v0.1.2-0.20181019110427-0a03f1a25270 h1:ugkI+Yw5ArFnhF8KTbJxyWIyvxMCa8jWyUF+wIAulhM= -github.com/Dreamacro/go-shadowsocks2 v0.1.2-0.20181019110427-0a03f1a25270/go.mod h1:DlkXRxmh5K+99aTPQaVjsZ1fAZNFw42vXGcOjR3Otps= +github.com/Dreamacro/go-shadowsocks2 v0.1.2 h1:8KgWbwAw5PJF+i6F3tI2iW/Em9WDtAuDG4obot8bGLM= +github.com/Dreamacro/go-shadowsocks2 v0.1.2/go.mod h1:J5YbNUiKtaD7EJmQ4O9ruUTY9+IgrflPgm63K1nUE0I= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -16,22 +16,28 @@ github.com/gofrs/uuid v3.1.0+incompatible h1:q2rtkjaKT4YEr6E1kamy0Ha4RtepWlQBedy github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/konsorten/go-windows-terminal-sequences v0.0.0-20180402223658-b729f2633dfe h1:CHRGQ8V7OlCYtwaKPJi3iA7J+YdNKdo8j7nG5IgDhjs= -github.com/konsorten/go-windows-terminal-sequences v0.0.0-20180402223658-b729f2633dfe/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/miekg/dns v1.1.0 h1:yv9O9RJbvVFkvW8PKYqp4x7HQkc5RWwmUY/L8MdUaIg= +github.com/miekg/dns v1.1.0/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/oschwald/geoip2-golang v1.2.1 h1:3iz+jmeJc6fuCyWeKgtXSXu7+zvkxJbHFXkMT5FVebU= github.com/oschwald/geoip2-golang v1.2.1/go.mod h1:0LTTzix/Ao1uMvOhAV4iLU0Lz7eCrP94qZWBTDKf0iE= github.com/oschwald/maxminddb-golang v1.3.0 h1:oTh8IBSj10S5JNlUDg5WjJ1QdBMdeaZIkPEVfESSWgE= github.com/oschwald/maxminddb-golang v1.3.0/go.mod h1:3jhIUymTJ5VREKyIhWm66LJiQt04F0UCDdodShpjWsY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sirupsen/logrus v1.1.0 h1:65VZabgUiV9ktjGM5nTq0+YurgTyX+YI2lSSfDjI+qU= -github.com/sirupsen/logrus v1.1.0/go.mod h1:zrgwTnHtNr00buQ1vSptGe8m1f/BbgsPukg8qsT7A+A= +github.com/sirupsen/logrus v1.2.0 h1:juTguoYk5qI21pwyTXY3B3Y5cOTH3ZUyZCg1v/mihuo= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181009213950-7c1a557ab941 h1:qBTHLajHecfu+xzRI9PqVDcqx7SdHj9d4B+EzSn3tAc= -golang.org/x/crypto v0.0.0-20181009213950-7c1a557ab941/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8= +golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 h1:et7+NAX3lLIk5qUCTA9QelBjGE/NkhzYw/mhnr0s7nI= +golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/net v0.0.0-20181108082009-03003ca0c849 h1:FSqE2GGG7wzsYUsWiQ8MZrvEd1EOyU3NCF0AW3Wtltg= +golang.org/x/net v0.0.0-20181108082009-03003ca0c849/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181005133103-4497e2df6f9e h1:EfdBzeKbFSvOjoIqSZcfS8wp0FBLokGBEs9lz1OtSg0= golang.org/x/sys v0.0.0-20181005133103-4497e2df6f9e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -39,5 +45,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/eapache/channels.v1 v1.1.0 h1:5bGAyKKvyCTWjSj7mhefG6Lc68VyN4MH1v8/7OoeeB4= gopkg.in/eapache/channels.v1 v1.1.0/go.mod h1:BHIBujSvu9yMTrTYbTCjDD43gUhtmaOtTWDe7sTv1js= -gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 87cf0032..6db5c54d 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -4,6 +4,7 @@ import ( adapters "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/config" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/log" P "github.com/Dreamacro/clash/proxy" T "github.com/Dreamacro/clash/tunnel" @@ -26,6 +27,8 @@ func ApplyConfig(cfg *config.Config, force bool) { } updateProxies(cfg.Proxies) updateRules(cfg.Rules) + updateGeneral(cfg.General) + updateDNS(cfg.DNS) } func GetGeneral() *config.General { @@ -40,6 +43,22 @@ func GetGeneral() *config.General { } } +func updateDNS(c *config.DNS) { + if c.Enable == false { + T.Instance().SetResolver(nil) + dns.ReCreateServer("", nil) + return + } + r := dns.New(dns.Config{ + Main: c.NameServer, + Fallback: c.Fallback, + IPv6: c.IPv6, + EnhancedMode: c.EnhancedMode, + }) + T.Instance().SetResolver(r) + dns.ReCreateServer(c.Listen, r) +} + func updateProxies(proxies map[string]C.Proxy) { tunnel := T.Instance() oldProxies := tunnel.Proxies() diff --git a/log/level.go b/log/level.go index a361f3c9..e95e36d2 100644 --- a/log/level.go +++ b/log/level.go @@ -3,18 +3,16 @@ package log import ( "encoding/json" "errors" - - yaml "gopkg.in/yaml.v2" ) var ( // LogLevelMapping is a mapping for LogLevel enum LogLevelMapping = map[string]LogLevel{ - "error": ERROR, - "warning": WARNING, - "info": INFO, - "debug": DEBUG, - "silent": SILENT, + ERROR.String(): ERROR, + WARNING.String(): WARNING, + INFO.String(): INFO, + DEBUG.String(): DEBUG, + SILENT.String(): SILENT, } ) @@ -28,10 +26,10 @@ const ( type LogLevel int -// UnmarshalYAML unserialize Mode with yaml -func (l *LogLevel) UnmarshalYAML(data []byte) error { +// UnmarshalYAML unserialize LogLevel with yaml +func (l *LogLevel) UnmarshalYAML(unmarshal func(interface{}) error) error { var tp string - yaml.Unmarshal(data, &tp) + unmarshal(&tp) level, exist := LogLevelMapping[tp] if !exist { return errors.New("invalid mode") @@ -40,12 +38,7 @@ func (l *LogLevel) UnmarshalYAML(data []byte) error { return nil } -// MarshalYAML serialize Mode with yaml -func (l LogLevel) MarshalYAML() ([]byte, error) { - return yaml.Marshal(l.String()) -} - -// UnmarshalJSON unserialize Mode with json +// UnmarshalJSON unserialize LogLevel with json func (l *LogLevel) UnmarshalJSON(data []byte) error { var tp string json.Unmarshal(data, &tp) @@ -57,7 +50,7 @@ func (l *LogLevel) UnmarshalJSON(data []byte) error { return nil } -// MarshalJSON serialize Mode with json +// MarshalJSON serialize LogLevel with json func (l LogLevel) MarshalJSON() ([]byte, error) { return json.Marshal(l.String()) } @@ -75,6 +68,6 @@ func (l LogLevel) String() string { case SILENT: return "silent" default: - return "unknow" + return "unknown" } } diff --git a/log/log.go b/log/log.go index c8060661..77a835cd 100644 --- a/log/log.go +++ b/log/log.go @@ -14,6 +14,10 @@ var ( level = INFO ) +func init() { + log.SetLevel(log.DebugLevel) +} + type Event struct { LogLevel LogLevel Payload string @@ -47,6 +51,10 @@ func Debugln(format string, v ...interface{}) { print(event) } +func Fatalln(format string, v ...interface{}) { + log.Fatalf(format, v...) +} + func Subscribe() observable.Subscription { sub, _ := source.Subscribe() return sub diff --git a/proxy/http/server.go b/proxy/http/server.go index d95d91ad..e66f5493 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -14,18 +14,18 @@ var ( tun = tunnel.Instance() ) -type httpListener struct { +type HttpListener struct { net.Listener address string closed bool } -func NewHttpProxy(addr string) (*httpListener, error) { +func NewHttpProxy(addr string) (*HttpListener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err } - hl := &httpListener{l, addr, false} + hl := &HttpListener{l, addr, false} go func() { log.Infoln("HTTP proxy listening at: %s", addr) @@ -44,12 +44,12 @@ func NewHttpProxy(addr string) (*httpListener, error) { return hl, nil } -func (l *httpListener) Close() { +func (l *HttpListener) Close() { l.closed = true l.Listener.Close() } -func (l *httpListener) Address() string { +func (l *HttpListener) Address() string { return l.address } diff --git a/proxy/listener.go b/proxy/listener.go index 82922aab..7fbcfd77 100644 --- a/proxy/listener.go +++ b/proxy/listener.go @@ -13,9 +13,9 @@ import ( var ( allowLan = false - socksListener listener - httpListener listener - redirListener listener + socksListener *socks.SockListener + httpListener *http.HttpListener + redirListener *redir.RedirListener ) type listener interface { diff --git a/proxy/redir/tcp.go b/proxy/redir/tcp.go index ba55e458..ac88710a 100644 --- a/proxy/redir/tcp.go +++ b/proxy/redir/tcp.go @@ -4,6 +4,7 @@ import ( "net" "github.com/Dreamacro/clash/adapters/inbound" + C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/tunnel" ) @@ -12,18 +13,18 @@ var ( tun = tunnel.Instance() ) -type redirListener struct { +type RedirListener struct { net.Listener address string closed bool } -func NewRedirProxy(addr string) (*redirListener, error) { +func NewRedirProxy(addr string) (*RedirListener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err } - rl := &redirListener{l, addr, false} + rl := &RedirListener{l, addr, false} go func() { log.Infoln("Redir proxy listening at: %s", addr) @@ -42,12 +43,12 @@ func NewRedirProxy(addr string) (*redirListener, error) { return rl, nil } -func (l *redirListener) Close() { +func (l *RedirListener) Close() { l.closed = true l.Listener.Close() } -func (l *redirListener) Address() string { +func (l *RedirListener) Address() string { return l.address } @@ -58,5 +59,5 @@ func handleRedir(conn net.Conn) { return } conn.(*net.TCPConn).SetKeepAlive(true) - tun.Add(adapters.NewSocket(target, conn)) + tun.Add(adapters.NewSocket(target, conn, C.REDIR)) } diff --git a/proxy/socks/tcp.go b/proxy/socks/tcp.go index badb8b18..83220c09 100644 --- a/proxy/socks/tcp.go +++ b/proxy/socks/tcp.go @@ -4,6 +4,7 @@ import ( "net" "github.com/Dreamacro/clash/adapters/inbound" + C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/tunnel" @@ -14,19 +15,19 @@ var ( tun = tunnel.Instance() ) -type sockListener struct { +type SockListener struct { net.Listener address string closed bool } -func NewSocksProxy(addr string) (*sockListener, error) { +func NewSocksProxy(addr string) (*SockListener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err } - sl := &sockListener{l, addr, false} + sl := &SockListener{l, addr, false} go func() { log.Infoln("SOCKS proxy listening at: %s", addr) for { @@ -44,12 +45,12 @@ func NewSocksProxy(addr string) (*sockListener, error) { return sl, nil } -func (l *sockListener) Close() { +func (l *SockListener) Close() { l.closed = true l.Listener.Close() } -func (l *sockListener) Address() string { +func (l *SockListener) Address() string { return l.address } @@ -60,5 +61,5 @@ func handleSocks(conn net.Conn) { return } conn.(*net.TCPConn).SetKeepAlive(true) - tun.Add(adapters.NewSocket(target, conn)) + tun.Add(adapters.NewSocket(target, conn, C.SOCKS)) } diff --git a/tunnel/mode.go b/tunnel/mode.go index e9eb531f..69d0d6f8 100644 --- a/tunnel/mode.go +++ b/tunnel/mode.go @@ -10,9 +10,9 @@ type Mode int var ( // ModeMapping is a mapping for Mode enum ModeMapping = map[string]Mode{ - "Global": Global, - "Rule": Rule, - "Direct": Direct, + Global.String(): Global, + Rule.String(): Rule, + Direct.String(): Direct, } ) @@ -34,6 +34,18 @@ func (m *Mode) UnmarshalJSON(data []byte) error { return nil } +// UnmarshalYAML unserialize Mode with yaml +func (m *Mode) UnmarshalYAML(unmarshal func(interface{}) error) error { + var tp string + unmarshal(&tp) + mode, exist := ModeMapping[tp] + if !exist { + return errors.New("invalid mode") + } + *m = mode + return nil +} + // MarshalJSON serialize Mode func (m Mode) MarshalJSON() ([]byte, error) { return json.Marshal(m.String()) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 264e0e8c..930bd070 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -1,11 +1,13 @@ package tunnel import ( + "net" "sync" "time" InboundAdapter "github.com/Dreamacro/clash/adapters/inbound" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/log" "gopkg.in/eapache/channels.v1" @@ -23,6 +25,7 @@ type Tunnel struct { proxies map[string]C.Proxy configLock *sync.RWMutex traffic *C.Traffic + resolver *dns.Resolver // Outbound Rule mode Mode @@ -72,6 +75,11 @@ func (t *Tunnel) SetMode(mode Mode) { t.mode = mode } +// SetResolver change the resolver of tunnel +func (t *Tunnel) SetResolver(resolver *dns.Resolver) { + t.resolver = resolver +} + func (t *Tunnel) process() { queue := t.queue.Out() for { @@ -81,10 +89,41 @@ func (t *Tunnel) process() { } } +func (t *Tunnel) resolveIP(host string) (net.IP, error) { + if t.resolver == nil { + ipAddr, err := net.ResolveIPAddr("ip", host) + if err != nil { + return nil, err + } + + return ipAddr.IP, nil + } + + return t.resolver.ResolveIP(host) +} + func (t *Tunnel) handleConn(localConn C.ServerAdapter) { defer localConn.Close() metadata := localConn.Metadata() + if metadata.Source == C.REDIR && t.resolver != nil { + host, exist := t.resolver.IPToHost(*metadata.IP) + if exist { + metadata.Host = host + metadata.AddrType = C.AtypDomainName + } + } else if metadata.IP == nil && metadata.AddrType == C.AtypDomainName { + ip, err := t.resolveIP(metadata.Host) + if err != nil { + log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) + } else { + log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String()) + metadata.IP = &ip + } + } else { + log.Debugln("[DNS] unknown%#v", metadata) + } + var proxy C.Proxy switch t.mode { case Direct: