diff --git a/common/mux/client.go b/common/mux/client.go index b3b3ada3..36dd3137 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -1,554 +1,21 @@ package mux import ( - "context" - "encoding/binary" - "io" - "net" - "sync" - "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing-mux" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/x/list" ) -var _ N.Dialer = (*Client)(nil) - -type Client struct { - access sync.Mutex - connections list.List[abstractSession] - ctx context.Context - dialer N.Dialer - protocol Protocol - maxConnections int - minStreams int - maxStreams int - paddingEnabled bool -} - -func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int, paddingEnabled bool) (*Client, error) { - return &Client{ - ctx: ctx, - dialer: dialer, - protocol: protocol, - maxConnections: maxConnections, - minStreams: minStreams, - maxStreams: maxStreams, - paddingEnabled: paddingEnabled, - }, nil -} - -func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (*Client, error) { +func NewClientWithOptions(dialer N.Dialer, options option.MultiplexOptions) (*Client, error) { if !options.Enabled { return nil, nil } - if options.MaxConnections == 0 && options.MaxStreams == 0 { - options.MinStreams = 8 - } - protocol, err := ParseProtocol(options.Protocol) - if err != nil { - return nil, err - } - return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams, options.Padding) -} - -func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - switch N.NetworkName(network) { - case N.NetworkTCP: - stream, err := c.openStream() - if err != nil { - return nil, err - } - return &ClientConn{Conn: stream, destination: destination}, nil - case N.NetworkUDP: - stream, err := c.openStream() - if err != nil { - return nil, err - } - return bufio.NewUnbindPacketConn(&ClientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil - default: - return nil, E.Extend(N.ErrUnknownNetwork, network) - } -} - -func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - stream, err := c.openStream() - if err != nil { - return nil, err - } - return &ClientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil -} - -func (c *Client) openStream() (net.Conn, error) { - var ( - session abstractSession - stream net.Conn - err error - ) - for attempts := 0; attempts < 2; attempts++ { - session, err = c.offer() - if err != nil { - continue - } - stream, err = session.Open() - if err != nil { - continue - } - break - } - if err != nil { - return nil, err - } - return &wrapStream{stream}, nil -} - -func (c *Client) offer() (abstractSession, error) { - c.access.Lock() - defer c.access.Unlock() - - sessions := make([]abstractSession, 0, c.maxConnections) - for element := c.connections.Front(); element != nil; { - if element.Value.IsClosed() { - nextElement := element.Next() - c.connections.Remove(element) - element = nextElement - continue - } - sessions = append(sessions, element.Value) - element = element.Next() - } - session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams) - if session == nil { - return c.offerNew() - } - numStreams := session.NumStreams() - if numStreams == 0 { - return session, nil - } - if c.maxConnections > 0 { - if len(sessions) >= c.maxConnections || numStreams < c.minStreams { - return session, nil - } - } else { - if c.maxStreams > 0 && numStreams < c.maxStreams { - return session, nil - } - } - return c.offerNew() -} - -func (c *Client) offerNew() (abstractSession, error) { - conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination) - if err != nil { - return nil, err - } - var version byte - if c.paddingEnabled { - version = Version1 - } else { - version = Version0 - } - conn = newProtocolConn(conn, Request{ - Version: version, - Protocol: c.protocol, - PaddingEnabled: c.paddingEnabled, + return mux.NewClient(mux.Options{ + Dialer: dialer, + Protocol: options.Protocol, + MaxConnections: options.MaxConnections, + MinStreams: options.MinStreams, + MaxStreams: options.MaxStreams, + Padding: options.Padding, }) - if c.paddingEnabled { - conn = newPaddingConn(conn) - } - session, err := c.protocol.newClient(conn) - if err != nil { - return nil, err - } - c.connections.PushBack(session) - return session, nil -} - -func (c *Client) Reset() { - c.access.Lock() - defer c.access.Unlock() - for _, session := range c.connections.Array() { - session.Close() - } - c.connections.Init() -} - -func (c *Client) Close() error { - c.access.Lock() - defer c.access.Unlock() - for _, session := range c.connections.Array() { - session.Close() - } - return nil -} - -type ClientConn struct { - net.Conn - destination M.Socksaddr - requestWrite bool - responseRead bool -} - -func (c *ClientConn) readResponse() error { - response, err := ReadStreamResponse(c.Conn) - if err != nil { - return err - } - if response.Status == statusError { - return E.New("remote error: ", response.Message) - } - return nil -} - -func (c *ClientConn) Read(b []byte) (n int, err error) { - if !c.responseRead { - err = c.readResponse() - if err != nil { - return - } - c.responseRead = true - } - return c.Conn.Read(b) -} - -func (c *ClientConn) Write(b []byte) (n int, err error) { - if c.requestWrite { - return c.Conn.Write(b) - } - request := StreamRequest{ - Network: N.NetworkTCP, - Destination: c.destination, - } - _buffer := buf.StackNewSize(streamRequestLen(request) + len(b)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - EncodeStreamRequest(request, buffer) - buffer.Write(b) - _, err = c.Conn.Write(buffer.Bytes()) - if err != nil { - return - } - c.requestWrite = true - return len(b), nil -} - -func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) { - if !c.requestWrite { - return bufio.ReadFrom0(c, r) - } - return bufio.Copy(c.Conn, r) -} - -func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) { - if !c.responseRead { - return bufio.WriteTo0(c, w) - } - return bufio.Copy(w, c.Conn) -} - -func (c *ClientConn) LocalAddr() net.Addr { - return c.Conn.LocalAddr() -} - -func (c *ClientConn) RemoteAddr() net.Addr { - return c.destination.TCPAddr() -} - -func (c *ClientConn) ReaderReplaceable() bool { - return c.responseRead -} - -func (c *ClientConn) WriterReplaceable() bool { - return c.requestWrite -} - -func (c *ClientConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *ClientConn) Upstream() any { - return c.Conn -} - -type ClientPacketConn struct { - N.ExtendedConn - destination M.Socksaddr - requestWrite bool - responseRead bool -} - -func (c *ClientPacketConn) readResponse() error { - response, err := ReadStreamResponse(c.ExtendedConn) - if err != nil { - return err - } - if response.Status == statusError { - return E.New("remote error: ", response.Message) - } - return nil -} - -func (c *ClientPacketConn) Read(b []byte) (n int, err error) { - if !c.responseRead { - err = c.readResponse() - if err != nil { - return - } - c.responseRead = true - } - var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) - if err != nil { - return - } - if cap(b) < int(length) { - return 0, io.ErrShortBuffer - } - return io.ReadFull(c.ExtendedConn, b[:length]) -} - -func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) { - request := StreamRequest{ - Network: N.NetworkUDP, - Destination: c.destination, - } - rLen := streamRequestLen(request) - if len(payload) > 0 { - rLen += 2 + len(payload) - } - _buffer := buf.StackNewSize(rLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - EncodeStreamRequest(request, buffer) - if len(payload) > 0 { - common.Must( - binary.Write(buffer, binary.BigEndian, uint16(len(payload))), - common.Error(buffer.Write(payload)), - ) - } - _, err = c.ExtendedConn.Write(buffer.Bytes()) - if err != nil { - return - } - c.requestWrite = true - return len(payload), nil -} - -func (c *ClientPacketConn) Write(b []byte) (n int, err error) { - if !c.requestWrite { - return c.writeRequest(b) - } - err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b))) - if err != nil { - return - } - return c.ExtendedConn.Write(b) -} - -func (c *ClientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) { - if !c.responseRead { - err = c.readResponse() - if err != nil { - return - } - c.responseRead = true - } - var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) - if err != nil { - return - } - _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) - return -} - -func (c *ClientPacketConn) WriteBuffer(buffer *buf.Buffer) error { - if !c.requestWrite { - defer buffer.Release() - return common.Error(c.writeRequest(buffer.Bytes())) - } - bLen := buffer.Len() - binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen)) - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *ClientPacketConn) FrontHeadroom() int { - return 2 -} - -func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - err = c.ReadBuffer(buffer) - return -} - -func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - return c.WriteBuffer(buffer) -} - -func (c *ClientPacketConn) LocalAddr() net.Addr { - return c.ExtendedConn.LocalAddr() -} - -func (c *ClientPacketConn) RemoteAddr() net.Addr { - return c.destination.UDPAddr() -} - -func (c *ClientPacketConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *ClientPacketConn) Upstream() any { - return c.ExtendedConn -} - -var _ N.NetPacketConn = (*ClientPacketAddrConn)(nil) - -type ClientPacketAddrConn struct { - N.ExtendedConn - destination M.Socksaddr - requestWrite bool - responseRead bool -} - -func (c *ClientPacketAddrConn) readResponse() error { - response, err := ReadStreamResponse(c.ExtendedConn) - if err != nil { - return err - } - if response.Status == statusError { - return E.New("remote error: ", response.Message) - } - return nil -} - -func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if !c.responseRead { - err = c.readResponse() - if err != nil { - return - } - c.responseRead = true - } - destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) - if err != nil { - return - } - if destination.IsFqdn() { - addr = destination - } else { - addr = destination.UDPAddr() - } - var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) - if err != nil { - return - } - if cap(p) < int(length) { - return 0, nil, io.ErrShortBuffer - } - n, err = io.ReadFull(c.ExtendedConn, p[:length]) - return -} - -func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) { - request := StreamRequest{ - Network: N.NetworkUDP, - Destination: c.destination, - PacketAddr: true, - } - rLen := streamRequestLen(request) - if len(payload) > 0 { - rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload) - } - _buffer := buf.StackNewSize(rLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - EncodeStreamRequest(request, buffer) - if len(payload) > 0 { - common.Must( - M.SocksaddrSerializer.WriteAddrPort(buffer, destination), - binary.Write(buffer, binary.BigEndian, uint16(len(payload))), - common.Error(buffer.Write(payload)), - ) - } - _, err = c.ExtendedConn.Write(buffer.Bytes()) - if err != nil { - return - } - c.requestWrite = true - return len(payload), nil -} - -func (c *ClientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if !c.requestWrite { - return c.writeRequest(p, M.SocksaddrFromNet(addr)) - } - err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) - if err != nil { - return - } - err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) - if err != nil { - return - } - return c.ExtendedConn.Write(p) -} - -func (c *ClientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - if !c.responseRead { - err = c.readResponse() - if err != nil { - return - } - c.responseRead = true - } - destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) - if err != nil { - return - } - var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) - if err != nil { - return - } - _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) - return -} - -func (c *ClientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - if !c.requestWrite { - defer buffer.Release() - return common.Error(c.writeRequest(buffer.Bytes(), destination)) - } - bLen := buffer.Len() - header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2)) - common.Must( - M.SocksaddrSerializer.WriteAddrPort(header, destination), - binary.Write(header, binary.BigEndian, uint16(bLen)), - ) - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *ClientPacketAddrConn) LocalAddr() net.Addr { - return c.ExtendedConn.LocalAddr() -} - -func (c *ClientPacketAddrConn) FrontHeadroom() int { - return 2 + M.MaxSocksaddrLength -} - -func (c *ClientPacketAddrConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *ClientPacketAddrConn) Upstream() any { - return c.ExtendedConn } diff --git a/common/mux/h2mux.go b/common/mux/h2mux.go deleted file mode 100644 index 1449dcc7..00000000 --- a/common/mux/h2mux.go +++ /dev/null @@ -1,235 +0,0 @@ -package mux - -import ( - "context" - "crypto/tls" - "io" - "net" - "net/http" - "net/url" - "os" - "time" - - "github.com/sagernet/sing-box/transport/v2rayhttp" - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - N "github.com/sagernet/sing/common/network" - - "golang.org/x/net/http2" -) - -const idleTimeout = 30 * time.Second - -var _ abstractSession = (*H2MuxServerSession)(nil) - -type H2MuxServerSession struct { - server http2.Server - active atomic.Int32 - conn net.Conn - inbound chan net.Conn - done chan struct{} -} - -func NewH2MuxServer(conn net.Conn) *H2MuxServerSession { - session := &H2MuxServerSession{ - conn: conn, - inbound: make(chan net.Conn), - done: make(chan struct{}), - server: http2.Server{ - IdleTimeout: idleTimeout, - }, - } - go func() { - session.server.ServeConn(conn, &http2.ServeConnOpts{ - Handler: session, - }) - _ = session.Close() - }() - return session -} - -func (s *H2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - s.active.Add(1) - defer s.active.Add(-1) - writer.WriteHeader(http.StatusOK) - conn := newHTTP2Wrapper(&v2rayhttp.ServerHTTPConn{ - HTTP2Conn: v2rayhttp.NewHTTPConn(request.Body, writer), - Flusher: writer.(http.Flusher), - }) - s.inbound <- conn - select { - case <-conn.done: - case <-s.done: - } -} - -func (s *H2MuxServerSession) Open() (net.Conn, error) { - return nil, os.ErrInvalid -} - -func (s *H2MuxServerSession) Accept() (net.Conn, error) { - select { - case conn := <-s.inbound: - return conn, nil - case <-s.done: - return nil, os.ErrClosed - } -} - -func (s *H2MuxServerSession) NumStreams() int { - return int(s.active.Load()) -} - -func (s *H2MuxServerSession) Close() error { - select { - case <-s.done: - default: - close(s.done) - } - return s.conn.Close() -} - -func (s *H2MuxServerSession) IsClosed() bool { - select { - case <-s.done: - return true - default: - return false - } -} - -func (s *H2MuxServerSession) CanTakeNewRequest() bool { - return false -} - -type h2MuxConnWrapper struct { - N.ExtendedConn - done chan struct{} -} - -func newHTTP2Wrapper(conn net.Conn) *h2MuxConnWrapper { - return &h2MuxConnWrapper{ - ExtendedConn: bufio.NewExtendedConn(conn), - done: make(chan struct{}), - } -} - -func (w *h2MuxConnWrapper) Write(p []byte) (n int, err error) { - select { - case <-w.done: - return 0, net.ErrClosed - default: - } - return w.ExtendedConn.Write(p) -} - -func (w *h2MuxConnWrapper) WriteBuffer(buffer *buf.Buffer) error { - select { - case <-w.done: - return net.ErrClosed - default: - } - return w.ExtendedConn.WriteBuffer(buffer) -} - -func (w *h2MuxConnWrapper) Close() error { - select { - case <-w.done: - default: - close(w.done) - } - return w.ExtendedConn.Close() -} - -func (w *h2MuxConnWrapper) Upstream() any { - return w.ExtendedConn -} - -var _ abstractSession = (*H2MuxClientSession)(nil) - -type H2MuxClientSession struct { - transport *http2.Transport - clientConn *http2.ClientConn - done chan struct{} -} - -func NewH2MuxClient(conn net.Conn) (*H2MuxClientSession, error) { - session := &H2MuxClientSession{ - transport: &http2.Transport{ - DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - return conn, nil - }, - ReadIdleTimeout: idleTimeout, - }, - done: make(chan struct{}), - } - session.transport.ConnPool = session - clientConn, err := session.transport.NewClientConn(conn) - if err != nil { - return nil, err - } - session.clientConn = clientConn - return session, nil -} - -func (s *H2MuxClientSession) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) { - return s.clientConn, nil -} - -func (s *H2MuxClientSession) MarkDead(conn *http2.ClientConn) { - s.Close() -} - -func (s *H2MuxClientSession) Open() (net.Conn, error) { - pipeInReader, pipeInWriter := io.Pipe() - request := &http.Request{ - Method: http.MethodConnect, - Body: pipeInReader, - URL: &url.URL{Scheme: "https", Host: "localhost"}, - } - conn := v2rayhttp.NewLateHTTPConn(pipeInWriter) - go func() { - response, err := s.transport.RoundTrip(request) - if err != nil { - conn.Setup(nil, err) - } else if response.StatusCode != 200 { - response.Body.Close() - conn.Setup(nil, E.New("unexpected status: ", response.StatusCode, " ", response.Status)) - } else { - conn.Setup(response.Body, nil) - } - }() - return conn, nil -} - -func (s *H2MuxClientSession) Accept() (net.Conn, error) { - return nil, os.ErrInvalid -} - -func (s *H2MuxClientSession) NumStreams() int { - return s.clientConn.State().StreamsActive -} - -func (s *H2MuxClientSession) Close() error { - select { - case <-s.done: - default: - close(s.done) - } - return s.clientConn.Close() -} - -func (s *H2MuxClientSession) IsClosed() bool { - select { - case <-s.done: - return true - default: - } - return s.clientConn.State().Closed -} - -func (s *H2MuxClientSession) CanTakeNewRequest() bool { - return s.clientConn.CanTakeNewRequest() -} diff --git a/common/mux/padding.go b/common/mux/padding.go deleted file mode 100644 index 850bf254..00000000 --- a/common/mux/padding.go +++ /dev/null @@ -1,240 +0,0 @@ -package mux - -import ( - "encoding/binary" - "io" - "math/rand" - "net" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" -) - -const kFirstPaddings = 16 - -type paddingConn struct { - N.ExtendedConn - writer N.VectorisedWriter - readPadding int - writePadding int - readRemaining int - paddingRemaining int -} - -func newPaddingConn(conn net.Conn) net.Conn { - writer, isVectorised := bufio.CreateVectorisedWriter(conn) - if isVectorised { - return &vectorisedPaddingConn{ - paddingConn{ - ExtendedConn: bufio.NewExtendedConn(conn), - writer: bufio.NewVectorisedWriter(conn), - }, - writer, - } - } else { - return &paddingConn{ - ExtendedConn: bufio.NewExtendedConn(conn), - writer: bufio.NewVectorisedWriter(conn), - } - } -} - -func (c *paddingConn) Read(p []byte) (n int, err error) { - if c.readRemaining > 0 { - if len(p) > c.readRemaining { - p = p[:c.readRemaining] - } - n, err = c.ExtendedConn.Read(p) - if err != nil { - return - } - c.readRemaining -= n - return - } - if c.paddingRemaining > 0 { - err = rw.SkipN(c.ExtendedConn, c.paddingRemaining) - if err != nil { - return - } - c.paddingRemaining = 0 - } - if c.readPadding < kFirstPaddings { - var paddingHdr []byte - if len(p) >= 4 { - paddingHdr = p[:4] - } else { - _paddingHdr := make([]byte, 4) - defer common.KeepAlive(_paddingHdr) - paddingHdr = common.Dup(_paddingHdr) - } - _, err = io.ReadFull(c.ExtendedConn, paddingHdr) - if err != nil { - return - } - originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) - paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:])) - if len(p) > originalDataSize { - p = p[:originalDataSize] - } - n, err = c.ExtendedConn.Read(p) - if err != nil { - return - } - c.readPadding++ - c.readRemaining = originalDataSize - n - c.paddingRemaining = paddingLen - return - } - return c.ExtendedConn.Read(p) -} - -func (c *paddingConn) Write(p []byte) (n int, err error) { - for pLen := len(p); pLen > 0; { - var data []byte - if pLen > 65535 { - data = p[:65535] - p = p[65535:] - pLen -= 65535 - } else { - data = p - pLen = 0 - } - var writeN int - writeN, err = c.write(data) - n += writeN - if err != nil { - break - } - } - return n, err -} - -func (c *paddingConn) write(p []byte) (n int, err error) { - if c.writePadding < kFirstPaddings { - paddingLen := 256 + rand.Intn(512) - _buffer := buf.StackNewSize(4 + len(p) + paddingLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - header := buffer.Extend(4) - binary.BigEndian.PutUint16(header[:2], uint16(len(p))) - binary.BigEndian.PutUint16(header[2:], uint16(paddingLen)) - common.Must1(buffer.Write(p)) - buffer.Extend(paddingLen) - _, err = c.ExtendedConn.Write(buffer.Bytes()) - if err == nil { - n = len(p) - } - c.writePadding++ - return - } - return c.ExtendedConn.Write(p) -} - -func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error { - p := buffer.FreeBytes() - if c.readRemaining > 0 { - if len(p) > c.readRemaining { - p = p[:c.readRemaining] - } - n, err := c.ExtendedConn.Read(p) - if err != nil { - return err - } - c.readRemaining -= n - buffer.Truncate(n) - return nil - } - if c.paddingRemaining > 0 { - err := rw.SkipN(c.ExtendedConn, c.paddingRemaining) - if err != nil { - return err - } - c.paddingRemaining = 0 - } - if c.readPadding < kFirstPaddings { - var paddingHdr []byte - if len(p) >= 4 { - paddingHdr = p[:4] - } else { - _paddingHdr := make([]byte, 4) - defer common.KeepAlive(_paddingHdr) - paddingHdr = common.Dup(_paddingHdr) - } - _, err := io.ReadFull(c.ExtendedConn, paddingHdr) - if err != nil { - return err - } - originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) - paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:])) - - if len(p) > originalDataSize { - p = p[:originalDataSize] - } - n, err := c.ExtendedConn.Read(p) - if err != nil { - return err - } - c.readPadding++ - c.readRemaining = originalDataSize - n - c.paddingRemaining = paddingLen - buffer.Truncate(n) - return nil - } - return c.ExtendedConn.ReadBuffer(buffer) -} - -func (c *paddingConn) WriteBuffer(buffer *buf.Buffer) error { - if c.writePadding < kFirstPaddings { - bufferLen := buffer.Len() - if bufferLen > 65535 { - return common.Error(c.Write(buffer.Bytes())) - } - paddingLen := 256 + rand.Intn(512) - header := buffer.ExtendHeader(4) - binary.BigEndian.PutUint16(header[:2], uint16(bufferLen)) - binary.BigEndian.PutUint16(header[2:], uint16(paddingLen)) - buffer.Extend(paddingLen) - c.writePadding++ - } - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *paddingConn) FrontHeadroom() int { - return 4 + 256 + 1024 -} - -type vectorisedPaddingConn struct { - paddingConn - writer N.VectorisedWriter -} - -func (c *vectorisedPaddingConn) WriteVectorised(buffers []*buf.Buffer) error { - if c.writePadding < kFirstPaddings { - bufferLen := buf.LenMulti(buffers) - if bufferLen > 65535 { - defer buf.ReleaseMulti(buffers) - for _, buffer := range buffers { - _, err := c.Write(buffer.Bytes()) - if err != nil { - return err - } - } - return nil - } - paddingLen := 256 + rand.Intn(512) - header := buf.NewSize(4) - common.Must( - binary.Write(header, binary.BigEndian, uint16(bufferLen)), - binary.Write(header, binary.BigEndian, uint16(paddingLen)), - ) - c.writePadding++ - padding := buf.NewSize(paddingLen) - padding.Extend(paddingLen) - buffers = append(append([]*buf.Buffer{header}, buffers...), padding) - } - return c.writer.WriteVectorised(buffers) -} diff --git a/common/mux/protocol.go b/common/mux/protocol.go index d48a92ee..abb0e268 100644 --- a/common/mux/protocol.go +++ b/common/mux/protocol.go @@ -1,299 +1,14 @@ package mux import ( - "encoding/binary" - "io" - "math/rand" - "net" - - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" - "github.com/sagernet/smux" - - "github.com/hashicorp/yamux" + "github.com/sagernet/sing-mux" ) -var Destination = M.Socksaddr{ - Fqdn: "sp.mux.sing-box.arpa", - Port: 444, -} - -const ( - ProtocolSMux Protocol = iota - ProtocolYAMux - ProtocolH2Mux +type ( + Client = mux.Client ) -type Protocol byte - -func ParseProtocol(name string) (Protocol, error) { - switch name { - case "", "smux": - return ProtocolSMux, nil - case "yamux": - return ProtocolYAMux, nil - case "h2mux": - return ProtocolH2Mux, nil - default: - return ProtocolSMux, E.New("unknown multiplex protocol: ", name) - } -} - -func (p Protocol) newServer(conn net.Conn) (abstractSession, error) { - switch p { - case ProtocolSMux: - session, err := smux.Server(conn, smuxConfig()) - if err != nil { - return nil, err - } - return &smuxSession{session}, nil - case ProtocolYAMux: - session, err := yamux.Server(conn, yaMuxConfig()) - if err != nil { - return nil, err - } - return &yamuxSession{session}, nil - case ProtocolH2Mux: - return NewH2MuxServer(conn), nil - default: - panic("unknown protocol") - } -} - -func (p Protocol) newClient(conn net.Conn) (abstractSession, error) { - switch p { - case ProtocolSMux: - session, err := smux.Client(conn, smuxConfig()) - if err != nil { - return nil, err - } - return &smuxSession{session}, nil - case ProtocolYAMux: - session, err := yamux.Client(conn, yaMuxConfig()) - if err != nil { - return nil, err - } - return &yamuxSession{session}, nil - case ProtocolH2Mux: - return NewH2MuxClient(conn) - default: - panic("unknown protocol") - } -} - -func smuxConfig() *smux.Config { - config := smux.DefaultConfig() - config.KeepAliveDisabled = true - return config -} - -func yaMuxConfig() *yamux.Config { - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - config.StreamCloseTimeout = C.TCPTimeout - config.StreamOpenTimeout = C.TCPTimeout - return config -} - -func (p Protocol) String() string { - switch p { - case ProtocolSMux: - return "smux" - case ProtocolYAMux: - return "yamux" - case ProtocolH2Mux: - return "h2mux" - default: - return "unknown" - } -} - -const ( - Version0 = iota - Version1 +var ( + Destination = mux.Destination + HandleConnection = mux.HandleConnection ) - -type Request struct { - Version byte - Protocol Protocol - PaddingEnabled bool -} - -func ReadRequest(reader io.Reader) (*Request, error) { - version, err := rw.ReadByte(reader) - if err != nil { - return nil, err - } - if version < Version0 || version > Version1 { - return nil, E.New("unsupported version: ", version) - } - protocol, err := rw.ReadByte(reader) - if err != nil { - return nil, err - } - var paddingEnabled bool - if version == Version1 { - err = binary.Read(reader, binary.BigEndian, &paddingEnabled) - if err != nil { - return nil, err - } - if paddingEnabled { - var paddingLen uint16 - err = binary.Read(reader, binary.BigEndian, &paddingLen) - if err != nil { - return nil, err - } - err = rw.SkipN(reader, int(paddingLen)) - if err != nil { - return nil, err - } - } - } - return &Request{Version: version, Protocol: Protocol(protocol), PaddingEnabled: paddingEnabled}, nil -} - -func EncodeRequest(request Request, payload []byte) *buf.Buffer { - var requestLen int - requestLen += 2 - var paddingLen uint16 - if request.Version == Version1 { - requestLen += 1 - if request.PaddingEnabled { - requestLen += 2 - paddingLen = uint16(256 + rand.Intn(512)) - requestLen += int(paddingLen) - } - } - buffer := buf.NewSize(requestLen + len(payload)) - common.Must( - buffer.WriteByte(request.Version), - buffer.WriteByte(byte(request.Protocol)), - ) - if request.Version == Version1 { - common.Must(binary.Write(buffer, binary.BigEndian, request.PaddingEnabled)) - if request.PaddingEnabled { - common.Must(binary.Write(buffer, binary.BigEndian, paddingLen)) - buffer.Extend(int(paddingLen)) - } - } - common.Must1(buffer.Write(payload)) - return buffer -} - -const ( - flagUDP = 1 - flagAddr = 2 - statusSuccess = 0 - statusError = 1 -) - -type StreamRequest struct { - Network string - Destination M.Socksaddr - PacketAddr bool -} - -func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) { - var flags uint16 - err := binary.Read(reader, binary.BigEndian, &flags) - if err != nil { - return nil, err - } - destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) - if err != nil { - return nil, err - } - var network string - var udpAddr bool - if flags&flagUDP == 0 { - network = N.NetworkTCP - } else { - network = N.NetworkUDP - udpAddr = flags&flagAddr != 0 - } - return &StreamRequest{network, destination, udpAddr}, nil -} - -func streamRequestLen(request StreamRequest) int { - var rLen int - rLen += 1 // version - rLen += 2 // flags - rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination) - return rLen -} - -func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) { - destination := request.Destination - var flags uint16 - if request.Network == N.NetworkUDP { - flags |= flagUDP - } - if request.PacketAddr { - flags |= flagAddr - if !destination.IsValid() { - destination = Destination - } - } - common.Must( - binary.Write(buffer, binary.BigEndian, flags), - M.SocksaddrSerializer.WriteAddrPort(buffer, destination), - ) -} - -type StreamResponse struct { - Status uint8 - Message string -} - -func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) { - var response StreamResponse - status, err := rw.ReadByte(reader) - if err != nil { - return nil, err - } - response.Status = status - if status == statusError { - response.Message, err = rw.ReadVString(reader) - if err != nil { - return nil, err - } - } - return &response, nil -} - -type wrapStream struct { - net.Conn -} - -func (w *wrapStream) Read(p []byte) (n int, err error) { - n, err = w.Conn.Read(p) - err = wrapError(err) - return -} - -func (w *wrapStream) Write(p []byte) (n int, err error) { - n, err = w.Conn.Write(p) - err = wrapError(err) - return -} - -func (w *wrapStream) WriteIsThreadUnsafe() { -} - -func (w *wrapStream) Upstream() any { - return w.Conn -} - -func wrapError(err error) error { - switch err { - case yamux.ErrStreamClosed: - return io.EOF - default: - return err - } -} diff --git a/common/mux/service.go b/common/mux/service.go deleted file mode 100644 index eef7e1d7..00000000 --- a/common/mux/service.go +++ /dev/null @@ -1,272 +0,0 @@ -package mux - -import ( - "context" - "encoding/binary" - "net" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/common/task" -) - -func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error { - request, err := ReadRequest(conn) - if err != nil { - return err - } - if request.PaddingEnabled { - conn = newPaddingConn(conn) - } - session, err := request.Protocol.newServer(conn) - if err != nil { - return err - } - var group task.Group - group.Append0(func(ctx context.Context) error { - var stream net.Conn - for { - stream, err = session.Accept() - if err != nil { - return err - } - go newConnection(ctx, router, errorHandler, logger, stream, metadata) - } - }) - group.Cleanup(func() { - session.Close() - }) - return group.Run(ctx) -} - -func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) { - stream = &wrapStream{stream} - request, err := ReadStreamRequest(stream) - if err != nil { - logger.ErrorContext(ctx, err) - return - } - metadata.Destination = request.Destination - if request.Network == N.NetworkTCP { - logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination) - hErr := router.RouteConnection(ctx, &ServerConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata) - stream.Close() - if hErr != nil { - errorHandler.NewError(ctx, hErr) - } - } else { - var packetConn N.PacketConn - if !request.PacketAddr { - logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination) - packetConn = &ServerPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination} - } else { - logger.InfoContext(ctx, "inbound multiplex packet connection") - packetConn = &ServerPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)} - } - hErr := router.RoutePacketConnection(ctx, packetConn, metadata) - stream.Close() - if hErr != nil { - errorHandler.NewError(ctx, hErr) - } - } -} - -var _ N.HandshakeConn = (*ServerConn)(nil) - -type ServerConn struct { - N.ExtendedConn - responseWrite bool -} - -func (c *ServerConn) HandshakeFailure(err error) error { - errMessage := err.Error() - _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - common.Must( - buffer.WriteByte(statusError), - rw.WriteVString(_buffer, errMessage), - ) - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *ServerConn) Write(b []byte) (n int, err error) { - if c.responseWrite { - return c.ExtendedConn.Write(b) - } - _buffer := buf.StackNewSize(1 + len(b)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - common.Must( - buffer.WriteByte(statusSuccess), - common.Error(buffer.Write(b)), - ) - _, err = c.ExtendedConn.Write(buffer.Bytes()) - if err != nil { - return - } - c.responseWrite = true - return len(b), nil -} - -func (c *ServerConn) WriteBuffer(buffer *buf.Buffer) error { - if c.responseWrite { - return c.ExtendedConn.WriteBuffer(buffer) - } - buffer.ExtendHeader(1)[0] = statusSuccess - c.responseWrite = true - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *ServerConn) FrontHeadroom() int { - if !c.responseWrite { - return 1 - } - return 0 -} - -func (c *ServerConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *ServerConn) Upstream() any { - return c.ExtendedConn -} - -var ( - _ N.HandshakeConn = (*ServerPacketConn)(nil) - _ N.PacketConn = (*ServerPacketConn)(nil) -) - -type ServerPacketConn struct { - N.ExtendedConn - destination M.Socksaddr - responseWrite bool -} - -func (c *ServerPacketConn) HandshakeFailure(err error) error { - errMessage := err.Error() - _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - common.Must( - buffer.WriteByte(statusError), - rw.WriteVString(_buffer, errMessage), - ) - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *ServerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) - if err != nil { - return - } - _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) - if err != nil { - return - } - destination = c.destination - return -} - -func (c *ServerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - pLen := buffer.Len() - common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) - if !c.responseWrite { - buffer.ExtendHeader(1)[0] = statusSuccess - c.responseWrite = true - } - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *ServerPacketConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *ServerPacketConn) Upstream() any { - return c.ExtendedConn -} - -func (c *ServerPacketConn) FrontHeadroom() int { - if !c.responseWrite { - return 3 - } - return 2 -} - -var ( - _ N.HandshakeConn = (*ServerPacketAddrConn)(nil) - _ N.PacketConn = (*ServerPacketAddrConn)(nil) -) - -type ServerPacketAddrConn struct { - N.ExtendedConn - responseWrite bool -} - -func (c *ServerPacketAddrConn) HandshakeFailure(err error) error { - errMessage := err.Error() - _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - common.Must( - buffer.WriteByte(statusError), - rw.WriteVString(_buffer, errMessage), - ) - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *ServerPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) - if err != nil { - return - } - var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) - if err != nil { - return - } - _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) - if err != nil { - return - } - return -} - -func (c *ServerPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - pLen := buffer.Len() - common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) - common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination)) - if !c.responseWrite { - buffer.ExtendHeader(1)[0] = statusSuccess - c.responseWrite = true - } - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *ServerPacketAddrConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *ServerPacketAddrConn) Upstream() any { - return c.ExtendedConn -} - -func (c *ServerPacketAddrConn) FrontHeadroom() int { - if !c.responseWrite { - return 3 + M.MaxSocksaddrLength - } - return 2 + M.MaxSocksaddrLength -} diff --git a/common/mux/session.go b/common/mux/session.go deleted file mode 100644 index db9b6ee4..00000000 --- a/common/mux/session.go +++ /dev/null @@ -1,111 +0,0 @@ -package mux - -import ( - "io" - "net" - - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/smux" - - "github.com/hashicorp/yamux" -) - -type abstractSession interface { - Open() (net.Conn, error) - Accept() (net.Conn, error) - NumStreams() int - Close() error - IsClosed() bool - CanTakeNewRequest() bool -} - -var _ abstractSession = (*smuxSession)(nil) - -type smuxSession struct { - *smux.Session -} - -func (s *smuxSession) Open() (net.Conn, error) { - return s.OpenStream() -} - -func (s *smuxSession) Accept() (net.Conn, error) { - return s.AcceptStream() -} - -func (s *smuxSession) CanTakeNewRequest() bool { - return true -} - -type yamuxSession struct { - *yamux.Session -} - -func (y *yamuxSession) CanTakeNewRequest() bool { - return true -} - -type protocolConn struct { - net.Conn - request Request - protocolWritten bool -} - -func newProtocolConn(conn net.Conn, request Request) net.Conn { - writer, isVectorised := bufio.CreateVectorisedWriter(conn) - if isVectorised { - return &vectorisedProtocolConn{ - protocolConn{ - Conn: conn, - request: request, - }, - writer, - } - } else { - return &protocolConn{ - Conn: conn, - request: request, - } - } -} - -func (c *protocolConn) Write(p []byte) (n int, err error) { - if c.protocolWritten { - return c.Conn.Write(p) - } - buffer := EncodeRequest(c.request, p) - n, err = c.Conn.Write(buffer.Bytes()) - buffer.Release() - if err == nil { - n-- - } - c.protocolWritten = true - return n, err -} - -func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) { - if !c.protocolWritten { - return bufio.ReadFrom0(c, r) - } - return bufio.Copy(c.Conn, r) -} - -func (c *protocolConn) Upstream() any { - return c.Conn -} - -type vectorisedProtocolConn struct { - protocolConn - writer N.VectorisedWriter -} - -func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error { - if c.protocolWritten { - return c.writer.WriteVectorised(buffers) - } - c.protocolWritten = true - buffer := EncodeRequest(c.request, nil) - return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...)) -} diff --git a/docs/configuration/shared/multiplex.md b/docs/configuration/shared/multiplex.md index f6256d09..833efaff 100644 --- a/docs/configuration/shared/multiplex.md +++ b/docs/configuration/shared/multiplex.md @@ -31,7 +31,7 @@ Multiplex protocol. | yamux | https://github.com/hashicorp/yamux | | h2mux | https://golang.org/x/net/http2 | -SMux is used by default. +h2mux is used by default. #### max_connections diff --git a/docs/configuration/shared/multiplex.zh.md b/docs/configuration/shared/multiplex.zh.md index aae78d41..99a0ba03 100644 --- a/docs/configuration/shared/multiplex.zh.md +++ b/docs/configuration/shared/multiplex.zh.md @@ -30,7 +30,7 @@ | yamux | https://github.com/hashicorp/yamux | | h2mux | https://golang.org/x/net/http2 | -默认使用 SMux。 +默认使用 h2mux。 #### max_connections diff --git a/go.mod b/go.mod index e897b881..8d77d4b4 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/go-chi/cors v1.2.1 github.com/go-chi/render v1.0.2 github.com/gofrs/uuid/v5 v5.0.0 - github.com/hashicorp/yamux v0.1.1 github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/mholt/acmez v1.1.0 @@ -25,8 +24,9 @@ require ( github.com/sagernet/gomobile v0.0.0-20230413023804-244d7ff07035 github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.2.4 + github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc + github.com/sagernet/sing-mux v0.0.0-20230424015424-9b0d527c3bb0 github.com/sagernet/sing-shadowsocks v0.2.2-0.20230417102954-f77257340507 github.com/sagernet/sing-shadowtls v0.1.2-0.20230417103049-4f682e05f19b github.com/sagernet/sing-tun v0.1.5-0.20230422121432-209ec123ca7b @@ -63,6 +63,7 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/google/btree v1.0.1 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + github.com/hashicorp/yamux v0.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/klauspost/compress v1.15.15 // indirect diff --git a/go.sum b/go.sum index 9728dec9..8b4ebe7a 100644 --- a/go.sum +++ b/go.sum @@ -111,10 +111,12 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.4 h1:gC8BR5sglbJZX23RtMyFa8EETP9YEUADhfbEzU1yVbo= -github.com/sagernet/sing v0.2.4/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= +github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 h1:+dDVjW20IT+e8maKryaDeRY2+RFmTFdrQeIzqE2WOss= +github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc h1:hmbuqKv48SAjiKPoqtJGvS5pEHVPZjTHq9CPwQY2cZ4= github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc/go.mod h1:ZKuuqgsHRxDahYrzgSgy4vIAGGuKPlIf4hLcNzYzLkY= +github.com/sagernet/sing-mux v0.0.0-20230424015424-9b0d527c3bb0 h1:87jyxzTjq01VgEiUVSMNRKjCfsSfp/QwyUVT37eXY50= +github.com/sagernet/sing-mux v0.0.0-20230424015424-9b0d527c3bb0/go.mod h1:pF+RnLvCAOhECrvauy6LYOpBakJ/vuaF1Wm4lPsWryI= github.com/sagernet/sing-shadowsocks v0.2.2-0.20230417102954-f77257340507 h1:bAHZCdWqJkb8LEW98+YsMVDXGRMUVjka8IC+St6ot88= github.com/sagernet/sing-shadowsocks v0.2.2-0.20230417102954-f77257340507/go.mod h1:UJjvQGw0lyYaDGIDvUraL16fwaAEH1WFw1Y6sUcMPog= github.com/sagernet/sing-shadowtls v0.1.2-0.20230417103049-4f682e05f19b h1:ouW/6IDCrxkBe19YSbdCd7buHix7b+UZ6BM4Zz74XF4= diff --git a/outbound/shadowsocks.go b/outbound/shadowsocks.go index 08c0a688..1e2e2e63 100644 --- a/outbound/shadowsocks.go +++ b/outbound/shadowsocks.go @@ -58,7 +58,7 @@ func NewShadowsocks(ctx context.Context, router adapter.Router, logger log.Conte } uotOptions := common.PtrValueOrDefault(options.UDPOverTCPOptions) if !uotOptions.Enabled { - outbound.multiplexDialer, err = mux.NewClientWithOptions(ctx, (*shadowsocksDialer)(outbound), common.PtrValueOrDefault(options.MultiplexOptions)) + outbound.multiplexDialer, err = mux.NewClientWithOptions((*shadowsocksDialer)(outbound), common.PtrValueOrDefault(options.MultiplexOptions)) if err != nil { return nil, err } diff --git a/outbound/trojan.go b/outbound/trojan.go index 4a72b22d..c33f9867 100644 --- a/outbound/trojan.go +++ b/outbound/trojan.go @@ -58,7 +58,7 @@ func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLog return nil, E.Cause(err, "create client transport: ", options.Transport.Type) } } - outbound.multiplexDialer, err = mux.NewClientWithOptions(ctx, (*trojanDialer)(outbound), common.PtrValueOrDefault(options.Multiplex)) + outbound.multiplexDialer, err = mux.NewClientWithOptions((*trojanDialer)(outbound), common.PtrValueOrDefault(options.Multiplex)) if err != nil { return nil, err } diff --git a/outbound/vmess.go b/outbound/vmess.go index e0030bf3..288bc42d 100644 --- a/outbound/vmess.go +++ b/outbound/vmess.go @@ -59,7 +59,7 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg return nil, E.Cause(err, "create client transport: ", options.Transport.Type) } } - outbound.multiplexDialer, err = mux.NewClientWithOptions(ctx, (*vmessDialer)(outbound), common.PtrValueOrDefault(options.Multiplex)) + outbound.multiplexDialer, err = mux.NewClientWithOptions((*vmessDialer)(outbound), common.PtrValueOrDefault(options.Multiplex)) if err != nil { return nil, err } diff --git a/route/router.go b/route/router.go index b6c776e1..d89cd2f7 100644 --- a/route/router.go +++ b/route/router.go @@ -598,7 +598,8 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad switch metadata.Destination.Fqdn { case mux.Destination.Fqdn: r.logger.InfoContext(ctx, "inbound multiplex connection") - return mux.NewConnection(ctx, r, r, r.logger, conn, metadata) + handler := adapter.NewUpstreamHandler(metadata, r.RouteConnection, r.RoutePacketConnection, r) + return mux.HandleConnection(ctx, handler, r.logger, conn, adapter.UpstreamMetadata(metadata)) case vmess.MuxDestination.Fqdn: r.logger.InfoContext(ctx, "inbound legacy multiplex connection") return vmess.HandleMuxConnection(ctx, conn, adapter.NewUpstreamHandler(metadata, r.RouteConnection, r.RoutePacketConnection, r))