From 207371aeaeec4cac5a8ca695a7aaf8a846138260 Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Sun, 27 Oct 2019 21:44:07 +0800 Subject: [PATCH] Feature: add experimental connections API --- adapters/inbound/http.go | 1 + adapters/inbound/https.go | 3 + adapters/inbound/util.go | 11 ++-- component/fakeip/pool.go | 3 +- constant/metadata.go | 41 ++++++++++--- constant/rule.go | 4 +- constant/traffic.go | 55 ----------------- hub/route/connections.go | 91 ++++++++++++++++++++++++++++ hub/route/server.go | 3 +- rules/geoip.go | 2 +- rules/ipcidr.go | 2 +- tunnel/connection.go | 12 ++-- tunnel/manager.go | 87 +++++++++++++++++++++++++++ tunnel/tracker.go | 122 ++++++++++++++++++++++++++++++++++++++ tunnel/tunnel.go | 29 ++++----- tunnel/util.go | 29 --------- 16 files changed, 365 insertions(+), 130 deletions(-) delete mode 100644 constant/traffic.go create mode 100644 hub/route/connections.go create mode 100644 tunnel/manager.go create mode 100644 tunnel/tracker.go delete mode 100644 tunnel/util.go diff --git a/adapters/inbound/http.go b/adapters/inbound/http.go index f60544f8..7aeb4c65 100644 --- a/adapters/inbound/http.go +++ b/adapters/inbound/http.go @@ -23,6 +23,7 @@ func (h *HTTPAdapter) Metadata() *C.Metadata { // NewHTTP is HTTPAdapter generator func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter { metadata := parseHTTPAddr(request) + metadata.Type = C.HTTP if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { metadata.SrcIP = ip metadata.SrcPort = port diff --git a/adapters/inbound/https.go b/adapters/inbound/https.go index a880af74..124b80c2 100644 --- a/adapters/inbound/https.go +++ b/adapters/inbound/https.go @@ -3,11 +3,14 @@ package adapters import ( "net" "net/http" + + C "github.com/Dreamacro/clash/constant" ) // NewHTTPS is HTTPAdapter generator func NewHTTPS(request *http.Request, conn net.Conn) *SocketAdapter { metadata := parseHTTPAddr(request) + metadata.Type = C.HTTPCONNECT if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { metadata.SrcIP = ip metadata.SrcPort = port diff --git a/adapters/inbound/util.go b/adapters/inbound/util.go index 9d977cf8..9a16a6cf 100644 --- a/adapters/inbound/util.go +++ b/adapters/inbound/util.go @@ -20,11 +20,11 @@ func parseSocksAddr(target socks5.Addr) *C.Metadata { metadata.DstPort = strconv.Itoa((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1])) case socks5.AtypIPv4: ip := net.IP(target[1 : 1+net.IPv4len]) - metadata.DstIP = &ip + metadata.DstIP = ip metadata.DstPort = strconv.Itoa((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1])) case socks5.AtypIPv6: ip := net.IP(target[1 : 1+net.IPv6len]) - metadata.DstIP = &ip + metadata.DstIP = ip metadata.DstPort = strconv.Itoa((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1])) } @@ -40,7 +40,6 @@ func parseHTTPAddr(request *http.Request) *C.Metadata { metadata := &C.Metadata{ NetWork: C.TCP, - Type: C.HTTP, AddrType: C.AtypDomainName, Host: host, DstIP: nil, @@ -55,18 +54,18 @@ func parseHTTPAddr(request *http.Request) *C.Metadata { default: metadata.AddrType = C.AtypIPv4 } - metadata.DstIP = &ip + metadata.DstIP = ip } return metadata } -func parseAddr(addr string) (*net.IP, string, error) { +func parseAddr(addr string) (net.IP, string, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, "", err } ip := net.ParseIP(host) - return &ip, port, nil + return ip, port, nil } diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index e71ddfaa..86a92ee2 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -14,7 +14,7 @@ type Pool struct { min uint32 gateway uint32 offset uint32 - mux *sync.Mutex + mux sync.Mutex cache *cache.LruCache } @@ -111,7 +111,6 @@ func New(ipnet *net.IPNet, size int) (*Pool, error) { min: min, max: max, gateway: min - 1, - mux: &sync.Mutex{}, cache: cache.NewLRUCache(cache.WithSize(size * 2)), }, nil } diff --git a/constant/metadata.go b/constant/metadata.go index cc870361..afcd3443 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -1,6 +1,7 @@ package constant import ( + "encoding/json" "net" ) @@ -14,6 +15,7 @@ const ( UDP HTTP Type = iota + HTTPCONNECT SOCKS REDIR ) @@ -27,18 +29,41 @@ func (n *NetWork) String() string { return "udp" } +func (n NetWork) MarshalJSON() ([]byte, error) { + return json.Marshal(n.String()) +} + type Type int +func (t Type) String() string { + switch t { + case HTTP: + return "HTTP" + case HTTPCONNECT: + return "HTTP Connect" + case SOCKS: + return "Socks5" + case REDIR: + return "Redir" + default: + return "Unknown" + } +} + +func (t Type) MarshalJSON() ([]byte, error) { + return json.Marshal(t.String()) +} + // Metadata is used to store connection address type Metadata struct { - NetWork NetWork - Type Type - SrcIP *net.IP - DstIP *net.IP - SrcPort string - DstPort string - AddrType int - Host string + NetWork NetWork `json:"network"` + Type Type `json:"type"` + SrcIP net.IP `json:"sourceIP"` + DstIP net.IP `json:"destinationIP"` + SrcPort string `json:"sourcePort"` + DstPort string `json:"destinationPort"` + AddrType int `json:"-"` + Host string `json:"host"` } func (m *Metadata) RemoteAddress() string { diff --git a/constant/rule.go b/constant/rule.go index 40756264..1ac599dd 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -24,7 +24,7 @@ func (rt RuleType) String() string { case DomainKeyword: return "DomainKeyword" case GEOIP: - return "GEOIP" + return "GeoIP" case IPCIDR: return "IPCIDR" case SrcIPCIDR: @@ -34,7 +34,7 @@ func (rt RuleType) String() string { case DstPort: return "DstPort" case MATCH: - return "MATCH" + return "Match" default: return "Unknown" } diff --git a/constant/traffic.go b/constant/traffic.go deleted file mode 100644 index edf67368..00000000 --- a/constant/traffic.go +++ /dev/null @@ -1,55 +0,0 @@ -package constant - -import ( - "time" -) - -type Traffic struct { - up chan int64 - down chan int64 - upCount int64 - downCount int64 - upTotal int64 - downTotal int64 - interval time.Duration -} - -func (t *Traffic) Up() chan<- int64 { - return t.up -} - -func (t *Traffic) Down() chan<- int64 { - return t.down -} - -func (t *Traffic) Now() (up int64, down int64) { - return t.upTotal, t.downTotal -} - -func (t *Traffic) handle() { - go t.handleCh(t.up, &t.upCount, &t.upTotal) - go t.handleCh(t.down, &t.downCount, &t.downTotal) -} - -func (t *Traffic) handleCh(ch <-chan int64, count *int64, total *int64) { - ticker := time.NewTicker(t.interval) - for { - select { - case n := <-ch: - *count += n - case <-ticker.C: - *total = *count - *count = 0 - } - } -} - -func NewTraffic(interval time.Duration) *Traffic { - t := &Traffic{ - up: make(chan int64), - down: make(chan int64), - interval: interval, - } - go t.handle() - return t -} diff --git a/hub/route/connections.go b/hub/route/connections.go new file mode 100644 index 00000000..6642b6f2 --- /dev/null +++ b/hub/route/connections.go @@ -0,0 +1,91 @@ +package route + +import ( + "bytes" + "encoding/json" + "net/http" + "strconv" + "time" + + T "github.com/Dreamacro/clash/tunnel" + "github.com/gorilla/websocket" + + "github.com/go-chi/chi" + "github.com/go-chi/render" +) + +func connectionRouter() http.Handler { + r := chi.NewRouter() + r.Get("/", getConnections) + r.Delete("/", closeAllConnections) + r.Delete("/{id}", closeConnection) + return r +} + +func getConnections(w http.ResponseWriter, r *http.Request) { + if !websocket.IsWebSocketUpgrade(r) { + snapshot := T.DefaultManager.Snapshot() + render.JSON(w, r, snapshot) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + intervalStr := r.URL.Query().Get("interval") + interval := 1000 + if intervalStr != "" { + t, err := strconv.Atoi(intervalStr) + if err != nil { + render.Status(r, http.StatusBadRequest) + render.JSON(w, r, ErrBadRequest) + return + } + + interval = t + } + + buf := &bytes.Buffer{} + sendSnapshot := func() error { + buf.Reset() + snapshot := T.DefaultManager.Snapshot() + if err := json.NewEncoder(buf).Encode(snapshot); err != nil { + return err + } + + return conn.WriteMessage(websocket.TextMessage, buf.Bytes()) + } + + if err := sendSnapshot(); err != nil { + return + } + + tick := time.NewTicker(time.Millisecond * time.Duration(interval)) + for range tick.C { + if err := sendSnapshot(); err != nil { + break + } + } +} + +func closeConnection(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + snapshot := T.DefaultManager.Snapshot() + for _, c := range snapshot.Connections { + if id == c.ID() { + c.Close() + break + } + } + render.NoContent(w, r) +} + +func closeAllConnections(w http.ResponseWriter, r *http.Request) { + snapshot := T.DefaultManager.Snapshot() + for _, c := range snapshot.Connections { + c.Close() + } + render.NoContent(w, r) +} diff --git a/hub/route/server.go b/hub/route/server.go index 394ad406..560c338f 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -67,6 +67,7 @@ func Start(addr string, secret string) { r.Mount("/configs", configRouter()) r.Mount("/proxies", proxyRouter()) r.Mount("/rules", ruleRouter()) + r.Mount("/connections", connectionRouter()) }) if uiPath != "" { @@ -140,7 +141,7 @@ func traffic(w http.ResponseWriter, r *http.Request) { } tick := time.NewTicker(time.Second) - t := T.Instance().Traffic() + t := T.DefaultManager buf := &bytes.Buffer{} var err error for range tick.C { diff --git a/rules/geoip.go b/rules/geoip.go index 07cc2eff..3ffce9e8 100644 --- a/rules/geoip.go +++ b/rules/geoip.go @@ -27,7 +27,7 @@ func (g *GEOIP) IsMatch(metadata *C.Metadata) bool { if metadata.DstIP == nil { return false } - record, _ := mmdb.Country(*metadata.DstIP) + record, _ := mmdb.Country(metadata.DstIP) return record.Country.IsoCode == g.country } diff --git a/rules/ipcidr.go b/rules/ipcidr.go index e8160300..78a649ab 100644 --- a/rules/ipcidr.go +++ b/rules/ipcidr.go @@ -24,7 +24,7 @@ func (i *IPCIDR) IsMatch(metadata *C.Metadata) bool { if i.isSourceIP { ip = metadata.SrcIP } - return ip != nil && i.ipnet.Contains(*ip) + return ip != nil && i.ipnet.Contains(ip) } func (i *IPCIDR) Adapter() string { diff --git a/tunnel/connection.go b/tunnel/connection.go index 1f9f335b..87faa5bd 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -13,12 +13,11 @@ import ( ) func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { - conn := newTrafficTrack(outbound, t.traffic) req := request.R host := req.Host inboundReeder := bufio.NewReader(request) - outboundReeder := bufio.NewReader(conn) + outboundReeder := bufio.NewReader(outbound) for { keepAlive := strings.TrimSpace(strings.ToLower(req.Header.Get("Proxy-Connection"))) == "keep-alive" @@ -26,7 +25,7 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { req.Header.Set("Connection", "close") req.RequestURI = "" adapters.RemoveHopByHopHeaders(req.Header) - err := req.Write(conn) + err := req.Write(outbound) if err != nil { break } @@ -91,7 +90,7 @@ func (t *Tunnel) handleUDPToRemote(conn net.Conn, pc net.PacketConn, addr net.Ad if _, err = pc.WriteTo(buf[:n], addr); err != nil { return } - t.traffic.Up() <- int64(n) + DefaultManager.Upload() <- int64(n) } func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, timeout time.Duration) { @@ -111,13 +110,12 @@ func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, if err != nil { return } - t.traffic.Down() <- int64(n) + DefaultManager.Download() <- int64(n) } } func (t *Tunnel) handleSocket(request *adapters.SocketAdapter, outbound net.Conn) { - conn := newTrafficTrack(outbound, t.traffic) - relay(request, conn) + relay(request, outbound) } // relay copies between left and right bidirectionally. diff --git a/tunnel/manager.go b/tunnel/manager.go new file mode 100644 index 00000000..d4d627e0 --- /dev/null +++ b/tunnel/manager.go @@ -0,0 +1,87 @@ +package tunnel + +import ( + "sync" + "time" +) + +var DefaultManager *Manager + +func init() { + DefaultManager = &Manager{ + upload: make(chan int64), + download: make(chan int64), + } + DefaultManager.handle() +} + +type Manager struct { + connections sync.Map + upload chan int64 + download chan int64 + uploadTemp int64 + downloadTemp int64 + uploadBlip int64 + downloadBlip int64 + uploadTotal int64 + downloadTotal int64 +} + +func (m *Manager) Join(c tracker) { + m.connections.Store(c.ID(), c) +} + +func (m *Manager) Leave(c tracker) { + m.connections.Delete(c.ID()) +} + +func (m *Manager) Upload() chan<- int64 { + return m.upload +} + +func (m *Manager) Download() chan<- int64 { + return m.download +} + +func (m *Manager) Now() (up int64, down int64) { + return m.uploadBlip, m.downloadBlip +} + +func (m *Manager) Snapshot() *Snapshot { + connections := []tracker{} + m.connections.Range(func(key, value interface{}) bool { + connections = append(connections, value.(tracker)) + return true + }) + + return &Snapshot{ + UploadTotal: m.uploadTotal, + DownloadTotal: m.downloadTotal, + Connections: connections, + } +} + +func (m *Manager) handle() { + go m.handleCh(m.upload, &m.uploadTemp, &m.uploadBlip, &m.uploadTotal) + go m.handleCh(m.download, &m.downloadTemp, &m.downloadBlip, &m.downloadTotal) +} + +func (m *Manager) handleCh(ch <-chan int64, temp *int64, blip *int64, total *int64) { + ticker := time.NewTicker(time.Second) + for { + select { + case n := <-ch: + *temp += n + *total += n + case <-ticker.C: + *blip = *temp + *temp = 0 + } + } +} + +type Snapshot struct { + DownloadTotal int64 `json:"downloadTotal"` + UploadTotal int64 `json:"uploadTotal"` + Connections []tracker `json:"connections"` +} diff --git a/tunnel/tracker.go b/tunnel/tracker.go new file mode 100644 index 00000000..77e442b8 --- /dev/null +++ b/tunnel/tracker.go @@ -0,0 +1,122 @@ +package tunnel + +import ( + "net" + "time" + + C "github.com/Dreamacro/clash/constant" + "github.com/gofrs/uuid" +) + +type tracker interface { + ID() string + Close() error +} + +type trackerInfo struct { + UUID uuid.UUID `json:"id"` + Metadata *C.Metadata `json:"metadata"` + UploadTotal int64 `json:"upload"` + DownloadTotal int64 `json:"download"` + Start time.Time `json:"start"` + Chain C.Chain `json:"chains"` + Rule string `json:"rule"` +} + +type tcpTracker struct { + C.Conn `json:"-"` + *trackerInfo + manager *Manager +} + +func (tt *tcpTracker) ID() string { + return tt.UUID.String() +} + +func (tt *tcpTracker) Read(b []byte) (int, error) { + n, err := tt.Conn.Read(b) + download := int64(n) + tt.manager.Download() <- download + tt.DownloadTotal += download + return n, err +} + +func (tt *tcpTracker) Write(b []byte) (int, error) { + n, err := tt.Conn.Write(b) + upload := int64(n) + tt.manager.Upload() <- upload + tt.UploadTotal += upload + return n, err +} + +func (tt *tcpTracker) Close() error { + tt.manager.Leave(tt) + return tt.Conn.Close() +} + +func newTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { + uuid, _ := uuid.NewV4() + t := &tcpTracker{ + Conn: conn, + manager: manager, + trackerInfo: &trackerInfo{ + UUID: uuid, + Start: time.Now(), + Metadata: metadata, + Chain: conn.Chains(), + Rule: rule.RuleType().String(), + }, + } + + manager.Join(t) + return t +} + +type udpTracker struct { + C.PacketConn `json:"-"` + *trackerInfo + manager *Manager +} + +func (ut *udpTracker) ID() string { + return ut.UUID.String() +} + +func (ut *udpTracker) ReadFrom(b []byte) (int, net.Addr, error) { + n, addr, err := ut.PacketConn.ReadFrom(b) + download := int64(n) + ut.manager.Download() <- download + ut.DownloadTotal += download + return n, addr, err +} + +func (ut *udpTracker) WriteTo(b []byte, addr net.Addr) (int, error) { + n, err := ut.PacketConn.WriteTo(b, addr) + upload := int64(n) + ut.manager.Upload() <- upload + ut.UploadTotal += upload + return n, err +} + +func (ut *udpTracker) Close() error { + ut.manager.Leave(ut) + return ut.PacketConn.Close() +} + +func newUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker { + uuid, _ := uuid.NewV4() + ut := &udpTracker{ + PacketConn: conn, + manager: manager, + trackerInfo: &trackerInfo{ + UUID: uuid, + Start: time.Now(), + Metadata: metadata, + Chain: conn.Chains(), + Rule: rule.RuleType().String(), + }, + } + + manager.Join(ut) + return ut +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 32647f1a..7eb16af1 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -30,8 +30,7 @@ type Tunnel struct { natTable *nat.Table rules []C.Rule proxies map[string]C.Proxy - configMux *sync.RWMutex - traffic *C.Traffic + configMux sync.RWMutex // experimental features ignoreResolveFail bool @@ -50,11 +49,6 @@ func (t *Tunnel) Add(req C.ServerAdapter) { } } -// Traffic return traffic of all connections -func (t *Tunnel) Traffic() *C.Traffic { - return t.traffic -} - // Rules return all rules func (t *Tunnel) Rules() []C.Rule { return t.rules @@ -123,7 +117,7 @@ func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { // preprocess enhanced-mode metadata if t.needLookupIP(metadata) { - host, exist := dns.DefaultResolver.IPToHost(*metadata.DstIP) + host, exist := dns.DefaultResolver.IPToHost(metadata.DstIP) if exist { metadata.Host = host metadata.AddrType = C.AtypDomainName @@ -188,8 +182,8 @@ func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) { wg.Done() return } - pc = rawPc addr = nAddr + pc = newUDPTracker(rawPc, DefaultManager, metadata, rule) if rule != nil { log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String()) @@ -231,6 +225,7 @@ func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) { log.Warnln("dial %s error: %s", proxy.Name(), err.Error()) return } + remoteConn = newTCPTracker(remoteConn, DefaultManager, metadata, rule) defer remoteConn.Close() if rule != nil { @@ -259,7 +254,7 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { if node := dns.DefaultHosts.Search(metadata.Host); node != nil { ip := node.Data.(net.IP) - metadata.DstIP = &ip + metadata.DstIP = ip resolved = true } @@ -273,7 +268,7 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) } else { log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String()) - metadata.DstIP = &ip + metadata.DstIP = ip } resolved = true } @@ -296,13 +291,11 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { func newTunnel() *Tunnel { return &Tunnel{ - tcpQueue: channels.NewInfiniteChannel(), - udpQueue: channels.NewInfiniteChannel(), - natTable: nat.New(), - proxies: make(map[string]C.Proxy), - configMux: &sync.RWMutex{}, - traffic: C.NewTraffic(time.Second), - mode: Rule, + tcpQueue: channels.NewInfiniteChannel(), + udpQueue: channels.NewInfiniteChannel(), + natTable: nat.New(), + proxies: make(map[string]C.Proxy), + mode: Rule, } } diff --git a/tunnel/util.go b/tunnel/util.go deleted file mode 100644 index 15c4d3da..00000000 --- a/tunnel/util.go +++ /dev/null @@ -1,29 +0,0 @@ -package tunnel - -import ( - "net" - - C "github.com/Dreamacro/clash/constant" -) - -// TrafficTrack record traffic of net.Conn -type TrafficTrack struct { - net.Conn - traffic *C.Traffic -} - -func (tt *TrafficTrack) Read(b []byte) (int, error) { - n, err := tt.Conn.Read(b) - tt.traffic.Down() <- int64(n) - return n, err -} - -func (tt *TrafficTrack) Write(b []byte) (int, error) { - n, err := tt.Conn.Write(b) - tt.traffic.Up() <- int64(n) - return n, err -} - -func newTrafficTrack(conn net.Conn, traffic *C.Traffic) *TrafficTrack { - return &TrafficTrack{traffic: traffic, Conn: conn} -}