diff --git a/common/net/bufconn.go b/common/net/bufconn.go index 54326cf9..8087608c 100644 --- a/common/net/bufconn.go +++ b/common/net/bufconn.go @@ -27,6 +27,10 @@ func (c *BufferedConn) Reader() *bufio.Reader { return c.r } +func (c *BufferedConn) ResetPeeked() { + c.peeked = false +} + func (c *BufferedConn) Peeked() bool { return c.peeked } diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 97dd7316..e21868c5 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -81,7 +81,7 @@ func (tt *tcpTracker) Upstream() any { return tt.Conn } -func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { +func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64) *tcpTracker { uuid, _ := uuid.NewV4() if conn != nil { if tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { @@ -100,8 +100,8 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R Metadata: metadata, Chain: conn.Chains(), Rule: "", - UploadTotal: atomic.NewInt64(0), - DownloadTotal: atomic.NewInt64(0), + UploadTotal: atomic.NewInt64(uploadTotal), + DownloadTotal: atomic.NewInt64(downloadTotal), }, extendedReader: N.NewExtendedReader(conn), extendedWriter: N.NewExtendedWriter(conn), @@ -147,7 +147,7 @@ func (ut *udpTracker) Close() error { return ut.PacketConn.Close() } -func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker { +func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64) *udpTracker { uuid, _ := uuid.NewV4() metadata.RemoteDst = conn.RemoteDestination() @@ -160,8 +160,8 @@ func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, ru Metadata: metadata, Chain: conn.Chains(), Rule: "", - UploadTotal: atomic.NewInt64(0), - DownloadTotal: atomic.NewInt64(0), + UploadTotal: atomic.NewInt64(uploadTotal), + DownloadTotal: atomic.NewInt64(downloadTotal), }, } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index b9d0e594..90fd42be 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -322,7 +322,7 @@ func handleUDPConn(packet C.PacketAdapter) { } pCtx.InjectPacketConn(rawPc) - pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule) + pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0) switch true { case metadata.SpecialProxy != "": @@ -367,6 +367,7 @@ func handleTCPConn(connCtx C.ConnContext) { } conn := connCtx.Conn() + conn.ResetPeeked() if sniffer.Dispatcher.Enable() && sniffingEnable { sniffer.Dispatcher.TCPSniff(conn, metadata) } @@ -400,6 +401,7 @@ func handleTCPConn(connCtx C.ConnContext) { } var peekBytes []byte + var peekLen int ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) defer cancel() @@ -415,7 +417,7 @@ func handleTCPConn(connCtx C.ConnContext) { if err != nil { return nil, err } - if peekLen := len(peekBytes); peekLen > 0 { + if peekLen = len(peekBytes); peekLen > 0 { _, _ = conn.Discard(peekLen) } return remoteConn, err @@ -436,7 +438,7 @@ func handleTCPConn(connCtx C.ConnContext) { return } - remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule) + remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule, 0, int64(peekLen)) defer func(remoteConn C.Conn) { _ = remoteConn.Close() }(remoteConn)