diff --git a/common/qtls/wrapper.go b/common/qtls/wrapper.go deleted file mode 100644 index 67ca47a6..00000000 --- a/common/qtls/wrapper.go +++ /dev/null @@ -1,120 +0,0 @@ -package qtls - -import ( - "context" - "crypto/tls" - "net" - "net/http" - - "github.com/sagernet/quic-go" - "github.com/sagernet/quic-go/http3" - M "github.com/sagernet/sing/common/metadata" - aTLS "github.com/sagernet/sing/common/tls" -) - -type QUICConfig interface { - Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error) - DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.EarlyConnection, error) - CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, quicConfig *quic.Config, enableDatagrams bool) http.RoundTripper -} - -type QUICServerConfig interface { - Listen(conn net.PacketConn, config *quic.Config) (QUICListener, error) - ListenEarly(conn net.PacketConn, config *quic.Config) (QUICEarlyListener, error) - ConfigureHTTP3() -} - -type QUICListener interface { - Accept(ctx context.Context) (quic.Connection, error) - Close() error - Addr() net.Addr -} - -type QUICEarlyListener interface { - Accept(ctx context.Context) (quic.EarlyConnection, error) - Close() error - Addr() net.Addr -} - -func Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config aTLS.Config, quicConfig *quic.Config) (quic.Connection, error) { - if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig { - return quicTLSConfig.Dial(ctx, conn, addr, quicConfig) - } - tlsConfig, err := config.Config() - if err != nil { - return nil, err - } - return quic.Dial(ctx, conn, addr, tlsConfig, quicConfig) -} - -func DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config aTLS.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { - if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig { - return quicTLSConfig.DialEarly(ctx, conn, addr, quicConfig) - } - tlsConfig, err := config.Config() - if err != nil { - return nil, err - } - return quic.DialEarly(ctx, conn, addr, tlsConfig, quicConfig) -} - -func CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, config aTLS.Config, quicConfig *quic.Config, enableDatagrams bool) (http.RoundTripper, error) { - if quicTLSConfig, isQUICConfig := config.(QUICConfig); isQUICConfig { - return quicTLSConfig.CreateTransport(conn, quicConnPtr, serverAddr, quicConfig, enableDatagrams), nil - } - tlsConfig, err := config.Config() - if err != nil { - return nil, err - } - return &http3.RoundTripper{ - TLSClientConfig: tlsConfig, - QuicConfig: quicConfig, - EnableDatagrams: enableDatagrams, - Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - quicConn, err := quic.DialEarly(ctx, conn, serverAddr.UDPAddr(), tlsCfg, cfg) - if err != nil { - return nil, err - } - *quicConnPtr = quicConn - return quicConn, nil - }, - }, nil -} - -func Listen(conn net.PacketConn, config aTLS.ServerConfig, quicConfig *quic.Config) (QUICListener, error) { - if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig { - return quicTLSConfig.Listen(conn, quicConfig) - } - tlsConfig, err := config.Config() - if err != nil { - return nil, err - } - return quic.Listen(conn, tlsConfig, quicConfig) -} - -func ListenEarly(conn net.PacketConn, config aTLS.ServerConfig, quicConfig *quic.Config) (QUICEarlyListener, error) { - if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig { - return quicTLSConfig.ListenEarly(conn, quicConfig) - } - tlsConfig, err := config.Config() - if err != nil { - return nil, err - } - return quic.ListenEarly(conn, tlsConfig, quicConfig) -} - -func ConfigureHTTP3(config aTLS.ServerConfig) error { - if len(config.NextProtos()) == 0 { - config.SetNextProtos([]string{http3.NextProtoH3}) - } - if quicTLSConfig, isQUICConfig := config.(QUICServerConfig); isQUICConfig { - quicTLSConfig.ConfigureHTTP3() - return nil - } - tlsConfig, err := config.Config() - if err != nil { - return err - } - http3.ConfigureTLSConfig(tlsConfig) - return nil -} diff --git a/common/tls/ech_quic.go b/common/tls/ech_quic.go index 4ed4cec1..b9cc5ede 100644 --- a/common/tls/ech_quic.go +++ b/common/tls/ech_quic.go @@ -10,13 +10,13 @@ import ( "github.com/sagernet/cloudflare-tls" "github.com/sagernet/quic-go/ech" "github.com/sagernet/quic-go/http3_ech" - "github.com/sagernet/sing-box/common/qtls" + "github.com/sagernet/sing-quic" M "github.com/sagernet/sing/common/metadata" ) var ( - _ qtls.QUICConfig = (*echClientConfig)(nil) - _ qtls.QUICServerConfig = (*echServerConfig)(nil) + _ qtls.Config = (*echClientConfig)(nil) + _ qtls.ServerConfig = (*echServerConfig)(nil) ) func (c *echClientConfig) Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error) { @@ -43,11 +43,11 @@ func (c *echClientConfig) CreateTransport(conn net.PacketConn, quicConnPtr *quic } } -func (c *echServerConfig) Listen(conn net.PacketConn, config *quic.Config) (qtls.QUICListener, error) { +func (c *echServerConfig) Listen(conn net.PacketConn, config *quic.Config) (qtls.Listener, error) { return quic.Listen(conn, c.config, config) } -func (c *echServerConfig) ListenEarly(conn net.PacketConn, config *quic.Config) (qtls.QUICEarlyListener, error) { +func (c *echServerConfig) ListenEarly(conn net.PacketConn, config *quic.Config) (qtls.EarlyListener, error) { return quic.ListenEarly(conn, c.config, config) } diff --git a/go.mod b/go.mod index e749e371..bb884db9 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/sagernet/sing v0.2.10-0.20230912050851-1453c7c8c20d github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400 + github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703 github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0 github.com/sagernet/sing-shadowsocks2 v0.1.4-0.20230907005906-5d2917b29248 github.com/sagernet/sing-shadowtls v0.1.4 diff --git a/go.sum b/go.sum index 46261ee9..06278cd0 100644 --- a/go.sum +++ b/go.sum @@ -118,6 +118,8 @@ github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b h1:m/UWg2voyb9 github.com/sagernet/sing-dns v0.1.9-0.20230911082806-425022bdc92b/go.mod h1:Kg98PBJEg/08jsNFtmZWmPomhskn9Ausn50ecNm4M+8= github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400 h1:LtpYd5c5AJtUSxjyH4KjUS8HT+2XgilyozjbCq/x3EM= github.com/sagernet/sing-mux v0.1.3-0.20230908032617-759a1886a400/go.mod h1:TKxqIvfQQgd36jp2tzsPavGjYTVZilV+atip1cssjIY= +github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703 h1:BbJZ5RkY3jQk5P9G5Ra0VhmDNKdT0aIP1FszEDyQL+o= +github.com/sagernet/sing-quic v0.0.0-20230915093242-b55f3531e703/go.mod h1:Mh5Senu4XDuX+RxSPQEoUB0j6kVmGais2h62Cnfj6Xk= github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0 h1:9wHYWxH+fcs01PM2+DylA8LNNY3ElnZykQo9rysng8U= github.com/sagernet/sing-shadowsocks v0.2.5-0.20230907005610-126234728ca0/go.mod h1:80fNKP0wnqlu85GZXV1H1vDPC/2t+dQbFggOw4XuFUM= github.com/sagernet/sing-shadowsocks2 v0.1.4-0.20230907005906-5d2917b29248 h1:JTFfy/LDmVFEK4KZJEujmC1iO8+aoF4unYhhZZRzRq4= diff --git a/inbound/hysteria.go b/inbound/hysteria.go index 93a32e9d..8087f6f5 100644 --- a/inbound/hysteria.go +++ b/inbound/hysteria.go @@ -9,12 +9,12 @@ import ( "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/congestion" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/qtls" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/hysteria" + "github.com/sagernet/sing-quic" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" E "github.com/sagernet/sing/common/exceptions" @@ -36,7 +36,7 @@ type Hysteria struct { xplusKey []byte sendBPS uint64 recvBPS uint64 - listener qtls.QUICListener + listener qtls.Listener udpAccess sync.RWMutex udpSessionId uint32 udpSessions map[uint32]chan *hysteria.UDPMessage diff --git a/inbound/hysteria2.go b/inbound/hysteria2.go index 4b7bb9a1..f2726e67 100644 --- a/inbound/hysteria2.go +++ b/inbound/hysteria2.go @@ -14,7 +14,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/hysteria2" + "github.com/sagernet/sing-quic/hysteria2" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" E "github.com/sagernet/sing/common/exceptions" @@ -25,8 +25,9 @@ var _ adapter.Inbound = (*Hysteria2)(nil) type Hysteria2 struct { myInboundAdapter - tlsConfig tls.ServerConfig - server *hysteria2.Server + tlsConfig tls.ServerConfig + service *hysteria2.Service[int] + userNameList []string } func NewHysteria2(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (*Hysteria2, error) { @@ -84,16 +85,13 @@ func NewHysteria2(ctx context.Context, router adapter.Router, logger log.Context }, tlsConfig: tlsConfig, } - server, err := hysteria2.NewServer(hysteria2.ServerOptions{ - Context: ctx, - Logger: logger, - SendBPS: uint64(options.UpMbps * 1024 * 1024), - ReceiveBPS: uint64(options.DownMbps * 1024 * 1024), - SalamanderPassword: salamanderPassword, - TLSConfig: tlsConfig, - Users: common.Map(options.Users, func(it option.Hysteria2User) hysteria2.User { - return hysteria2.User(it) - }), + service, err := hysteria2.NewService[int](hysteria2.ServiceOptions{ + Context: ctx, + Logger: logger, + SendBPS: uint64(options.UpMbps * 1024 * 1024), + ReceiveBPS: uint64(options.DownMbps * 1024 * 1024), + SalamanderPassword: salamanderPassword, + TLSConfig: tlsConfig, IgnoreClientBandwidth: options.IgnoreClientBandwidth, Handler: adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newConnection, inbound.newPacketConnection, nil), MasqueradeHandler: masqueradeHandler, @@ -101,7 +99,17 @@ func NewHysteria2(ctx context.Context, router adapter.Router, logger log.Context if err != nil { return nil, err } - inbound.server = server + userList := make([]int, 0, len(options.Users)) + userNameList := make([]string, 0, len(options.Users)) + userPasswordList := make([]string, 0, len(options.Users)) + for index, user := range options.Users { + userList = append(userList, index) + userNameList = append(userNameList, user.Name) + userPasswordList = append(userPasswordList, user.Password) + } + service.UpdateUsers(userList, userPasswordList) + inbound.service = service + inbound.userNameList = userNameList return inbound, nil } @@ -109,14 +117,20 @@ func (h *Hysteria2) newConnection(ctx context.Context, conn net.Conn, metadata a ctx = log.ContextWithNewID(ctx) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) metadata = h.createMetadata(conn, metadata) - metadata.User, _ = auth.UserFromContext[string](ctx) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + } return h.router.RouteConnection(ctx, conn, metadata) } func (h *Hysteria2) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { ctx = log.ContextWithNewID(ctx) metadata = h.createPacketMetadata(conn, metadata) - metadata.User, _ = auth.UserFromContext[string](ctx) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + } h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) return h.router.RoutePacketConnection(ctx, conn, metadata) } @@ -132,13 +146,13 @@ func (h *Hysteria2) Start() error { if err != nil { return err } - return h.server.Start(packetConn) + return h.service.Start(packetConn) } func (h *Hysteria2) Close() error { return common.Close( &h.myInboundAdapter, h.tlsConfig, - common.PtrOrNil(h.server), + common.PtrOrNil(h.service), ) } diff --git a/inbound/naive_quic.go b/inbound/naive_quic.go index 7a17b01f..9f99bf27 100644 --- a/inbound/naive_quic.go +++ b/inbound/naive_quic.go @@ -5,7 +5,7 @@ package inbound import ( "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/http3" - "github.com/sagernet/sing-box/common/qtls" + "github.com/sagernet/sing-quic" E "github.com/sagernet/sing/common/exceptions" ) diff --git a/inbound/tuic.go b/inbound/tuic.go index a8547bf6..e6714f0d 100644 --- a/inbound/tuic.go +++ b/inbound/tuic.go @@ -12,7 +12,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/tuic" + "github.com/sagernet/sing-quic/tuic" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" E "github.com/sagernet/sing/common/exceptions" @@ -25,8 +25,9 @@ var _ adapter.Inbound = (*TUIC)(nil) type TUIC struct { myInboundAdapter - server *tuic.Server - tlsConfig tls.ServerConfig + tlsConfig tls.ServerConfig + server *tuic.Service[int] + userNameList []string } func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TUICInboundOptions) (*TUIC, error) { @@ -38,17 +39,6 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge if err != nil { return nil, err } - var users []tuic.User - for index, user := range options.Users { - if user.UUID == "" { - return nil, E.New("missing uuid for user ", index) - } - userUUID, err := uuid.FromString(user.UUID) - if err != nil { - return nil, E.Cause(err, "invalid uuid for user ", index) - } - users = append(users, tuic.User{Name: user.Name, UUID: userUUID, Password: user.Password}) - } inbound := &TUIC{ myInboundAdapter: myInboundAdapter{ protocol: C.TypeTUIC, @@ -60,11 +50,10 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge listenOptions: options.ListenOptions, }, } - server, err := tuic.NewServer(tuic.ServerOptions{ + service, err := tuic.NewService[int](tuic.ServiceOptions{ Context: ctx, Logger: logger, TLSConfig: tlsConfig, - Users: users, CongestionControl: options.CongestionControl, AuthTimeout: time.Duration(options.AuthTimeout), ZeroRTTHandshake: options.ZeroRTTHandshake, @@ -74,7 +63,26 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge if err != nil { return nil, err } - inbound.server = server + var userList []int + var userNameList []string + var userUUIDList [][16]byte + var userPasswordList []string + for index, user := range options.Users { + if user.UUID == "" { + return nil, E.New("missing uuid for user ", index) + } + userUUID, err := uuid.FromString(user.UUID) + if err != nil { + return nil, E.Cause(err, "invalid uuid for user ", index) + } + userList = append(userList, index) + userNameList = append(userNameList, user.Name) + userUUIDList = append(userUUIDList, userUUID) + userPasswordList = append(userPasswordList, user.Password) + } + service.UpdateUsers(userList, userUUIDList, userPasswordList) + inbound.server = service + inbound.userNameList = userNameList return inbound, nil } @@ -82,14 +90,20 @@ func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapte ctx = log.ContextWithNewID(ctx) h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) metadata = h.createMetadata(conn, metadata) - metadata.User, _ = auth.UserFromContext[string](ctx) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + } return h.router.RouteConnection(ctx, conn, metadata) } func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { ctx = log.ContextWithNewID(ctx) metadata = h.createPacketMetadata(conn, metadata) - metadata.User, _ = auth.UserFromContext[string](ctx) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + } h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) return h.router.RoutePacketConnection(ctx, conn, metadata) } diff --git a/outbound/hysteria.go b/outbound/hysteria.go index c236f759..dfe26996 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -11,12 +11,12 @@ import ( "github.com/sagernet/quic-go/congestion" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" - "github.com/sagernet/sing-box/common/qtls" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/hysteria" + "github.com/sagernet/sing-quic" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" diff --git a/outbound/hysteria2.go b/outbound/hysteria2.go index f974e9a8..9bd4b310 100644 --- a/outbound/hysteria2.go +++ b/outbound/hysteria2.go @@ -13,7 +13,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/hysteria2" + "github.com/sagernet/sing-quic/hysteria2" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" diff --git a/outbound/tuic.go b/outbound/tuic.go index e8c3f700..c0983323 100644 --- a/outbound/tuic.go +++ b/outbound/tuic.go @@ -14,7 +14,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/transport/tuic" + "github.com/sagernet/sing-quic/tuic" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" diff --git a/transport/hysteria2/client.go b/transport/hysteria2/client.go deleted file mode 100644 index 62471ef7..00000000 --- a/transport/hysteria2/client.go +++ /dev/null @@ -1,314 +0,0 @@ -package hysteria2 - -import ( - "context" - "io" - "net" - "net/http" - "net/url" - "os" - "runtime" - "sync" - "time" - - "github.com/sagernet/quic-go" - "github.com/sagernet/sing-box/common/qtls" - "github.com/sagernet/sing-box/common/tls" - "github.com/sagernet/sing-box/transport/hysteria2/congestion" - "github.com/sagernet/sing-box/transport/hysteria2/internal/protocol" - tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion" - "github.com/sagernet/sing/common/baderror" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -const ( - defaultStreamReceiveWindow = 8388608 // 8MB - defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB - defaultMaxIdleTimeout = 30 * time.Second - defaultKeepAlivePeriod = 10 * time.Second -) - -type ClientOptions struct { - Context context.Context - Dialer N.Dialer - ServerAddress M.Socksaddr - SendBPS uint64 - ReceiveBPS uint64 - SalamanderPassword string - Password string - TLSConfig tls.Config - UDPDisabled bool -} - -type Client struct { - ctx context.Context - dialer N.Dialer - serverAddr M.Socksaddr - sendBPS uint64 - receiveBPS uint64 - salamanderPassword string - password string - tlsConfig tls.Config - quicConfig *quic.Config - udpDisabled bool - - connAccess sync.RWMutex - conn *clientQUICConnection -} - -func NewClient(options ClientOptions) (*Client, error) { - quicConfig := &quic.Config{ - DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), - EnableDatagrams: true, - InitialStreamReceiveWindow: defaultStreamReceiveWindow, - MaxStreamReceiveWindow: defaultStreamReceiveWindow, - InitialConnectionReceiveWindow: defaultConnReceiveWindow, - MaxConnectionReceiveWindow: defaultConnReceiveWindow, - MaxIdleTimeout: defaultMaxIdleTimeout, - KeepAlivePeriod: defaultKeepAlivePeriod, - } - return &Client{ - ctx: options.Context, - dialer: options.Dialer, - serverAddr: options.ServerAddress, - sendBPS: options.SendBPS, - receiveBPS: options.ReceiveBPS, - salamanderPassword: options.SalamanderPassword, - password: options.Password, - tlsConfig: options.TLSConfig, - quicConfig: quicConfig, - udpDisabled: options.UDPDisabled, - }, nil -} - -func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { - conn := c.conn - if conn != nil && conn.active() { - return conn, nil - } - c.connAccess.Lock() - defer c.connAccess.Unlock() - conn = c.conn - if conn != nil && conn.active() { - return conn, nil - } - conn, err := c.offerNew(ctx) - if err != nil { - return nil, err - } - return conn, nil -} - -func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { - udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr) - if err != nil { - return nil, err - } - var packetConn net.PacketConn - packetConn = bufio.NewUnbindPacketConn(udpConn) - if c.salamanderPassword != "" { - packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword)) - } - var quicConn quic.EarlyConnection - http3Transport, err := qtls.CreateTransport(packetConn, &quicConn, c.serverAddr, c.tlsConfig, c.quicConfig, true) - if err != nil { - udpConn.Close() - return nil, err - } - request := &http.Request{ - Method: http.MethodPost, - URL: &url.URL{ - Scheme: "https", - Host: protocol.URLHost, - Path: protocol.URLPath, - }, - Header: make(http.Header), - } - protocol.AuthRequestToHeader(request.Header, protocol.AuthRequest{Auth: c.password, Rx: c.receiveBPS}) - response, err := http3Transport.RoundTrip(request.WithContext(ctx)) - if err != nil { - if quicConn != nil { - quicConn.CloseWithError(0, "") - } - udpConn.Close() - return nil, err - } - if response.StatusCode != protocol.StatusAuthOK { - if quicConn != nil { - quicConn.CloseWithError(0, "") - } - udpConn.Close() - return nil, E.New("authentication failed, status code: ", response.StatusCode) - } - response.Body.Close() - authResponse := protocol.AuthResponseFromHeader(response.Header) - actualTx := authResponse.Rx - if actualTx == 0 || actualTx > c.sendBPS { - actualTx = c.sendBPS - } - if !authResponse.RxAuto && actualTx > 0 { - quicConn.SetCongestionControl(congestion.NewBrutalSender(actualTx)) - } else { - quicConn.SetCongestionControl(tuicCongestion.NewBBRSender( - tuicCongestion.DefaultClock{}, - tuicCongestion.GetInitialPacketSize(quicConn.RemoteAddr()), - tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize, - tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize, - )) - } - conn := &clientQUICConnection{ - quicConn: quicConn, - rawConn: udpConn, - connDone: make(chan struct{}), - udpDisabled: c.udpDisabled || !authResponse.UDPEnabled, - udpConnMap: make(map[uint32]*udpPacketConn), - } - if !c.udpDisabled { - go c.loopMessages(conn) - } - c.conn = conn - return conn, nil -} - -func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { - conn, err := c.offer(ctx) - if err != nil { - return nil, err - } - stream, err := conn.quicConn.OpenStream() - if err != nil { - return nil, err - } - return &clientConn{ - Stream: stream, - destination: destination, - }, nil -} - -func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { - if c.udpDisabled { - return nil, os.ErrInvalid - } - conn, err := c.offer(ctx) - if err != nil { - return nil, err - } - if conn.udpDisabled { - return nil, E.New("UDP disabled by server") - } - var sessionID uint32 - clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, func() { - conn.udpAccess.Lock() - delete(conn.udpConnMap, sessionID) - conn.udpAccess.Unlock() - }) - conn.udpAccess.Lock() - sessionID = conn.udpSessionID - conn.udpSessionID++ - conn.udpConnMap[sessionID] = clientPacketConn - conn.udpAccess.Unlock() - clientPacketConn.sessionID = sessionID - return clientPacketConn, nil -} - -func (c *Client) CloseWithError(err error) error { - conn := c.conn - if conn != nil { - conn.closeWithError(err) - } - return nil -} - -type clientQUICConnection struct { - quicConn quic.Connection - rawConn io.Closer - closeOnce sync.Once - connDone chan struct{} - connErr error - udpDisabled bool - udpAccess sync.RWMutex - udpConnMap map[uint32]*udpPacketConn - udpSessionID uint32 -} - -func (c *clientQUICConnection) active() bool { - select { - case <-c.quicConn.Context().Done(): - return false - default: - } - select { - case <-c.connDone: - return false - default: - } - return true -} - -func (c *clientQUICConnection) closeWithError(err error) { - c.closeOnce.Do(func() { - c.connErr = err - close(c.connDone) - c.quicConn.CloseWithError(0, "") - }) -} - -type clientConn struct { - quic.Stream - destination M.Socksaddr - requestWritten bool - responseRead bool -} - -func (c *clientConn) NeedHandshake() bool { - return !c.requestWritten -} - -func (c *clientConn) Read(p []byte) (n int, err error) { - if c.responseRead { - n, err = c.Stream.Read(p) - return n, baderror.WrapQUIC(err) - } - status, errorMessage, err := protocol.ReadTCPResponse(c.Stream) - if err != nil { - return 0, baderror.WrapQUIC(err) - } - if !status { - err = E.New("remote error: ", errorMessage) - return - } - c.responseRead = true - n, err = c.Stream.Read(p) - return n, baderror.WrapQUIC(err) -} - -func (c *clientConn) Write(p []byte) (n int, err error) { - if !c.requestWritten { - buffer := protocol.WriteTCPRequest(c.destination.String(), p) - defer buffer.Release() - _, err = c.Stream.Write(buffer.Bytes()) - if err != nil { - return - } - c.requestWritten = true - return len(p), nil - } - n, err = c.Stream.Write(p) - return n, baderror.WrapQUIC(err) -} - -func (c *clientConn) LocalAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *clientConn) RemoteAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *clientConn) Close() error { - c.Stream.CancelRead(0) - return c.Stream.Close() -} diff --git a/transport/hysteria2/client_paclet.go b/transport/hysteria2/client_paclet.go deleted file mode 100644 index 21198bf5..00000000 --- a/transport/hysteria2/client_paclet.go +++ /dev/null @@ -1,47 +0,0 @@ -package hysteria2 - -import E "github.com/sagernet/sing/common/exceptions" - -func (c *Client) loopMessages(conn *clientQUICConnection) { - for { - message, err := conn.quicConn.ReceiveMessage(c.ctx) - if err != nil { - conn.closeWithError(E.Cause(err, "receive message")) - return - } - go func() { - hErr := c.handleMessage(conn, message) - if hErr != nil { - conn.closeWithError(E.Cause(hErr, "handle message")) - } - }() - } -} - -func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { - message := allocMessage() - err := decodeUDPMessage(message, data) - if err != nil { - message.release() - return E.Cause(err, "decode UDP message") - } - conn.handleUDPMessage(message) - return nil -} - -func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) { - c.udpAccess.RLock() - udpConn, loaded := c.udpConnMap[message.sessionID] - c.udpAccess.RUnlock() - if !loaded { - message.releaseMessage() - return - } - select { - case <-udpConn.ctx.Done(): - message.releaseMessage() - return - default: - } - udpConn.inputPacket(message) -} diff --git a/transport/hysteria2/congestion/brutal.go b/transport/hysteria2/congestion/brutal.go deleted file mode 100644 index c52350c8..00000000 --- a/transport/hysteria2/congestion/brutal.go +++ /dev/null @@ -1,151 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/sagernet/quic-go/congestion" -) - -const ( - initMaxDatagramSize = 1252 - - pktInfoSlotCount = 4 - minSampleCount = 50 - minAckRate = 0.8 -) - -var _ congestion.CongestionControl = &BrutalSender{} - -type BrutalSender struct { - rttStats congestion.RTTStatsProvider - bps congestion.ByteCount - maxDatagramSize congestion.ByteCount - pacer *pacer - - pktInfoSlots [pktInfoSlotCount]pktInfo - ackRate float64 -} - -type pktInfo struct { - Timestamp int64 - AckCount uint64 - LossCount uint64 -} - -func NewBrutalSender(bps uint64) *BrutalSender { - bs := &BrutalSender{ - bps: congestion.ByteCount(bps), - maxDatagramSize: initMaxDatagramSize, - ackRate: 1, - } - bs.pacer = newPacer(func() congestion.ByteCount { - return congestion.ByteCount(float64(bs.bps) / bs.ackRate) - }) - return bs -} - -func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) { - b.rttStats = rttStats -} - -func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { - return b.pacer.TimeUntilSend() -} - -func (b *BrutalSender) HasPacingBudget(now time.Time) bool { - return b.pacer.Budget(now) >= b.maxDatagramSize -} - -func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { - return bytesInFlight < b.GetCongestionWindow() -} - -func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { - rtt := b.rttStats.SmoothedRTT() - if rtt <= 0 { - return 10240 - } - return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate) -} - -func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, - packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool, -) { - b.pacer.SentPacket(sentTime, bytes) -} - -func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, - priorInFlight congestion.ByteCount, eventTime time.Time, -) { - currentTimestamp := eventTime.Unix() - slot := currentTimestamp % pktInfoSlotCount - if b.pktInfoSlots[slot].Timestamp == currentTimestamp { - b.pktInfoSlots[slot].AckCount++ - } else { - // uninitialized slot or too old, reset - b.pktInfoSlots[slot].Timestamp = currentTimestamp - b.pktInfoSlots[slot].AckCount = 1 - b.pktInfoSlots[slot].LossCount = 0 - } - b.updateAckRate(currentTimestamp) -} - -func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount, - priorInFlight congestion.ByteCount, -) { - currentTimestamp := time.Now().Unix() - slot := currentTimestamp % pktInfoSlotCount - if b.pktInfoSlots[slot].Timestamp == currentTimestamp { - b.pktInfoSlots[slot].LossCount++ - } else { - // uninitialized slot or too old, reset - b.pktInfoSlots[slot].Timestamp = currentTimestamp - b.pktInfoSlots[slot].AckCount = 0 - b.pktInfoSlots[slot].LossCount = 1 - } - b.updateAckRate(currentTimestamp) -} - -func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { - b.maxDatagramSize = size - b.pacer.SetMaxDatagramSize(size) -} - -func (b *BrutalSender) updateAckRate(currentTimestamp int64) { - minTimestamp := currentTimestamp - pktInfoSlotCount - var ackCount, lossCount uint64 - for _, info := range b.pktInfoSlots { - if info.Timestamp < minTimestamp { - continue - } - ackCount += info.AckCount - lossCount += info.LossCount - } - if ackCount+lossCount < minSampleCount { - b.ackRate = 1 - } - rate := float64(ackCount) / float64(ackCount+lossCount) - if rate < minAckRate { - b.ackRate = minAckRate - } - b.ackRate = rate -} - -func (b *BrutalSender) InSlowStart() bool { - return false -} - -func (b *BrutalSender) InRecovery() bool { - return false -} - -func (b *BrutalSender) MaybeExitSlowStart() {} - -func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} - -func maxDuration(a, b time.Duration) time.Duration { - if a > b { - return a - } - return b -} diff --git a/transport/hysteria2/congestion/pacer.go b/transport/hysteria2/congestion/pacer.go deleted file mode 100644 index 878985e5..00000000 --- a/transport/hysteria2/congestion/pacer.go +++ /dev/null @@ -1,86 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/sagernet/quic-go/congestion" -) - -const ( - maxBurstPackets = 10 - minPacingDelay = time.Millisecond -) - -// The pacer implements a token bucket pacing algorithm. -type pacer struct { - budgetAtLastSent congestion.ByteCount - maxDatagramSize congestion.ByteCount - lastSentTime time.Time - getBandwidth func() congestion.ByteCount // in bytes/s -} - -func newPacer(getBandwidth func() congestion.ByteCount) *pacer { - p := &pacer{ - budgetAtLastSent: maxBurstPackets * initMaxDatagramSize, - maxDatagramSize: initMaxDatagramSize, - getBandwidth: getBandwidth, - } - return p -} - -func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { - budget := p.Budget(sendTime) - if size > budget { - p.budgetAtLastSent = 0 - } else { - p.budgetAtLastSent = budget - size - } - p.lastSentTime = sendTime -} - -func (p *pacer) Budget(now time.Time) congestion.ByteCount { - if p.lastSentTime.IsZero() { - return p.maxBurstSize() - } - budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 - return minByteCount(p.maxBurstSize(), budget) -} - -func (p *pacer) maxBurstSize() congestion.ByteCount { - return maxByteCount( - congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9, - maxBurstPackets*p.maxDatagramSize, - ) -} - -// TimeUntilSend returns when the next packet should be sent. -// It returns the zero value of time.Time if a packet can be sent immediately. -func (p *pacer) TimeUntilSend() time.Time { - if p.budgetAtLastSent >= p.maxDatagramSize { - return time.Time{} - } - return p.lastSentTime.Add(maxDuration( - minPacingDelay, - time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/ - float64(p.getBandwidth())))*time.Nanosecond, - )) -} - -func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) { - p.maxDatagramSize = s -} - -func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount { - if a < b { - return b - } - return a -} - -func minByteCount(a, b congestion.ByteCount) congestion.ByteCount { - if a < b { - return a - } - return b -} diff --git a/transport/hysteria2/internal/protocol/http.go b/transport/hysteria2/internal/protocol/http.go deleted file mode 100644 index abcc1a4f..00000000 --- a/transport/hysteria2/internal/protocol/http.go +++ /dev/null @@ -1,68 +0,0 @@ -package protocol - -import ( - "net/http" - "strconv" -) - -const ( - URLHost = "hysteria" - URLPath = "/auth" - - RequestHeaderAuth = "Hysteria-Auth" - ResponseHeaderUDPEnabled = "Hysteria-UDP" - CommonHeaderCCRX = "Hysteria-CC-RX" - CommonHeaderPadding = "Hysteria-Padding" - - StatusAuthOK = 233 -) - -// AuthRequest is what client sends to server for authentication. -type AuthRequest struct { - Auth string - Rx uint64 // 0 = unknown, client asks server to use bandwidth detection -} - -// AuthResponse is what server sends to client when authentication is passed. -type AuthResponse struct { - UDPEnabled bool - Rx uint64 // 0 = unlimited - RxAuto bool // true = server asks client to use bandwidth detection -} - -func AuthRequestFromHeader(h http.Header) AuthRequest { - rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) - return AuthRequest{ - Auth: h.Get(RequestHeaderAuth), - Rx: rx, - } -} - -func AuthRequestToHeader(h http.Header, req AuthRequest) { - h.Set(RequestHeaderAuth, req.Auth) - h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10)) - h.Set(CommonHeaderPadding, authRequestPadding.String()) -} - -func AuthResponseFromHeader(h http.Header) AuthResponse { - resp := AuthResponse{} - resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) - rxStr := h.Get(CommonHeaderCCRX) - if rxStr == "auto" { - // Special case for server requesting client to use bandwidth detection - resp.RxAuto = true - } else { - resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64) - } - return resp -} - -func AuthResponseToHeader(h http.Header, resp AuthResponse) { - h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled)) - if resp.RxAuto { - h.Set(CommonHeaderCCRX, "auto") - } else { - h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10)) - } - h.Set(CommonHeaderPadding, authResponsePadding.String()) -} diff --git a/transport/hysteria2/internal/protocol/padding.go b/transport/hysteria2/internal/protocol/padding.go deleted file mode 100644 index 9895cdcc..00000000 --- a/transport/hysteria2/internal/protocol/padding.go +++ /dev/null @@ -1,31 +0,0 @@ -package protocol - -import ( - "math/rand" -) - -const ( - paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -) - -// padding specifies a half-open range [Min, Max). -type padding struct { - Min int - Max int -} - -func (p padding) String() string { - n := p.Min + rand.Intn(p.Max-p.Min) - bs := make([]byte, n) - for i := range bs { - bs[i] = paddingChars[rand.Intn(len(paddingChars))] - } - return string(bs) -} - -var ( - authRequestPadding = padding{Min: 256, Max: 2048} - authResponsePadding = padding{Min: 256, Max: 2048} - tcpRequestPadding = padding{Min: 64, Max: 512} - tcpResponsePadding = padding{Min: 128, Max: 1024} -) diff --git a/transport/hysteria2/internal/protocol/proxy.go b/transport/hysteria2/internal/protocol/proxy.go deleted file mode 100644 index 795b3cbf..00000000 --- a/transport/hysteria2/internal/protocol/proxy.go +++ /dev/null @@ -1,266 +0,0 @@ -package protocol - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - - "github.com/sagernet/quic-go/quicvarint" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" -) - -const ( - FrameTypeTCPRequest = 0x401 - - // Max length values are for preventing DoS attacks - - MaxAddressLength = 2048 - MaxMessageLength = 2048 - MaxPaddingLength = 4096 - - MaxUDPSize = 4096 - - maxVarInt1 = 63 - maxVarInt2 = 16383 - maxVarInt4 = 1073741823 - maxVarInt8 = 4611686018427387903 -) - -// TCPRequest format: -// 0x401 (QUIC varint) -// Address length (QUIC varint) -// Address (bytes) -// Padding length (QUIC varint) -// Padding (bytes) - -func ReadTCPRequest(r io.Reader) (string, error) { - bReader := quicvarint.NewReader(r) - addrLen, err := quicvarint.Read(bReader) - if err != nil { - return "", err - } - if addrLen == 0 || addrLen > MaxAddressLength { - return "", E.New("invalid address length") - } - addrBuf := make([]byte, addrLen) - _, err = io.ReadFull(r, addrBuf) - if err != nil { - return "", err - } - paddingLen, err := quicvarint.Read(bReader) - if err != nil { - return "", err - } - if paddingLen > MaxPaddingLength { - return "", E.New("invalid padding length") - } - if paddingLen > 0 { - _, err = io.CopyN(io.Discard, r, int64(paddingLen)) - if err != nil { - return "", err - } - } - return string(addrBuf), nil -} - -func WriteTCPRequest(addr string, payload []byte) *buf.Buffer { - padding := tcpRequestPadding.String() - paddingLen := len(padding) - addrLen := len(addr) - sz := int(quicvarint.Len(FrameTypeTCPRequest)) + - int(quicvarint.Len(uint64(addrLen))) + addrLen + - int(quicvarint.Len(uint64(paddingLen))) + paddingLen - buffer := buf.NewSize(sz + len(payload)) - bufferContent := buffer.Extend(sz) - i := varintPut(bufferContent, FrameTypeTCPRequest) - i += varintPut(bufferContent[i:], uint64(addrLen)) - i += copy(bufferContent[i:], addr) - i += varintPut(bufferContent[i:], uint64(paddingLen)) - copy(bufferContent[i:], padding) - buffer.Write(payload) - return buffer -} - -// TCPResponse format: -// Status (byte, 0=ok, 1=error) -// Message length (QUIC varint) -// Message (bytes) -// Padding length (QUIC varint) -// Padding (bytes) - -func ReadTCPResponse(r io.Reader) (bool, string, error) { - var status [1]byte - if _, err := io.ReadFull(r, status[:]); err != nil { - return false, "", err - } - bReader := quicvarint.NewReader(r) - msg, err := ReadVString(bReader) - if err != nil { - return false, "", err - } - paddingLen, err := quicvarint.Read(bReader) - if err != nil { - return false, "", err - } - if paddingLen > MaxPaddingLength { - return false, "", E.New("invalid padding length") - } - if paddingLen > 0 { - _, err = io.CopyN(io.Discard, r, int64(paddingLen)) - if err != nil { - return false, "", err - } - } - return status[0] == 0, msg, nil -} - -func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer { - padding := tcpResponsePadding.String() - paddingLen := len(padding) - msgLen := len(msg) - sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + - int(quicvarint.Len(uint64(paddingLen))) + paddingLen - buffer := buf.NewSize(sz + len(payload)) - if ok { - buffer.WriteByte(0) - } else { - buffer.WriteByte(1) - } - WriteVString(buffer, msg) - WriteUVariant(buffer, uint64(paddingLen)) - buffer.Extend(paddingLen) - buffer.Write(payload) - return buffer -} - -// UDPMessage format: -// Session ID (uint32 BE) -// Packet ID (uint16 BE) -// Fragment ID (uint8) -// Fragment count (uint8) -// Address length (QUIC varint) -// Address (bytes) -// Data... - -type UDPMessage struct { - SessionID uint32 // 4 - PacketID uint16 // 2 - FragID uint8 // 1 - FragCount uint8 // 1 - Addr string // varint + bytes - Data []byte -} - -func (m *UDPMessage) HeaderSize() int { - lAddr := len(m.Addr) - return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr -} - -func (m *UDPMessage) Size() int { - return m.HeaderSize() + len(m.Data) -} - -func (m *UDPMessage) Serialize(buf []byte) int { - // Make sure the buffer is big enough - if len(buf) < m.Size() { - return -1 - } - binary.BigEndian.PutUint32(buf, m.SessionID) - binary.BigEndian.PutUint16(buf[4:], m.PacketID) - buf[6] = m.FragID - buf[7] = m.FragCount - i := varintPut(buf[8:], uint64(len(m.Addr))) - i += copy(buf[8+i:], m.Addr) - i += copy(buf[8+i:], m.Data) - return 8 + i -} - -func ParseUDPMessage(msg []byte) (*UDPMessage, error) { - m := &UDPMessage{} - buf := bytes.NewBuffer(msg) - if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil { - return nil, err - } - if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil { - return nil, err - } - if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil { - return nil, err - } - if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil { - return nil, err - } - lAddr, err := quicvarint.Read(buf) - if err != nil { - return nil, err - } - if lAddr == 0 || lAddr > MaxMessageLength { - return nil, E.New("invalid address length") - } - bs := buf.Bytes() - m.Addr = string(bs[:lAddr]) - m.Data = bs[lAddr:] - return m, nil -} - -func ReadVString(reader io.Reader) (string, error) { - length, err := quicvarint.Read(quicvarint.NewReader(reader)) - if err != nil { - return "", err - } - value, err := rw.ReadBytes(reader, int(length)) - if err != nil { - return "", err - } - return string(value), nil -} - -func WriteVString(writer io.Writer, value string) error { - err := WriteUVariant(writer, uint64(len(value))) - if err != nil { - return err - } - return rw.WriteString(writer, value) -} - -func WriteUVariant(writer io.Writer, value uint64) error { - var b [8]byte - return common.Error(writer.Write(b[:varintPut(b[:], value)])) -} - -// varintPut is like quicvarint.Append, but instead of appending to a slice, -// it writes to a fixed-size buffer. Returns the number of bytes written. -func varintPut(b []byte, i uint64) int { - if i <= maxVarInt1 { - b[0] = uint8(i) - return 1 - } - if i <= maxVarInt2 { - b[0] = uint8(i>>8) | 0x40 - b[1] = uint8(i) - return 2 - } - if i <= maxVarInt4 { - b[0] = uint8(i>>24) | 0x80 - b[1] = uint8(i >> 16) - b[2] = uint8(i >> 8) - b[3] = uint8(i) - return 4 - } - if i <= maxVarInt8 { - b[0] = uint8(i>>56) | 0xc0 - b[1] = uint8(i >> 48) - b[2] = uint8(i >> 40) - b[3] = uint8(i >> 32) - b[4] = uint8(i >> 24) - b[5] = uint8(i >> 16) - b[6] = uint8(i >> 8) - b[7] = uint8(i) - return 8 - } - panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) -} diff --git a/transport/hysteria2/packet.go b/transport/hysteria2/packet.go deleted file mode 100644 index 100a30d6..00000000 --- a/transport/hysteria2/packet.go +++ /dev/null @@ -1,450 +0,0 @@ -package hysteria2 - -import ( - "bytes" - "context" - "encoding/binary" - "errors" - "io" - "math" - "net" - "os" - "sync" - "time" - - "github.com/sagernet/quic-go" - "github.com/sagernet/quic-go/quicvarint" - "github.com/sagernet/sing-box/transport/hysteria2/internal/protocol" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/cache" - M "github.com/sagernet/sing/common/metadata" -) - -var udpMessagePool = sync.Pool{ - New: func() interface{} { - return new(udpMessage) - }, -} - -func allocMessage() *udpMessage { - message := udpMessagePool.Get().(*udpMessage) - message.referenced = true - return message -} - -func releaseMessages(messages []*udpMessage) { - for _, message := range messages { - if message != nil { - message.release() - } - } -} - -type udpMessage struct { - sessionID uint32 - packetID uint16 - fragmentID uint8 - fragmentTotal uint8 - destination string - data *buf.Buffer - referenced bool -} - -func (m *udpMessage) release() { - if !m.referenced { - return - } - *m = udpMessage{} - udpMessagePool.Put(m) -} - -func (m *udpMessage) releaseMessage() { - m.data.Release() - m.release() -} - -func (m *udpMessage) pack() *buf.Buffer { - buffer := buf.NewSize(m.headerSize() + m.data.Len()) - common.Must( - binary.Write(buffer, binary.BigEndian, m.sessionID), - binary.Write(buffer, binary.BigEndian, m.packetID), - binary.Write(buffer, binary.BigEndian, m.fragmentID), - binary.Write(buffer, binary.BigEndian, m.fragmentTotal), - protocol.WriteVString(buffer, m.destination), - common.Error(buffer.Write(m.data.Bytes())), - ) - return buffer -} - -func (m *udpMessage) headerSize() int { - return 8 + int(quicvarint.Len(uint64(len(m.destination)))) + len(m.destination) -} - -func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { - if message.data.Len() <= maxPacketSize { - return []*udpMessage{message} - } - var fragments []*udpMessage - originPacket := message.data.Bytes() - udpMTU := maxPacketSize - message.headerSize() - for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { - fragment := allocMessage() - *fragment = *message - if remaining > udpMTU { - fragment.data = buf.As(originPacket[:udpMTU]) - originPacket = originPacket[udpMTU:] - } else { - fragment.data = buf.As(originPacket) - originPacket = nil - } - fragments = append(fragments, fragment) - } - fragmentTotal := uint16(len(fragments)) - for index, fragment := range fragments { - fragment.fragmentID = uint8(index) - fragment.fragmentTotal = uint8(fragmentTotal) - /*if index > 0 { - fragment.destination = "" - // not work in hysteria - }*/ - } - return fragments -} - -type udpPacketConn struct { - ctx context.Context - cancel common.ContextCancelCauseFunc - sessionID uint32 - quicConn quic.Connection - data chan *udpMessage - udpMTU int - udpMTUTime time.Time - packetId atomic.Uint32 - closeOnce sync.Once - defragger *udpDefragger - onDestroy func() -} - -func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn { - ctx, cancel := common.ContextWithCancelCause(ctx) - return &udpPacketConn{ - ctx: ctx, - cancel: cancel, - quicConn: quicConn, - data: make(chan *udpMessage, 64), - defragger: newUDPDefragger(), - onDestroy: onDestroy, - } -} - -func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case p := <-c.data: - buffer = p.data - destination = M.ParseSocksaddr(p.destination) - p.release() - return - case <-c.ctx.Done(): - return nil, M.Socksaddr{}, io.ErrClosedPipe - } -} - -func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case p := <-c.data: - _, err = buffer.ReadOnceFrom(p.data) - destination = M.ParseSocksaddr(p.destination) - p.releaseMessage() - return - case <-c.ctx.Done(): - return M.Socksaddr{}, io.ErrClosedPipe - } -} - -func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case p := <-c.data: - _, err = newBuffer().ReadOnceFrom(p.data) - destination = M.ParseSocksaddr(p.destination) - p.releaseMessage() - return - case <-c.ctx.Done(): - return M.Socksaddr{}, io.ErrClosedPipe - } -} - -func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - select { - case pkt := <-c.data: - n = copy(p, pkt.data.Bytes()) - destination := M.ParseSocksaddr(pkt.destination) - if destination.IsFqdn() { - addr = destination - } else { - addr = destination.UDPAddr() - } - pkt.releaseMessage() - return n, addr, nil - case <-c.ctx.Done(): - return 0, nil, io.ErrClosedPipe - } -} - -func (c *udpPacketConn) needFragment() bool { - nowTime := time.Now() - if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { - c.udpMTUTime = nowTime - return true - } - return false -} - -func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - defer buffer.Release() - select { - case <-c.ctx.Done(): - return net.ErrClosed - default: - } - if buffer.Len() > 0xffff { - return quic.ErrMessageTooLarge(0xffff) - } - packetId := c.packetId.Add(1) - if packetId > math.MaxUint16 { - c.packetId.Store(0) - packetId = 0 - } - message := allocMessage() - *message = udpMessage{ - sessionID: c.sessionID, - packetID: uint16(packetId), - fragmentTotal: 1, - destination: destination.String(), - data: buffer, - } - defer message.releaseMessage() - var err error - if c.needFragment() && buffer.Len() > c.udpMTU { - err = c.writePackets(fragUDPMessage(message, c.udpMTU)) - } else { - err = c.writePacket(message) - } - if err == nil { - return nil - } - var tooLargeErr quic.ErrMessageTooLarge - if !errors.As(err, &tooLargeErr) { - return err - } - c.udpMTU = int(tooLargeErr) - c.udpMTUTime = time.Now() - return c.writePackets(fragUDPMessage(message, c.udpMTU)) -} - -func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - select { - case <-c.ctx.Done(): - return 0, net.ErrClosed - default: - } - if len(p) > 0xffff { - return 0, quic.ErrMessageTooLarge(0xffff) - } - packetId := c.packetId.Add(1) - if packetId > math.MaxUint16 { - c.packetId.Store(0) - packetId = 0 - } - message := allocMessage() - *message = udpMessage{ - sessionID: c.sessionID, - packetID: uint16(packetId), - fragmentTotal: 1, - destination: addr.String(), - data: buf.As(p), - } - if c.needFragment() && len(p) > c.udpMTU { - err = c.writePackets(fragUDPMessage(message, c.udpMTU)) - if err == nil { - return len(p), nil - } - } else { - err = c.writePacket(message) - } - if err == nil { - return len(p), nil - } - var tooLargeErr quic.ErrMessageTooLarge - if !errors.As(err, &tooLargeErr) { - return - } - c.udpMTU = int(tooLargeErr) - c.udpMTUTime = time.Now() - err = c.writePackets(fragUDPMessage(message, c.udpMTU)) - if err == nil { - return len(p), nil - } - return -} - -func (c *udpPacketConn) inputPacket(message *udpMessage) { - if message.fragmentTotal <= 1 { - select { - case c.data <- message: - default: - } - } else { - newMessage := c.defragger.feed(message) - if newMessage != nil { - select { - case c.data <- newMessage: - default: - } - } - } -} - -func (c *udpPacketConn) writePackets(messages []*udpMessage) error { - defer releaseMessages(messages) - for _, message := range messages { - err := c.writePacket(message) - if err != nil { - return err - } - } - return nil -} - -func (c *udpPacketConn) writePacket(message *udpMessage) error { - buffer := message.pack() - defer buffer.Release() - return c.quicConn.SendMessage(buffer.Bytes()) -} - -func (c *udpPacketConn) Close() error { - c.closeOnce.Do(func() { - c.closeWithError(os.ErrClosed) - c.onDestroy() - }) - return nil -} - -func (c *udpPacketConn) closeWithError(err error) { - c.cancel(err) -} - -func (c *udpPacketConn) LocalAddr() net.Addr { - return c.quicConn.LocalAddr() -} - -func (c *udpPacketConn) SetDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *udpPacketConn) SetReadDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { - return os.ErrInvalid -} - -type udpDefragger struct { - packetMap *cache.LruCache[uint16, *packetItem] -} - -func newUDPDefragger() *udpDefragger { - return &udpDefragger{ - packetMap: cache.New( - cache.WithAge[uint16, *packetItem](10), - cache.WithUpdateAgeOnGet[uint16, *packetItem](), - cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) { - releaseMessages(value.messages) - }), - ), - } -} - -type packetItem struct { - access sync.Mutex - messages []*udpMessage - count uint8 -} - -func (d *udpDefragger) feed(m *udpMessage) *udpMessage { - if m.fragmentTotal <= 1 { - return m - } - if m.fragmentID >= m.fragmentTotal { - return nil - } - item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem) - item.access.Lock() - defer item.access.Unlock() - if int(m.fragmentTotal) != len(item.messages) { - releaseMessages(item.messages) - item.messages = make([]*udpMessage, m.fragmentTotal) - item.count = 1 - item.messages[m.fragmentID] = m - return nil - } - if item.messages[m.fragmentID] != nil { - return nil - } - item.messages[m.fragmentID] = m - item.count++ - if int(item.count) != len(item.messages) { - return nil - } - newMessage := allocMessage() - newMessage.sessionID = m.sessionID - newMessage.packetID = m.packetID - newMessage.destination = item.messages[0].destination - var finalLength int - for _, message := range item.messages { - finalLength += message.data.Len() - } - if finalLength > 0 { - newMessage.data = buf.NewSize(finalLength) - for _, message := range item.messages { - newMessage.data.Write(message.data.Bytes()) - message.releaseMessage() - } - item.messages = nil - return newMessage - } - item.messages = nil - return nil -} - -func newPacketItem() *packetItem { - return new(packetItem) -} - -func decodeUDPMessage(message *udpMessage, data []byte) error { - reader := bytes.NewReader(data) - err := binary.Read(reader, binary.BigEndian, &message.sessionID) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.packetID) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.fragmentID) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) - if err != nil { - return err - } - message.destination, err = protocol.ReadVString(reader) - if err != nil { - return err - } - message.data = buf.As(data[len(data)-reader.Len():]) - return nil -} diff --git a/transport/hysteria2/salamander.go b/transport/hysteria2/salamander.go deleted file mode 100644 index 9b734d52..00000000 --- a/transport/hysteria2/salamander.go +++ /dev/null @@ -1,106 +0,0 @@ -package hysteria2 - -import ( - "net" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - - "golang.org/x/crypto/blake2b" -) - -const salamanderSaltLen = 8 - -const ObfsTypeSalamander = "salamander" - -type Salamander struct { - net.PacketConn - password []byte -} - -func NewSalamanderConn(conn net.PacketConn, password []byte) net.PacketConn { - writer, isVectorised := bufio.CreateVectorisedPacketWriter(conn) - if isVectorised { - return &VectorisedSalamander{ - Salamander: Salamander{ - PacketConn: conn, - password: password, - }, - writer: writer, - } - } else { - return &Salamander{ - PacketConn: conn, - password: password, - } - } -} - -func (s *Salamander) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, addr, err = s.PacketConn.ReadFrom(p) - if err != nil { - return - } - if n <= salamanderSaltLen { - return 0, nil, E.New("salamander: packet too short") - } - key := blake2b.Sum256(append(s.password, p[:salamanderSaltLen]...)) - for index, c := range p[salamanderSaltLen:n] { - p[index] = c ^ key[index%blake2b.Size256] - } - return n - salamanderSaltLen, addr, nil -} - -func (s *Salamander) WriteTo(p []byte, addr net.Addr) (n int, err error) { - buffer := buf.NewSize(len(p) + salamanderSaltLen) - defer buffer.Release() - buffer.WriteRandom(salamanderSaltLen) - key := blake2b.Sum256(append(s.password, buffer.Bytes()...)) - for index, c := range p { - common.Must(buffer.WriteByte(c ^ key[index%blake2b.Size256])) - } - _, err = s.PacketConn.WriteTo(buffer.Bytes(), addr) - if err != nil { - return - } - return len(p), nil -} - -type VectorisedSalamander struct { - Salamander - writer N.VectorisedPacketWriter -} - -func (s *VectorisedSalamander) WriteTo(p []byte, addr net.Addr) (n int, err error) { - buffer := buf.NewSize(salamanderSaltLen) - buffer.WriteRandom(salamanderSaltLen) - key := blake2b.Sum256(append(s.password, buffer.Bytes()...)) - for i := range p { - p[i] ^= key[i%blake2b.Size256] - } - err = s.writer.WriteVectorisedPacket([]*buf.Buffer{buffer, buf.As(p)}, M.SocksaddrFromNet(addr)) - if err != nil { - return - } - return len(p), nil -} - -func (s *VectorisedSalamander) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { - header := buf.NewSize(salamanderSaltLen) - defer header.Release() - header.WriteRandom(salamanderSaltLen) - key := blake2b.Sum256(append(s.password, header.Bytes()...)) - var bufferIndex int - for _, buffer := range buffers { - content := buffer.Bytes() - for index, c := range content { - content[bufferIndex+index] = c ^ key[bufferIndex+index%blake2b.Size256] - } - bufferIndex += len(content) - } - return s.writer.WriteVectorisedPacket(append([]*buf.Buffer{header}, buffers...), destination) -} diff --git a/transport/hysteria2/server.go b/transport/hysteria2/server.go deleted file mode 100644 index e9bb3904..00000000 --- a/transport/hysteria2/server.go +++ /dev/null @@ -1,344 +0,0 @@ -package hysteria2 - -import ( - "context" - "io" - "net" - "net/http" - "os" - "runtime" - "strings" - "sync" - - "github.com/sagernet/quic-go" - "github.com/sagernet/quic-go/http3" - "github.com/sagernet/sing-box/common/qtls" - "github.com/sagernet/sing-box/common/tls" - "github.com/sagernet/sing-box/transport/hysteria2/congestion" - "github.com/sagernet/sing-box/transport/hysteria2/internal/protocol" - tuicCongestion "github.com/sagernet/sing-box/transport/tuic/congestion" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/auth" - "github.com/sagernet/sing/common/baderror" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -type ServerOptions struct { - Context context.Context - Logger logger.Logger - SendBPS uint64 - ReceiveBPS uint64 - IgnoreClientBandwidth bool - SalamanderPassword string - TLSConfig tls.ServerConfig - Users []User - UDPDisabled bool - Handler ServerHandler - MasqueradeHandler http.Handler -} - -type User struct { - Name string - Password string -} - -type ServerHandler interface { - N.TCPConnectionHandler - N.UDPConnectionHandler -} - -type Server struct { - ctx context.Context - logger logger.Logger - sendBPS uint64 - receiveBPS uint64 - ignoreClientBandwidth bool - salamanderPassword string - tlsConfig tls.ServerConfig - quicConfig *quic.Config - userMap map[string]User - udpDisabled bool - handler ServerHandler - masqueradeHandler http.Handler - quicListener io.Closer -} - -func NewServer(options ServerOptions) (*Server, error) { - quicConfig := &quic.Config{ - DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), - EnableDatagrams: !options.UDPDisabled, - MaxIncomingStreams: 1 << 60, - InitialStreamReceiveWindow: defaultStreamReceiveWindow, - MaxStreamReceiveWindow: defaultStreamReceiveWindow, - InitialConnectionReceiveWindow: defaultConnReceiveWindow, - MaxConnectionReceiveWindow: defaultConnReceiveWindow, - MaxIdleTimeout: defaultMaxIdleTimeout, - KeepAlivePeriod: defaultKeepAlivePeriod, - } - if len(options.Users) == 0 { - return nil, E.New("missing users") - } - userMap := make(map[string]User) - for _, user := range options.Users { - userMap[user.Password] = user - } - if options.MasqueradeHandler == nil { - options.MasqueradeHandler = http.NotFoundHandler() - } - return &Server{ - ctx: options.Context, - logger: options.Logger, - sendBPS: options.SendBPS, - receiveBPS: options.ReceiveBPS, - ignoreClientBandwidth: options.IgnoreClientBandwidth, - salamanderPassword: options.SalamanderPassword, - tlsConfig: options.TLSConfig, - quicConfig: quicConfig, - userMap: userMap, - udpDisabled: options.UDPDisabled, - handler: options.Handler, - masqueradeHandler: options.MasqueradeHandler, - }, nil -} - -func (s *Server) Start(conn net.PacketConn) error { - if s.salamanderPassword != "" { - conn = NewSalamanderConn(conn, []byte(s.salamanderPassword)) - } - err := qtls.ConfigureHTTP3(s.tlsConfig) - if err != nil { - return err - } - listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) - if err != nil { - return err - } - s.quicListener = listener - go s.loopConnections(listener) - return nil -} - -func (s *Server) Close() error { - return common.Close( - s.quicListener, - ) -} - -func (s *Server) loopConnections(listener qtls.QUICListener) { - for { - connection, err := listener.Accept(s.ctx) - if err != nil { - if strings.Contains(err.Error(), "server closed") { - s.logger.Debug(E.Cause(err, "listener closed")) - } else { - s.logger.Error(E.Cause(err, "listener closed")) - } - return - } - go s.handleConnection(connection) - } -} - -func (s *Server) handleConnection(connection quic.Connection) { - session := &serverSession{ - Server: s, - ctx: s.ctx, - quicConn: connection, - source: M.SocksaddrFromNet(connection.RemoteAddr()), - connDone: make(chan struct{}), - udpConnMap: make(map[uint32]*udpPacketConn), - } - httpServer := http3.Server{ - Handler: session, - StreamHijacker: session.handleStream0, - } - _ = httpServer.ServeQUICConn(connection) - _ = connection.CloseWithError(0, "") -} - -type serverSession struct { - *Server - ctx context.Context - quicConn quic.Connection - source M.Socksaddr - connAccess sync.Mutex - connDone chan struct{} - connErr error - authenticated bool - authUser *User - udpAccess sync.RWMutex - udpConnMap map[uint32]*udpPacketConn -} - -func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath { - if s.authenticated { - protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ - UDPEnabled: !s.udpDisabled, - Rx: s.receiveBPS, - RxAuto: s.ignoreClientBandwidth, - }) - w.WriteHeader(protocol.StatusAuthOK) - return - } - request := protocol.AuthRequestFromHeader(r.Header) - user, loaded := s.userMap[request.Auth] - if !loaded { - s.masqueradeHandler.ServeHTTP(w, r) - return - } - s.authUser = &user - s.authenticated = true - if !s.ignoreClientBandwidth && request.Rx > 0 { - var sendBps uint64 - if s.sendBPS > 0 && s.sendBPS < request.Rx { - sendBps = s.sendBPS - } else { - sendBps = request.Rx - } - s.quicConn.SetCongestionControl(congestion.NewBrutalSender(sendBps)) - } else { - s.quicConn.SetCongestionControl(tuicCongestion.NewBBRSender( - tuicCongestion.DefaultClock{}, - tuicCongestion.GetInitialPacketSize(s.quicConn.RemoteAddr()), - tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize, - tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize, - )) - } - protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ - UDPEnabled: !s.udpDisabled, - Rx: s.receiveBPS, - RxAuto: s.ignoreClientBandwidth, - }) - w.WriteHeader(protocol.StatusAuthOK) - if s.ctx.Done() != nil { - go func() { - select { - case <-s.ctx.Done(): - s.closeWithError(s.ctx.Err()) - case <-s.connDone: - } - }() - } - if !s.udpDisabled { - go s.loopMessages() - } - } else { - s.masqueradeHandler.ServeHTTP(w, r) - } -} - -func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) { - if !s.authenticated || err != nil { - return false, nil - } - if frameType != protocol.FrameTypeTCPRequest { - return false, nil - } - go func() { - hErr := s.handleStream(stream) - stream.CancelRead(0) - stream.Close() - if hErr != nil { - stream.CancelRead(0) - stream.Close() - s.logger.Error(E.Cause(hErr, "handle stream request")) - } - }() - return true, nil -} - -func (s *serverSession) handleStream(stream quic.Stream) error { - destinationString, err := protocol.ReadTCPRequest(stream) - if err != nil { - return E.New("read TCP request") - } - ctx := s.ctx - if s.authUser.Name != "" { - ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) - } - _ = s.handler.NewConnection(ctx, &serverConn{Stream: stream}, M.Metadata{ - Source: s.source, - Destination: M.ParseSocksaddr(destinationString), - }) - return nil -} - -func (s *serverSession) closeWithError(err error) { - s.connAccess.Lock() - defer s.connAccess.Unlock() - select { - case <-s.connDone: - return - default: - s.connErr = err - close(s.connDone) - } - if E.IsClosedOrCanceled(err) { - s.logger.Debug(E.Cause(err, "connection failed")) - } else { - s.logger.Error(E.Cause(err, "connection failed")) - } - _ = s.quicConn.CloseWithError(0, "") -} - -type serverConn struct { - quic.Stream - responseWritten bool -} - -func (c *serverConn) HandshakeFailure(err error) error { - if c.responseWritten { - return os.ErrClosed - } - c.responseWritten = true - buffer := protocol.WriteTCPResponse(false, err.Error(), nil) - defer buffer.Release() - return common.Error(c.Stream.Write(buffer.Bytes())) -} - -func (c *serverConn) HandshakeSuccess() error { - if c.responseWritten { - return nil - } - c.responseWritten = true - buffer := protocol.WriteTCPResponse(true, "", nil) - defer buffer.Release() - return common.Error(c.Stream.Write(buffer.Bytes())) -} - -func (c *serverConn) Read(p []byte) (n int, err error) { - n, err = c.Stream.Read(p) - return n, baderror.WrapQUIC(err) -} - -func (c *serverConn) Write(p []byte) (n int, err error) { - if !c.responseWritten { - c.responseWritten = true - buffer := protocol.WriteTCPResponse(true, "", p) - defer buffer.Release() - _, err = c.Stream.Write(buffer.Bytes()) - if err != nil { - return 0, baderror.WrapQUIC(err) - } - return len(p), nil - } - n, err = c.Stream.Write(p) - return n, baderror.WrapQUIC(err) -} - -func (c *serverConn) LocalAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *serverConn) RemoteAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *serverConn) Close() error { - c.Stream.CancelRead(0) - return c.Stream.Close() -} diff --git a/transport/hysteria2/server_packet.go b/transport/hysteria2/server_packet.go deleted file mode 100644 index d84b5927..00000000 --- a/transport/hysteria2/server_packet.go +++ /dev/null @@ -1,55 +0,0 @@ -package hysteria2 - -import ( - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" -) - -func (s *serverSession) loopMessages() { - for { - message, err := s.quicConn.ReceiveMessage(s.ctx) - if err != nil { - s.closeWithError(E.Cause(err, "receive message")) - return - } - hErr := s.handleMessage(message) - if hErr != nil { - s.closeWithError(E.Cause(hErr, "handle message")) - return - } - } -} - -func (s *serverSession) handleMessage(data []byte) error { - message := allocMessage() - err := decodeUDPMessage(message, data) - if err != nil { - message.release() - return E.Cause(err, "decode UDP message") - } - s.handleUDPMessage(message) - return nil -} - -func (s *serverSession) handleUDPMessage(message *udpMessage) { - s.udpAccess.RLock() - udpConn, loaded := s.udpConnMap[message.sessionID] - s.udpAccess.RUnlock() - if !loaded || common.Done(udpConn.ctx) { - udpConn = newUDPPacketConn(s.ctx, s.quicConn, func() { - s.udpAccess.Lock() - delete(s.udpConnMap, message.sessionID) - s.udpAccess.Unlock() - }) - udpConn.sessionID = message.sessionID - s.udpAccess.Lock() - s.udpConnMap[message.sessionID] = udpConn - s.udpAccess.Unlock() - go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{ - Source: s.source, - Destination: M.ParseSocksaddr(message.destination), - }) - } - udpConn.inputPacket(message) -} diff --git a/transport/tuic/address.go b/transport/tuic/address.go deleted file mode 100644 index 22b18fa9..00000000 --- a/transport/tuic/address.go +++ /dev/null @@ -1,10 +0,0 @@ -package tuic - -import M "github.com/sagernet/sing/common/metadata" - -var addressSerializer = M.NewSerializer( - M.AddressFamilyByte(0x00, M.AddressFamilyFqdn), - M.AddressFamilyByte(0x01, M.AddressFamilyIPv4), - M.AddressFamilyByte(0x02, M.AddressFamilyIPv6), - M.AddressFamilyByte(0xff, M.AddressFamilyEmpty), -) diff --git a/transport/tuic/client.go b/transport/tuic/client.go deleted file mode 100644 index 88967723..00000000 --- a/transport/tuic/client.go +++ /dev/null @@ -1,307 +0,0 @@ -//go:build with_quic - -package tuic - -import ( - "context" - "io" - "net" - "runtime" - "sync" - "time" - - "github.com/sagernet/quic-go" - "github.com/sagernet/sing-box/common/qtls" - "github.com/sagernet/sing-box/common/tls" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/baderror" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - - "github.com/gofrs/uuid/v5" -) - -type ClientOptions struct { - Context context.Context - Dialer N.Dialer - ServerAddress M.Socksaddr - TLSConfig tls.Config - UUID uuid.UUID - Password string - CongestionControl string - UDPStream bool - ZeroRTTHandshake bool - Heartbeat time.Duration -} - -type Client struct { - ctx context.Context - dialer N.Dialer - serverAddr M.Socksaddr - tlsConfig tls.Config - quicConfig *quic.Config - uuid uuid.UUID - password string - congestionControl string - udpStream bool - zeroRTTHandshake bool - heartbeat time.Duration - - connAccess sync.RWMutex - conn *clientQUICConnection -} - -func NewClient(options ClientOptions) (*Client, error) { - if options.Heartbeat == 0 { - options.Heartbeat = 10 * time.Second - } - quicConfig := &quic.Config{ - DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), - MaxDatagramFrameSize: 1400, - EnableDatagrams: true, - MaxIncomingUniStreams: 1 << 60, - } - switch options.CongestionControl { - case "": - options.CongestionControl = "cubic" - case "cubic", "new_reno", "bbr": - default: - return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) - } - return &Client{ - ctx: options.Context, - dialer: options.Dialer, - serverAddr: options.ServerAddress, - tlsConfig: options.TLSConfig, - quicConfig: quicConfig, - uuid: options.UUID, - password: options.Password, - congestionControl: options.CongestionControl, - udpStream: options.UDPStream, - zeroRTTHandshake: options.ZeroRTTHandshake, - heartbeat: options.Heartbeat, - }, nil -} - -func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { - conn := c.conn - if conn != nil && conn.active() { - return conn, nil - } - c.connAccess.Lock() - defer c.connAccess.Unlock() - conn = c.conn - if conn != nil && conn.active() { - return conn, nil - } - conn, err := c.offerNew(ctx) - if err != nil { - return nil, err - } - return conn, nil -} - -func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { - udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr) - if err != nil { - return nil, err - } - var quicConn quic.Connection - if c.zeroRTTHandshake { - quicConn, err = qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) - } else { - quicConn, err = qtls.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) - } - if err != nil { - udpConn.Close() - return nil, E.Cause(err, "open connection") - } - setCongestion(c.ctx, quicConn, c.congestionControl) - conn := &clientQUICConnection{ - quicConn: quicConn, - rawConn: udpConn, - connDone: make(chan struct{}), - udpConnMap: make(map[uint16]*udpPacketConn), - } - go func() { - hErr := c.clientHandshake(quicConn) - if hErr != nil { - conn.closeWithError(hErr) - } - }() - if c.udpStream { - go c.loopUniStreams(conn) - } - go c.loopMessages(conn) - go c.loopHeartbeats(conn) - c.conn = conn - return conn, nil -} - -func (c *Client) clientHandshake(conn quic.Connection) error { - authStream, err := conn.OpenUniStream() - if err != nil { - return E.Cause(err, "open handshake stream") - } - defer authStream.Close() - handshakeState := conn.ConnectionState() - tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32) - if err != nil { - return E.Cause(err, "export keying material") - } - authRequest := buf.NewSize(AuthenticateLen) - authRequest.WriteByte(Version) - authRequest.WriteByte(CommandAuthenticate) - authRequest.Write(c.uuid[:]) - authRequest.Write(tuicAuthToken) - return common.Error(authStream.Write(authRequest.Bytes())) -} - -func (c *Client) loopHeartbeats(conn *clientQUICConnection) { - ticker := time.NewTicker(c.heartbeat) - defer ticker.Stop() - for { - select { - case <-conn.connDone: - return - case <-ticker.C: - err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat}) - if err != nil { - conn.closeWithError(E.Cause(err, "send heartbeat")) - } - } - } -} - -func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { - conn, err := c.offer(ctx) - if err != nil { - return nil, err - } - stream, err := conn.quicConn.OpenStream() - if err != nil { - return nil, err - } - return &clientConn{ - Stream: stream, - parent: conn, - destination: destination, - }, nil -} - -func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { - conn, err := c.offer(ctx) - if err != nil { - return nil, err - } - var sessionID uint16 - clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() { - conn.udpAccess.Lock() - delete(conn.udpConnMap, sessionID) - conn.udpAccess.Unlock() - }) - conn.udpAccess.Lock() - sessionID = conn.udpSessionID - conn.udpSessionID++ - conn.udpConnMap[sessionID] = clientPacketConn - conn.udpAccess.Unlock() - clientPacketConn.sessionID = sessionID - return clientPacketConn, nil -} - -func (c *Client) CloseWithError(err error) error { - conn := c.conn - if conn != nil { - conn.closeWithError(err) - } - return nil -} - -type clientQUICConnection struct { - quicConn quic.Connection - rawConn io.Closer - closeOnce sync.Once - connDone chan struct{} - connErr error - udpAccess sync.RWMutex - udpConnMap map[uint16]*udpPacketConn - udpSessionID uint16 -} - -func (c *clientQUICConnection) active() bool { - select { - case <-c.quicConn.Context().Done(): - return false - default: - } - select { - case <-c.connDone: - return false - default: - } - return true -} - -func (c *clientQUICConnection) closeWithError(err error) { - c.closeOnce.Do(func() { - c.connErr = err - close(c.connDone) - _ = c.quicConn.CloseWithError(0, "") - _ = c.rawConn.Close() - }) -} - -type clientConn struct { - quic.Stream - parent *clientQUICConnection - destination M.Socksaddr - requestWritten bool -} - -func (c *clientConn) NeedHandshake() bool { - return !c.requestWritten -} - -func (c *clientConn) Read(b []byte) (n int, err error) { - n, err = c.Stream.Read(b) - return n, baderror.WrapQUIC(err) -} - -func (c *clientConn) Write(b []byte) (n int, err error) { - if !c.requestWritten { - request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b)) - defer request.Release() - request.WriteByte(Version) - request.WriteByte(CommandConnect) - err = addressSerializer.WriteAddrPort(request, c.destination) - if err != nil { - return - } - request.Write(b) - _, err = c.Stream.Write(request.Bytes()) - if err != nil { - c.parent.closeWithError(E.Cause(err, "create new connection")) - return 0, baderror.WrapQUIC(err) - } - c.requestWritten = true - return len(b), nil - } - n, err = c.Stream.Write(b) - return n, baderror.WrapQUIC(err) -} - -func (c *clientConn) Close() error { - c.Stream.CancelRead(0) - return c.Stream.Close() -} - -func (c *clientConn) LocalAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *clientConn) RemoteAddr() net.Addr { - return c.destination -} diff --git a/transport/tuic/client_packet.go b/transport/tuic/client_packet.go deleted file mode 100644 index eb660848..00000000 --- a/transport/tuic/client_packet.go +++ /dev/null @@ -1,112 +0,0 @@ -//go:build with_quic - -package tuic - -import ( - "io" - - "github.com/sagernet/quic-go" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" -) - -func (c *Client) loopMessages(conn *clientQUICConnection) { - for { - message, err := conn.quicConn.ReceiveMessage(c.ctx) - if err != nil { - conn.closeWithError(E.Cause(err, "receive message")) - return - } - go func() { - hErr := c.handleMessage(conn, message) - if hErr != nil { - conn.closeWithError(E.Cause(hErr, "handle message")) - } - }() - } -} - -func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { - if len(data) < 2 { - return E.New("invalid message") - } - if data[0] != Version { - return E.New("unknown version ", data[0]) - } - switch data[1] { - case CommandPacket: - message := allocMessage() - err := decodeUDPMessage(message, data[2:]) - if err != nil { - message.release() - return E.Cause(err, "decode UDP message") - } - conn.handleUDPMessage(message) - return nil - case CommandHeartbeat: - return nil - default: - return E.New("unknown command ", data[0]) - } -} - -func (c *Client) loopUniStreams(conn *clientQUICConnection) { - for { - stream, err := conn.quicConn.AcceptUniStream(c.ctx) - if err != nil { - conn.closeWithError(E.Cause(err, "handle uni stream")) - return - } - go func() { - hErr := c.handleUniStream(conn, stream) - if hErr != nil { - conn.closeWithError(hErr) - } - }() - } -} - -func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error { - defer stream.CancelRead(0) - buffer := buf.NewPacket() - defer buffer.Release() - _, err := buffer.ReadAtLeastFrom(stream, 2) - if err != nil { - return err - } - version, _ := buffer.ReadByte() - if version != Version { - return E.New("unknown version ", version) - } - command, _ := buffer.ReadByte() - if command != CommandPacket { - return E.New("unknown command ", command) - } - reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream) - message := allocMessage() - err = readUDPMessage(message, reader) - if err != nil { - message.release() - return err - } - conn.handleUDPMessage(message) - return nil -} - -func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) { - c.udpAccess.RLock() - udpConn, loaded := c.udpConnMap[message.sessionID] - c.udpAccess.RUnlock() - if !loaded { - message.releaseMessage() - return - } - select { - case <-udpConn.ctx.Done(): - message.releaseMessage() - return - default: - } - udpConn.inputPacket(message) -} diff --git a/transport/tuic/congestion.go b/transport/tuic/congestion.go deleted file mode 100644 index 71f74838..00000000 --- a/transport/tuic/congestion.go +++ /dev/null @@ -1,46 +0,0 @@ -package tuic - -import ( - "context" - "time" - - "github.com/sagernet/quic-go" - "github.com/sagernet/sing-box/transport/tuic/congestion" - "github.com/sagernet/sing/common/ntp" -) - -func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) { - timeFunc := ntp.TimeFuncFromContext(ctx) - if timeFunc == nil { - timeFunc = time.Now - } - switch congestionName { - case "cubic": - connection.SetCongestionControl( - congestion.NewCubicSender( - congestion.DefaultClock{TimeFunc: timeFunc}, - congestion.GetInitialPacketSize(connection.RemoteAddr()), - false, - nil, - ), - ) - case "new_reno": - connection.SetCongestionControl( - congestion.NewCubicSender( - congestion.DefaultClock{TimeFunc: timeFunc}, - congestion.GetInitialPacketSize(connection.RemoteAddr()), - true, - nil, - ), - ) - case "bbr": - connection.SetCongestionControl( - congestion.NewBBRSender( - congestion.DefaultClock{}, - congestion.GetInitialPacketSize(connection.RemoteAddr()), - congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize, - congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize, - ), - ) - } -} diff --git a/transport/tuic/congestion/README.md b/transport/tuic/congestion/README.md deleted file mode 100644 index 6aa0309d..00000000 --- a/transport/tuic/congestion/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# congestion - -mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion \ No newline at end of file diff --git a/transport/tuic/congestion/bandwidth.go b/transport/tuic/congestion/bandwidth.go deleted file mode 100644 index 23393bad..00000000 --- a/transport/tuic/congestion/bandwidth.go +++ /dev/null @@ -1,25 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/sagernet/quic-go/congestion" -) - -// Bandwidth of a connection -type Bandwidth uint64 - -const infBandwidth Bandwidth = math.MaxUint64 - -const ( - // BitsPerSecond is 1 bit per second - BitsPerSecond Bandwidth = 1 - // BytesPerSecond is 1 byte per second - BytesPerSecond = 8 * BitsPerSecond -) - -// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta -func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { - return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond -} diff --git a/transport/tuic/congestion/bandwidth_sampler.go b/transport/tuic/congestion/bandwidth_sampler.go deleted file mode 100644 index 908f6e0d..00000000 --- a/transport/tuic/congestion/bandwidth_sampler.go +++ /dev/null @@ -1,374 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/sagernet/quic-go/congestion" -) - -var InfiniteBandwidth = Bandwidth(math.MaxUint64) - -// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned -// to the caller when the packet is acked or lost. -type SendTimeState struct { - // Whether other states in this object is valid. - isValid bool - // Whether the sender is app limited at the time the packet was sent. - // App limited bandwidth sample might be artificially low because the sender - // did not have enough data to send in order to saturate the link. - isAppLimited bool - // Total number of sent bytes at the time the packet was sent. - // Includes the packet itself. - totalBytesSent congestion.ByteCount - // Total number of acked bytes at the time the packet was sent. - totalBytesAcked congestion.ByteCount - // Total number of lost bytes at the time the packet was sent. - totalBytesLost congestion.ByteCount -} - -// ConnectionStateOnSentPacket represents the information about a sent packet -// and the state of the connection at the moment the packet was sent, -// specifically the information about the most recently acknowledged packet at -// that moment. -type ConnectionStateOnSentPacket struct { - packetNumber congestion.PacketNumber - // Time at which the packet is sent. - sendTime time.Time - // Size of the packet. - size congestion.ByteCount - // The value of |totalBytesSentAtLastAckedPacket| at the time the - // packet was sent. - totalBytesSentAtLastAckedPacket congestion.ByteCount - // The value of |lastAckedPacketSentTime| at the time the packet was - // sent. - lastAckedPacketSentTime time.Time - // The value of |lastAckedPacketAckTime| at the time the packet was - // sent. - lastAckedPacketAckTime time.Time - // Send time states that are returned to the congestion controller when the - // packet is acked or lost. - sendTimeState SendTimeState -} - -// BandwidthSample -type BandwidthSample struct { - // The bandwidth at that particular sample. Zero if no valid bandwidth sample - // is available. - bandwidth Bandwidth - // The RTT measurement at this particular sample. Zero if no RTT sample is - // available. Does not correct for delayed ack time. - rtt time.Duration - // States captured when the packet was sent. - stateAtSend SendTimeState -} - -func NewBandwidthSample() *BandwidthSample { - return &BandwidthSample{ - // FIXME: the default value of original code is zero. - rtt: InfiniteRTT, - } -} - -// BandwidthSampler keeps track of sent and acknowledged packets and outputs a -// bandwidth sample for every packet acknowledged. The samples are taken for -// individual packets, and are not filtered; the consumer has to filter the -// bandwidth samples itself. In certain cases, the sampler will locally severely -// underestimate the bandwidth, hence a maximum filter with a size of at least -// one RTT is recommended. -// -// This class bases its samples on the slope of two curves: the number of bytes -// sent over time, and the number of bytes acknowledged as received over time. -// It produces a sample of both slopes for every packet that gets acknowledged, -// based on a slope between two points on each of the corresponding curves. Note -// that due to the packet loss, the number of bytes on each curve might get -// further and further away from each other, meaning that it is not feasible to -// compare byte values coming from different curves with each other. -// -// The obvious points for measuring slope sample are the ones corresponding to -// the packet that was just acknowledged. Let us denote them as S_1 (point at -// which the current packet was sent) and A_1 (point at which the current packet -// was acknowledged). However, taking a slope requires two points on each line, -// so estimating bandwidth requires picking a packet in the past with respect to -// which the slope is measured. -// -// For that purpose, BandwidthSampler always keeps track of the most recently -// acknowledged packet, and records it together with every outgoing packet. -// When a packet gets acknowledged (A_1), it has not only information about when -// it itself was sent (S_1), but also the information about the latest -// acknowledged packet right before it was sent (S_0 and A_0). -// -// Based on that data, send and ack rate are estimated as: -// -// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0)) -// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0)) -// -// Here, the ack rate is intuitively the rate we want to treat as bandwidth. -// However, in certain cases (e.g. ack compression) the ack rate at a point may -// end up higher than the rate at which the data was originally sent, which is -// not indicative of the real bandwidth. Hence, we use the send rate as an upper -// bound, and the sample value is -// -// rate_sample = min(send_rate, ack_rate) -// -// An important edge case handled by the sampler is tracking the app-limited -// samples. There are multiple meaning of "app-limited" used interchangeably, -// hence it is important to understand and to be able to distinguish between -// them. -// -// Meaning 1: connection state. The connection is said to be app-limited when -// there is no outstanding data to send. This means that certain bandwidth -// samples in the future would not be an accurate indication of the link -// capacity, and it is important to inform consumer about that. Whenever -// connection becomes app-limited, the sampler is notified via OnAppLimited() -// method. -// -// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth -// sampler becomes notified about the connection being app-limited, it enters -// app-limited phase. In that phase, all *sent* packets are marked as -// app-limited. Note that the connection itself does not have to be -// app-limited during the app-limited phase, and in fact it will not be -// (otherwise how would it send packets?). The boolean flag below indicates -// whether the sampler is in that phase. -// -// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is -// sent during the app-limited phase, the resulting sample related to the -// packet will be marked as app-limited. -// -// With the terminology issue out of the way, let us consider the question of -// what kind of situation it addresses. -// -// Consider a scenario where we first send packets 1 to 20 at a regular -// bandwidth, and then immediately run out of data. After a few seconds, we send -// packets 21 to 60, and only receive ack for 21 between sending packets 40 and -// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0 -// we use to compute the slope is going to be packet 20, a few seconds apart -// from the current packet, hence the resulting estimate would be extremely low -// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21, -// meaning that the bandwidth sample would exclude the quiescence. -// -// Based on the analysis of that scenario, we implement the following rule: once -// OnAppLimited() is called, all sent packets will produce app-limited samples -// up until an ack for a packet that was sent after OnAppLimited() was called. -// Note that while the scenario above is not the only scenario when the -// connection is app-limited, the approach works in other cases too. -type BandwidthSampler struct { - // The total number of congestion controlled bytes sent during the connection. - totalBytesSent congestion.ByteCount - // The total number of congestion controlled bytes which were acknowledged. - totalBytesAcked congestion.ByteCount - // The total number of congestion controlled bytes which were lost. - totalBytesLost congestion.ByteCount - // The value of |totalBytesSent| at the time the last acknowledged packet - // was sent. Valid only when |lastAckedPacketSentTime| is valid. - totalBytesSentAtLastAckedPacket congestion.ByteCount - // The time at which the last acknowledged packet was sent. Set to - // QuicTime::Zero() if no valid timestamp is available. - lastAckedPacketSentTime time.Time - // The time at which the most recent packet was acknowledged. - lastAckedPacketAckTime time.Time - // The most recently sent packet. - lastSendPacket congestion.PacketNumber - // Indicates whether the bandwidth sampler is currently in an app-limited - // phase. - isAppLimited bool - // The packet that will be acknowledged after this one will cause the sampler - // to exit the app-limited phase. - endOfAppLimitedPhase congestion.PacketNumber - // Record of the connection state at the point where each packet in flight was - // sent, indexed by the packet number. - connectionStats *ConnectionStates -} - -func NewBandwidthSampler() *BandwidthSampler { - return &BandwidthSampler{ - connectionStats: &ConnectionStates{ - stats: make(map[congestion.PacketNumber]*ConnectionStateOnSentPacket), - }, - } -} - -// OnPacketSent Inputs the sent packet information into the sampler. Assumes that all -// packets are sent in order. The information about the packet will not be -// released from the sampler until it the packet is either acknowledged or -// declared lost. -func (s *BandwidthSampler) OnPacketSent(sentTime time.Time, lastSentPacket congestion.PacketNumber, sentBytes, bytesInFlight congestion.ByteCount, hasRetransmittableData bool) { - s.lastSendPacket = lastSentPacket - - if !hasRetransmittableData { - return - } - - s.totalBytesSent += sentBytes - - // If there are no packets in flight, the time at which the new transmission - // opens can be treated as the A_0 point for the purpose of bandwidth - // sampling. This underestimates bandwidth to some extent, and produces some - // artificially low samples for most packets in flight, but it provides with - // samples at important points where we would not have them otherwise, most - // importantly at the beginning of the connection. - if bytesInFlight == 0 { - s.lastAckedPacketAckTime = sentTime - s.totalBytesSentAtLastAckedPacket = s.totalBytesSent - - // In this situation ack compression is not a concern, set send rate to - // effectively infinite. - s.lastAckedPacketSentTime = sentTime - } - - s.connectionStats.Insert(lastSentPacket, sentTime, sentBytes, s) -} - -// OnPacketAcked Notifies the sampler that the |lastAckedPacket| is acknowledged. Returns a -// bandwidth sample. If no bandwidth sample is available, -// QuicBandwidth::Zero() is returned. -func (s *BandwidthSampler) OnPacketAcked(ackTime time.Time, lastAckedPacket congestion.PacketNumber) *BandwidthSample { - sentPacketState := s.connectionStats.Get(lastAckedPacket) - if sentPacketState == nil { - return NewBandwidthSample() - } - - sample := s.onPacketAckedInner(ackTime, lastAckedPacket, sentPacketState) - s.connectionStats.Remove(lastAckedPacket) - - return sample -} - -// onPacketAckedInner Handles the actual bandwidth calculations, whereas the outer method handles -// retrieving and removing |sentPacket|. -func (s *BandwidthSampler) onPacketAckedInner(ackTime time.Time, lastAckedPacket congestion.PacketNumber, sentPacket *ConnectionStateOnSentPacket) *BandwidthSample { - s.totalBytesAcked += sentPacket.size - - s.totalBytesSentAtLastAckedPacket = sentPacket.sendTimeState.totalBytesSent - s.lastAckedPacketSentTime = sentPacket.sendTime - s.lastAckedPacketAckTime = ackTime - - // Exit app-limited phase once a packet that was sent while the connection is - // not app-limited is acknowledged. - if s.isAppLimited && lastAckedPacket > s.endOfAppLimitedPhase { - s.isAppLimited = false - } - - // There might have been no packets acknowledged at the moment when the - // current packet was sent. In that case, there is no bandwidth sample to - // make. - if sentPacket.lastAckedPacketSentTime.IsZero() { - return NewBandwidthSample() - } - - // Infinite rate indicates that the sampler is supposed to discard the - // current send rate sample and use only the ack rate. - sendRate := InfiniteBandwidth - if sentPacket.sendTime.After(sentPacket.lastAckedPacketSentTime) { - sendRate = BandwidthFromDelta(sentPacket.sendTimeState.totalBytesSent-sentPacket.totalBytesSentAtLastAckedPacket, sentPacket.sendTime.Sub(sentPacket.lastAckedPacketSentTime)) - } - - // During the slope calculation, ensure that ack time of the current packet is - // always larger than the time of the previous packet, otherwise division by - // zero or integer underflow can occur. - if !ackTime.After(sentPacket.lastAckedPacketAckTime) { - // TODO(wub): Compare this code count before and after fixing clock jitter - // issue. - // if sentPacket.lastAckedPacketAckTime.Equal(sentPacket.sendTime) { - // This is the 1st packet after quiescense. - // QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2); - // } else { - // QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2); - // } - - return NewBandwidthSample() - } - - ackRate := BandwidthFromDelta(s.totalBytesAcked-sentPacket.sendTimeState.totalBytesAcked, - ackTime.Sub(sentPacket.lastAckedPacketAckTime)) - - // Note: this sample does not account for delayed acknowledgement time. This - // means that the RTT measurements here can be artificially high, especially - // on low bandwidth connections. - sample := &BandwidthSample{ - bandwidth: minBandwidth(sendRate, ackRate), - rtt: ackTime.Sub(sentPacket.sendTime), - } - - SentPacketToSendTimeState(sentPacket, &sample.stateAtSend) - return sample -} - -// OnPacketLost Informs the sampler that a packet is considered lost and it should no -// longer keep track of it. -func (s *BandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber) SendTimeState { - ok, sentPacket := s.connectionStats.Remove(packetNumber) - sendTimeState := SendTimeState{ - isValid: ok, - } - if sentPacket != nil { - s.totalBytesLost += sentPacket.size - SentPacketToSendTimeState(sentPacket, &sendTimeState) - } - - return sendTimeState -} - -// OnAppLimited Informs the sampler that the connection is currently app-limited, causing -// the sampler to enter the app-limited phase. The phase will expire by -// itself. -func (s *BandwidthSampler) OnAppLimited() { - s.isAppLimited = true - s.endOfAppLimitedPhase = s.lastSendPacket -} - -// SentPacketToSendTimeState Copy a subset of the (private) ConnectionStateOnSentPacket to the (public) -// SendTimeState. Always set send_time_state->is_valid to true. -func SentPacketToSendTimeState(sentPacket *ConnectionStateOnSentPacket, sendTimeState *SendTimeState) { - sendTimeState.isAppLimited = sentPacket.sendTimeState.isAppLimited - sendTimeState.totalBytesSent = sentPacket.sendTimeState.totalBytesSent - sendTimeState.totalBytesAcked = sentPacket.sendTimeState.totalBytesAcked - sendTimeState.totalBytesLost = sentPacket.sendTimeState.totalBytesLost - sendTimeState.isValid = true -} - -// ConnectionStates Record of the connection state at the point where each packet in flight was -// sent, indexed by the packet number. -// FIXME: using LinkedList replace map to fast remove all the packets lower than the specified packet number. -type ConnectionStates struct { - stats map[congestion.PacketNumber]*ConnectionStateOnSentPacket -} - -func (s *ConnectionStates) Insert(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) bool { - if _, ok := s.stats[packetNumber]; ok { - return false - } - - s.stats[packetNumber] = NewConnectionStateOnSentPacket(packetNumber, sentTime, bytes, sampler) - return true -} - -func (s *ConnectionStates) Get(packetNumber congestion.PacketNumber) *ConnectionStateOnSentPacket { - return s.stats[packetNumber] -} - -func (s *ConnectionStates) Remove(packetNumber congestion.PacketNumber) (bool, *ConnectionStateOnSentPacket) { - state, ok := s.stats[packetNumber] - if ok { - delete(s.stats, packetNumber) - } - return ok, state -} - -func NewConnectionStateOnSentPacket(packetNumber congestion.PacketNumber, sentTime time.Time, bytes congestion.ByteCount, sampler *BandwidthSampler) *ConnectionStateOnSentPacket { - return &ConnectionStateOnSentPacket{ - packetNumber: packetNumber, - sendTime: sentTime, - size: bytes, - lastAckedPacketSentTime: sampler.lastAckedPacketSentTime, - lastAckedPacketAckTime: sampler.lastAckedPacketAckTime, - totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket, - sendTimeState: SendTimeState{ - isValid: true, - isAppLimited: sampler.isAppLimited, - totalBytesSent: sampler.totalBytesSent, - totalBytesAcked: sampler.totalBytesAcked, - totalBytesLost: sampler.totalBytesLost, - }, - } -} diff --git a/transport/tuic/congestion/bbr_sender.go b/transport/tuic/congestion/bbr_sender.go deleted file mode 100644 index 34acc676..00000000 --- a/transport/tuic/congestion/bbr_sender.go +++ /dev/null @@ -1,1000 +0,0 @@ -package congestion - -// src from https://quiche.googlesource.com/quiche.git/+/66dea072431f94095dfc3dd2743cb94ef365f7ef/quic/core/congestion_control/bbr_sender.cc - -import ( - "fmt" - "math" - "math/rand" - "net" - "time" - - "github.com/sagernet/quic-go/congestion" -) - -const ( - // InitialMaxDatagramSize is the default maximum packet size used in QUIC for congestion window computations in bytes. - InitialMaxDatagramSize = 1252 - InitialPacketSizeIPv4 = 1252 - InitialPacketSizeIPv6 = 1232 - InitialCongestionWindow = 32 - DefaultBBRMaxCongestionWindow = 10000 -) - -func GetInitialPacketSize(addr net.Addr) congestion.ByteCount { - maxSize := congestion.ByteCount(1200) - // If this is not a UDP address, we don't know anything about the MTU. - // Use the minimum size of an Initial packet as the max packet size. - if udpAddr, ok := addr.(*net.UDPAddr); ok { - if udpAddr.IP.To4() != nil { - maxSize = InitialPacketSizeIPv4 - } else { - maxSize = InitialPacketSizeIPv6 - } - } - return congestion.ByteCount(maxSize) -} - -var ( - - // Default initial rtt used before any samples are received. - InitialRtt = 100 * time.Millisecond - - // The gain used for the STARTUP, equal to 4*ln(2). - DefaultHighGain = 2.77 - - // The gain used in STARTUP after loss has been detected. - // 1.5 is enough to allow for 25% exogenous loss and still observe a 25% growth - // in measured bandwidth. - StartupAfterLossGain = 1.5 - - // The cycle of gains used during the PROBE_BW stage. - PacingGain = []float64{1.25, 0.75, 1, 1, 1, 1, 1, 1} - - // The length of the gain cycle. - GainCycleLength = len(PacingGain) - - // The size of the bandwidth filter window, in round-trips. - BandwidthWindowSize = GainCycleLength + 2 - - // The time after which the current min_rtt value expires. - MinRttExpiry = 10 * time.Second - - // The minimum time the connection can spend in PROBE_RTT mode. - ProbeRttTime = time.Millisecond * 200 - - // If the bandwidth does not increase by the factor of |kStartupGrowthTarget| - // within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection - // will exit the STARTUP mode. - StartupGrowthTarget = 1.25 - RoundTripsWithoutGrowthBeforeExitingStartup = int64(3) - - // Coefficient of target congestion window to use when basing PROBE_RTT on BDP. - ModerateProbeRttMultiplier = 0.75 - - // Coefficient to determine if a new RTT is sufficiently similar to min_rtt that - // we don't need to enter PROBE_RTT. - SimilarMinRttThreshold = 1.125 - - // Congestion window gain for QUIC BBR during PROBE_BW phase. - DefaultCongestionWindowGainConst = 2.0 -) - -type bbrMode int - -const ( - // Startup phase of the connection. - STARTUP = iota - // After achieving the highest possible bandwidth during the startup, lower - // the pacing rate in order to drain the queue. - DRAIN - // Cruising mode. - PROBE_BW - // Temporarily slow down sending in order to empty the buffer and measure - // the real minimum RTT. - PROBE_RTT -) - -type bbrRecoveryState int - -const ( - // Do not limit. - NOT_IN_RECOVERY = iota - - // Allow an extra outstanding byte for each byte acknowledged. - CONSERVATION - - // Allow two extra outstanding bytes for each byte acknowledged (slow - // start). - GROWTH -) - -type bbrSender struct { - mode bbrMode - clock Clock - rttStats congestion.RTTStatsProvider - bytesInFlight congestion.ByteCount - // return total bytes of unacked packets. - // GetBytesInFlight func() congestion.ByteCount - // Bandwidth sampler provides BBR with the bandwidth measurements at - // individual points. - sampler *BandwidthSampler - // The number of the round trips that have occurred during the connection. - roundTripCount int64 - // The packet number of the most recently sent packet. - lastSendPacket congestion.PacketNumber - // Acknowledgement of any packet after |current_round_trip_end_| will cause - // the round trip counter to advance. - currentRoundTripEnd congestion.PacketNumber - // The filter that tracks the maximum bandwidth over the multiple recent - // round-trips. - maxBandwidth *WindowedFilter - // Tracks the maximum number of bytes acked faster than the sending rate. - maxAckHeight *WindowedFilter - // The time this aggregation started and the number of bytes acked during it. - aggregationEpochStartTime time.Time - aggregationEpochBytes congestion.ByteCount - // Minimum RTT estimate. Automatically expires within 10 seconds (and - // triggers PROBE_RTT mode) if no new value is sampled during that period. - minRtt time.Duration - // The time at which the current value of |min_rtt_| was assigned. - minRttTimestamp time.Time - // The maximum allowed number of bytes in flight. - congestionWindow congestion.ByteCount - // The initial value of the |congestion_window_|. - initialCongestionWindow congestion.ByteCount - // The largest value the |congestion_window_| can achieve. - initialMaxCongestionWindow congestion.ByteCount - // The smallest value the |congestion_window_| can achieve. - // minCongestionWindow congestion.ByteCount - // The pacing gain applied during the STARTUP phase. - highGain float64 - // The CWND gain applied during the STARTUP phase. - highCwndGain float64 - // The pacing gain applied during the DRAIN phase. - drainGain float64 - // The current pacing rate of the connection. - pacingRate Bandwidth - // The gain currently applied to the pacing rate. - pacingGain float64 - // The gain currently applied to the congestion window. - congestionWindowGain float64 - // The gain used for the congestion window during PROBE_BW. Latched from - // quic_bbr_cwnd_gain flag. - congestionWindowGainConst float64 - // The number of RTTs to stay in STARTUP mode. Defaults to 3. - numStartupRtts int64 - // If true, exit startup if 1RTT has passed with no bandwidth increase and - // the connection is in recovery. - exitStartupOnLoss bool - // Number of round-trips in PROBE_BW mode, used for determining the current - // pacing gain cycle. - cycleCurrentOffset int - // The time at which the last pacing gain cycle was started. - lastCycleStart time.Time - // Indicates whether the connection has reached the full bandwidth mode. - isAtFullBandwidth bool - // Number of rounds during which there was no significant bandwidth increase. - roundsWithoutBandwidthGain int64 - // The bandwidth compared to which the increase is measured. - bandwidthAtLastRound Bandwidth - // Set to true upon exiting quiescence. - exitingQuiescence bool - // Time at which PROBE_RTT has to be exited. Setting it to zero indicates - // that the time is yet unknown as the number of packets in flight has not - // reached the required value. - exitProbeRttAt time.Time - // Indicates whether a round-trip has passed since PROBE_RTT became active. - probeRttRoundPassed bool - // Indicates whether the most recent bandwidth sample was marked as - // app-limited. - lastSampleIsAppLimited bool - // Indicates whether any non app-limited samples have been recorded. - hasNoAppLimitedSample bool - // Indicates app-limited calls should be ignored as long as there's - // enough data inflight to see more bandwidth when necessary. - flexibleAppLimited bool - // Current state of recovery. - recoveryState bbrRecoveryState - // Receiving acknowledgement of a packet after |end_recovery_at_| will cause - // BBR to exit the recovery mode. A value above zero indicates at least one - // loss has been detected, so it must not be set back to zero. - endRecoveryAt congestion.PacketNumber - // A window used to limit the number of bytes in flight during loss recovery. - recoveryWindow congestion.ByteCount - // If true, consider all samples in recovery app-limited. - isAppLimitedRecovery bool - // When true, pace at 1.5x and disable packet conservation in STARTUP. - slowerStartup bool - // When true, disables packet conservation in STARTUP. - rateBasedStartup bool - // When non-zero, decreases the rate in STARTUP by the total number of bytes - // lost in STARTUP divided by CWND. - startupRateReductionMultiplier int64 - // Sum of bytes lost in STARTUP. - startupBytesLost congestion.ByteCount - // When true, add the most recent ack aggregation measurement during STARTUP. - enableAckAggregationDuringStartup bool - // When true, expire the windowed ack aggregation values in STARTUP when - // bandwidth increases more than 25%. - expireAckAggregationInStartup bool - // If true, will not exit low gain mode until bytes_in_flight drops below BDP - // or it's time for high gain mode. - drainToTarget bool - // If true, use a CWND of 0.75*BDP during probe_rtt instead of 4 packets. - probeRttBasedOnBdp bool - // If true, skip probe_rtt and update the timestamp of the existing min_rtt to - // now if min_rtt over the last cycle is within 12.5% of the current min_rtt. - // Even if the min_rtt is 12.5% too low, the 25% gain cycling and 2x CWND gain - // should overcome an overly small min_rtt. - probeRttSkippedIfSimilarRtt bool - // If true, disable PROBE_RTT entirely as long as the connection was recently - // app limited. - probeRttDisabledIfAppLimited bool - appLimitedSinceLastProbeRtt bool - minRttSinceLastProbeRtt time.Duration - // Latched value of --quic_always_get_bw_sample_when_acked. - alwaysGetBwSampleWhenAcked bool - - pacer *pacer - - maxDatagramSize congestion.ByteCount -} - -func NewBBRSender( - clock Clock, - initialMaxDatagramSize, - initialCongestionWindow, - initialMaxCongestionWindow congestion.ByteCount, -) *bbrSender { - b := &bbrSender{ - mode: STARTUP, - clock: clock, - sampler: NewBandwidthSampler(), - maxBandwidth: NewWindowedFilter(int64(BandwidthWindowSize), MaxFilter), - maxAckHeight: NewWindowedFilter(int64(BandwidthWindowSize), MaxFilter), - congestionWindow: initialCongestionWindow, - initialCongestionWindow: initialCongestionWindow, - highGain: DefaultHighGain, - highCwndGain: DefaultHighGain, - drainGain: 1.0 / DefaultHighGain, - pacingGain: 1.0, - congestionWindowGain: 1.0, - congestionWindowGainConst: DefaultCongestionWindowGainConst, - numStartupRtts: RoundTripsWithoutGrowthBeforeExitingStartup, - recoveryState: NOT_IN_RECOVERY, - recoveryWindow: initialMaxCongestionWindow, - minRttSinceLastProbeRtt: InfiniteRTT, - maxDatagramSize: initialMaxDatagramSize, - } - b.pacer = newPacer(b.BandwidthEstimate) - return b -} - -func (b *bbrSender) maxCongestionWindow() congestion.ByteCount { - return b.maxDatagramSize * DefaultBBRMaxCongestionWindow -} - -func (b *bbrSender) minCongestionWindow() congestion.ByteCount { - return b.maxDatagramSize * b.initialCongestionWindow -} - -func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { - b.rttStats = provider -} - -func (b *bbrSender) GetBytesInFlight() congestion.ByteCount { - return b.bytesInFlight -} - -// TimeUntilSend returns when the next packet should be sent. -func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { - b.bytesInFlight = bytesInFlight - return b.pacer.TimeUntilSend() -} - -func (b *bbrSender) HasPacingBudget(now time.Time) bool { - return b.pacer.Budget(now) >= b.maxDatagramSize -} - -func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) { - if s < b.maxDatagramSize { - panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s)) - } - cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow() - b.maxDatagramSize = s - if cwndIsMinCwnd { - b.congestionWindow = b.minCongestionWindow() - } - b.pacer.SetMaxDatagramSize(s) -} - -func (b *bbrSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool) { - b.pacer.SentPacket(sentTime, bytes) - b.lastSendPacket = packetNumber - - b.bytesInFlight = bytesInFlight - if bytesInFlight == 0 && b.sampler.isAppLimited { - b.exitingQuiescence = true - } - - if b.aggregationEpochStartTime.IsZero() { - b.aggregationEpochStartTime = sentTime - } - - b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable) -} - -func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool { - b.bytesInFlight = bytesInFlight - return bytesInFlight < b.GetCongestionWindow() -} - -func (b *bbrSender) GetCongestionWindow() congestion.ByteCount { - if b.mode == PROBE_RTT { - return b.ProbeRttCongestionWindow() - } - - if b.InRecovery() && !(b.rateBasedStartup && b.mode == STARTUP) { - return minByteCount(b.congestionWindow, b.recoveryWindow) - } - - return b.congestionWindow -} - -func (b *bbrSender) MaybeExitSlowStart() { -} - -func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, priorInFlight congestion.ByteCount, eventTime time.Time) { - totalBytesAckedBefore := b.sampler.totalBytesAcked - isRoundStart, minRttExpired := false, false - lastAckedPacket := number - - isRoundStart = b.UpdateRoundTripCounter(lastAckedPacket) - minRttExpired = b.UpdateBandwidthAndMinRtt(eventTime, number, ackedBytes) - b.UpdateRecoveryState(false, isRoundStart) - bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore - excessAcked := b.UpdateAckAggregationBytes(eventTime, bytesAcked) - - // Handle logic specific to STARTUP and DRAIN modes. - if isRoundStart && !b.isAtFullBandwidth { - b.CheckIfFullBandwidthReached() - } - b.MaybeExitStartupOrDrain(eventTime) - - // Handle logic specific to PROBE_RTT. - b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) - - // After the model is updated, recalculate the pacing rate and congestion - // window. - b.CalculatePacingRate() - b.CalculateCongestionWindow(bytesAcked, excessAcked) - b.CalculateRecoveryWindow(bytesAcked, congestion.ByteCount(0)) -} - -func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount, priorInFlight congestion.ByteCount) { - eventTime := time.Now() - totalBytesAckedBefore := b.sampler.totalBytesAcked - isRoundStart, minRttExpired := false, false - - b.DiscardLostPackets(number, lostBytes) - - // Input the new data into the BBR model of the connection. - var excessAcked congestion.ByteCount - - // Handle logic specific to PROBE_BW mode. - if b.mode == PROBE_BW { - b.UpdateGainCyclePhase(time.Now(), priorInFlight, true) - } - - // Handle logic specific to STARTUP and DRAIN modes. - b.MaybeExitStartupOrDrain(eventTime) - - // Handle logic specific to PROBE_RTT. - b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) - - // Calculate number of packets acked and lost. - bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore - bytesLost := lostBytes - - // After the model is updated, recalculate the pacing rate and congestion - // window. - b.CalculatePacingRate() - b.CalculateCongestionWindow(bytesAcked, excessAcked) - b.CalculateRecoveryWindow(bytesAcked, bytesLost) -} - -//func (b *bbrSender) OnCongestionEvent(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets, lostPackets []*congestion.Packet) { -// totalBytesAckedBefore := b.sampler.totalBytesAcked -// isRoundStart, minRttExpired := false, false -// -// if lostPackets != nil { -// b.DiscardLostPackets(lostPackets) -// } -// -// // Input the new data into the BBR model of the connection. -// var excessAcked congestion.ByteCount -// if len(ackedPackets) > 0 { -// lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber -// isRoundStart = b.UpdateRoundTripCounter(lastAckedPacket) -// minRttExpired = b.UpdateBandwidthAndMinRtt(eventTime, ackedPackets) -// b.UpdateRecoveryState(lastAckedPacket, len(lostPackets) > 0, isRoundStart) -// bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore -// excessAcked = b.UpdateAckAggregationBytes(eventTime, bytesAcked) -// } -// -// // Handle logic specific to PROBE_BW mode. -// if b.mode == PROBE_BW { -// b.UpdateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) > 0) -// } -// -// // Handle logic specific to STARTUP and DRAIN modes. -// if isRoundStart && !b.isAtFullBandwidth { -// b.CheckIfFullBandwidthReached() -// } -// b.MaybeExitStartupOrDrain(eventTime) -// -// // Handle logic specific to PROBE_RTT. -// b.MaybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) -// -// // Calculate number of packets acked and lost. -// bytesAcked := b.sampler.totalBytesAcked - totalBytesAckedBefore -// bytesLost := congestion.ByteCount(0) -// for _, packet := range lostPackets { -// bytesLost += packet.Length -// } -// -// // After the model is updated, recalculate the pacing rate and congestion -// // window. -// b.CalculatePacingRate() -// b.CalculateCongestionWindow(bytesAcked, excessAcked) -// b.CalculateRecoveryWindow(bytesAcked, bytesLost) -//} - -//func (b *bbrSender) SetNumEmulatedConnections(n int) { -// -//} - -func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) { -} - -//func (b *bbrSender) OnConnectionMigration() { -// -//} - -//// Experiments -//func (b *bbrSender) SetSlowStartLargeReduction(enabled bool) { -// -//} - -//func (b *bbrSender) BandwidthEstimate() Bandwidth { -// return Bandwidth(b.maxBandwidth.GetBest()) -//} - -// BandwidthEstimate returns the current bandwidth estimate -func (b *bbrSender) BandwidthEstimate() Bandwidth { - if b.rttStats == nil { - return infBandwidth - } - srtt := b.rttStats.SmoothedRTT() - if srtt == 0 { - // If we haven't measured an rtt, the bandwidth estimate is unknown. - return infBandwidth - } - return BandwidthFromDelta(b.GetCongestionWindow(), srtt) -} - -//func (b *bbrSender) HybridSlowStart() *HybridSlowStart { -// return nil -//} - -//func (b *bbrSender) SlowstartThreshold() congestion.ByteCount { -// return 0 -//} - -//func (b *bbrSender) RenoBeta() float32 { -// return 0.0 -//} - -func (b *bbrSender) InRecovery() bool { - return b.recoveryState != NOT_IN_RECOVERY -} - -func (b *bbrSender) InSlowStart() bool { - return b.mode == STARTUP -} - -//func (b *bbrSender) ShouldSendProbingPacket() bool { -// if b.pacingGain <= 1 { -// return false -// } -// // TODO(b/77975811): If the pipe is highly under-utilized, consider not -// // sending a probing transmission, because the extra bandwidth is not needed. -// // If flexible_app_limited is enabled, check if the pipe is sufficiently full. -// if b.flexibleAppLimited { -// return !b.IsPipeSufficientlyFull() -// } else { -// return true -// } -//} - -//func (b *bbrSender) IsPipeSufficientlyFull() bool { -// // See if we need more bytes in flight to see more bandwidth. -// if b.mode == STARTUP { -// // STARTUP exits if it doesn't observe a 25% bandwidth increase, so the CWND -// // must be more than 25% above the target. -// return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(1.5) -// } -// if b.pacingGain > 1 { -// // Super-unity PROBE_BW doesn't exit until 1.25 * BDP is achieved. -// return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(b.pacingGain) -// } -// // If bytes_in_flight are above the target congestion window, it should be -// // possible to observe the same or more bandwidth if it's available. -// return b.GetBytesInFlight() >= b.GetTargetCongestionWindow(1.1) -//} - -//func (b *bbrSender) SetFromConfig() { -// // TODO: not impl. -//} - -func (b *bbrSender) UpdateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool { - if b.currentRoundTripEnd == 0 || lastAckedPacket > b.currentRoundTripEnd { - b.currentRoundTripEnd = lastAckedPacket - b.roundTripCount++ - // if b.rttStats != nil && b.InSlowStart() { - // TODO: ++stats_->slowstart_num_rtts; - // } - return true - } - return false -} - -func (b *bbrSender) UpdateBandwidthAndMinRtt(now time.Time, number congestion.PacketNumber, ackedBytes congestion.ByteCount) bool { - sampleMinRtt := InfiniteRTT - - if !b.alwaysGetBwSampleWhenAcked && ackedBytes == 0 { - // Skip acked packets with 0 in flight bytes when updating bandwidth. - return false - } - bandwidthSample := b.sampler.OnPacketAcked(now, number) - if b.alwaysGetBwSampleWhenAcked && !bandwidthSample.stateAtSend.isValid { - // From the sampler's perspective, the packet has never been sent, or the - // packet has been acked or marked as lost previously. - return false - } - b.lastSampleIsAppLimited = bandwidthSample.stateAtSend.isAppLimited - // has_non_app_limited_sample_ |= - // !bandwidth_sample.state_at_send.is_app_limited; - if !bandwidthSample.stateAtSend.isAppLimited { - b.hasNoAppLimitedSample = true - } - if bandwidthSample.rtt > 0 { - sampleMinRtt = minRtt(sampleMinRtt, bandwidthSample.rtt) - } - if !bandwidthSample.stateAtSend.isAppLimited || bandwidthSample.bandwidth > b.BandwidthEstimate() { - b.maxBandwidth.Update(int64(bandwidthSample.bandwidth), b.roundTripCount) - } - - // If none of the RTT samples are valid, return immediately. - if sampleMinRtt == InfiniteRTT { - return false - } - - b.minRttSinceLastProbeRtt = minRtt(b.minRttSinceLastProbeRtt, sampleMinRtt) - // Do not expire min_rtt if none was ever available. - minRttExpired := b.minRtt > 0 && (now.After(b.minRttTimestamp.Add(MinRttExpiry))) - if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 { - if minRttExpired && b.ShouldExtendMinRttExpiry() { - minRttExpired = false - } else { - b.minRtt = sampleMinRtt - } - b.minRttTimestamp = now - // Reset since_last_probe_rtt fields. - b.minRttSinceLastProbeRtt = InfiniteRTT - b.appLimitedSinceLastProbeRtt = false - } - - return minRttExpired -} - -func (b *bbrSender) ShouldExtendMinRttExpiry() bool { - if b.probeRttDisabledIfAppLimited && b.appLimitedSinceLastProbeRtt { - // Extend the current min_rtt if we've been app limited recently. - return true - } - - minRttIncreasedSinceLastProbe := b.minRttSinceLastProbeRtt > time.Duration(float64(b.minRtt)*SimilarMinRttThreshold) - if b.probeRttSkippedIfSimilarRtt && b.appLimitedSinceLastProbeRtt && !minRttIncreasedSinceLastProbe { - // Extend the current min_rtt if we've been app limited recently and an rtt - // has been measured in that time that's less than 12.5% more than the - // current min_rtt. - return true - } - - return false -} - -func (b *bbrSender) DiscardLostPackets(number congestion.PacketNumber, lostBytes congestion.ByteCount) { - b.sampler.OnPacketLost(number) - if b.mode == STARTUP { - // if b.rttStats != nil { - // TODO: slow start. - // } - if b.startupRateReductionMultiplier != 0 { - b.startupBytesLost += lostBytes - } - } -} - -func (b *bbrSender) UpdateRecoveryState(hasLosses, isRoundStart bool) { - // Exit recovery when there are no losses for a round. - if !hasLosses { - b.endRecoveryAt = b.lastSendPacket - } - switch b.recoveryState { - case NOT_IN_RECOVERY: - // Enter conservation on the first loss. - if hasLosses { - b.recoveryState = CONSERVATION - // This will cause the |recovery_window_| to be set to the correct - // value in CalculateRecoveryWindow(). - b.recoveryWindow = 0 - // Since the conservation phase is meant to be lasting for a whole - // round, extend the current round as if it were started right now. - b.currentRoundTripEnd = b.lastSendPacket - if false && b.lastSampleIsAppLimited { - b.isAppLimitedRecovery = true - } - } - case CONSERVATION: - if isRoundStart { - b.recoveryState = GROWTH - } - fallthrough - case GROWTH: - // Exit recovery if appropriate. - if !hasLosses && b.lastSendPacket > b.endRecoveryAt { - b.recoveryState = NOT_IN_RECOVERY - b.isAppLimitedRecovery = false - } - } - - if b.recoveryState != NOT_IN_RECOVERY && b.isAppLimitedRecovery { - b.sampler.OnAppLimited() - } -} - -func (b *bbrSender) UpdateAckAggregationBytes(ackTime time.Time, ackedBytes congestion.ByteCount) congestion.ByteCount { - // Compute how many bytes are expected to be delivered, assuming max bandwidth - // is correct. - expectedAckedBytes := congestion.ByteCount(b.maxBandwidth.GetBest()) * - congestion.ByteCount((ackTime.Sub(b.aggregationEpochStartTime))) - // Reset the current aggregation epoch as soon as the ack arrival rate is less - // than or equal to the max bandwidth. - if b.aggregationEpochBytes <= expectedAckedBytes { - // Reset to start measuring a new aggregation epoch. - b.aggregationEpochBytes = ackedBytes - b.aggregationEpochStartTime = ackTime - return 0 - } - // Compute how many extra bytes were delivered vs max bandwidth. - // Include the bytes most recently acknowledged to account for stretch acks. - b.aggregationEpochBytes += ackedBytes - b.maxAckHeight.Update(int64(b.aggregationEpochBytes-expectedAckedBytes), b.roundTripCount) - return b.aggregationEpochBytes - expectedAckedBytes -} - -func (b *bbrSender) UpdateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) { - bytesInFlight := b.GetBytesInFlight() - // In most cases, the cycle is advanced after an RTT passes. - shouldAdvanceGainCycling := now.Sub(b.lastCycleStart) > b.GetMinRtt() - - // If the pacing gain is above 1.0, the connection is trying to probe the - // bandwidth by increasing the number of bytes in flight to at least - // pacing_gain * BDP. Make sure that it actually reaches the target, as long - // as there are no losses suggesting that the buffers are not able to hold - // that much. - if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.GetTargetCongestionWindow(b.pacingGain) { - shouldAdvanceGainCycling = false - } - // If pacing gain is below 1.0, the connection is trying to drain the extra - // queue which could have been incurred by probing prior to it. If the number - // of bytes in flight falls down to the estimated BDP value earlier, conclude - // that the queue has been successfully drained and exit this cycle early. - if b.pacingGain < 1.0 && bytesInFlight <= b.GetTargetCongestionWindow(1.0) { - shouldAdvanceGainCycling = true - } - - if shouldAdvanceGainCycling { - b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % GainCycleLength - b.lastCycleStart = now - // Stay in low gain mode until the target BDP is hit. - // Low gain mode will be exited immediately when the target BDP is achieved. - if b.drainToTarget && b.pacingGain < 1.0 && PacingGain[b.cycleCurrentOffset] == 1.0 && - bytesInFlight > b.GetTargetCongestionWindow(1.0) { - return - } - b.pacingGain = PacingGain[b.cycleCurrentOffset] - } -} - -func (b *bbrSender) GetTargetCongestionWindow(gain float64) congestion.ByteCount { - bdp := congestion.ByteCount(b.GetMinRtt()) * congestion.ByteCount(b.BandwidthEstimate()) - congestionWindow := congestion.ByteCount(gain * float64(bdp)) - - // BDP estimate will be zero if no bandwidth samples are available yet. - if congestionWindow == 0 { - congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow)) - } - - return maxByteCount(congestionWindow, b.minCongestionWindow()) -} - -func (b *bbrSender) CheckIfFullBandwidthReached() { - if b.lastSampleIsAppLimited { - return - } - - target := Bandwidth(float64(b.bandwidthAtLastRound) * StartupGrowthTarget) - if b.BandwidthEstimate() >= target { - b.bandwidthAtLastRound = b.BandwidthEstimate() - b.roundsWithoutBandwidthGain = 0 - if b.expireAckAggregationInStartup { - // Expire old excess delivery measurements now that bandwidth increased. - b.maxAckHeight.Reset(0, b.roundTripCount) - } - return - } - b.roundsWithoutBandwidthGain++ - if b.roundsWithoutBandwidthGain >= b.numStartupRtts || (b.exitStartupOnLoss && b.InRecovery()) { - b.isAtFullBandwidth = true - } -} - -func (b *bbrSender) MaybeExitStartupOrDrain(now time.Time) { - if b.mode == STARTUP && b.isAtFullBandwidth { - b.OnExitStartup(now) - b.mode = DRAIN - b.pacingGain = b.drainGain - b.congestionWindowGain = b.highCwndGain - } - if b.mode == DRAIN && b.GetBytesInFlight() <= b.GetTargetCongestionWindow(1) { - b.EnterProbeBandwidthMode(now) - } -} - -func (b *bbrSender) EnterProbeBandwidthMode(now time.Time) { - b.mode = PROBE_BW - b.congestionWindowGain = b.congestionWindowGainConst - - // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is - // excluded because in that case increased gain and decreased gain would not - // follow each other. - b.cycleCurrentOffset = rand.Int() % (GainCycleLength - 1) - if b.cycleCurrentOffset >= 1 { - b.cycleCurrentOffset += 1 - } - - b.lastCycleStart = now - b.pacingGain = PacingGain[b.cycleCurrentOffset] -} - -func (b *bbrSender) MaybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) { - if minRttExpired && !b.exitingQuiescence && b.mode != PROBE_RTT { - if b.InSlowStart() { - b.OnExitStartup(now) - } - b.mode = PROBE_RTT - b.pacingGain = 1.0 - // Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight| - // is at the target small value. - b.exitProbeRttAt = time.Time{} - } - - if b.mode == PROBE_RTT { - b.sampler.OnAppLimited() - if b.exitProbeRttAt.IsZero() { - // If the window has reached the appropriate size, schedule exiting - // PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but - // we allow an extra packet since QUIC checks CWND before sending a - // packet. - if b.GetBytesInFlight() < b.ProbeRttCongestionWindow()+b.maxDatagramSize { - b.exitProbeRttAt = now.Add(ProbeRttTime) - b.probeRttRoundPassed = false - } - } else { - if isRoundStart { - b.probeRttRoundPassed = true - } - if !now.Before(b.exitProbeRttAt) && b.probeRttRoundPassed { - b.minRttTimestamp = now - if !b.isAtFullBandwidth { - b.EnterStartupMode(now) - } else { - b.EnterProbeBandwidthMode(now) - } - } - } - } - b.exitingQuiescence = false -} - -func (b *bbrSender) ProbeRttCongestionWindow() congestion.ByteCount { - if b.probeRttBasedOnBdp { - return b.GetTargetCongestionWindow(ModerateProbeRttMultiplier) - } else { - return b.minCongestionWindow() - } -} - -func (b *bbrSender) EnterStartupMode(now time.Time) { - // if b.rttStats != nil { - // TODO: slow start. - // } - b.mode = STARTUP - b.pacingGain = b.highGain - b.congestionWindowGain = b.highCwndGain -} - -func (b *bbrSender) OnExitStartup(now time.Time) { - if b.rttStats == nil { - return - } - // TODO: slow start. -} - -func (b *bbrSender) CalculatePacingRate() { - if b.BandwidthEstimate() == 0 { - return - } - - targetRate := Bandwidth(b.pacingGain * float64(b.BandwidthEstimate())) - if b.isAtFullBandwidth { - b.pacingRate = targetRate - return - } - - // Pace at the rate of initial_window / RTT as soon as RTT measurements are - // available. - if b.pacingRate == 0 && b.rttStats.MinRTT() > 0 { - b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT()) - return - } - // Slow the pacing rate in STARTUP once loss has ever been detected. - hasEverDetectedLoss := b.endRecoveryAt > 0 - if b.slowerStartup && hasEverDetectedLoss && b.hasNoAppLimitedSample { - b.pacingRate = Bandwidth(StartupAfterLossGain * float64(b.BandwidthEstimate())) - return - } - - // Slow the pacing rate in STARTUP by the bytes_lost / CWND. - if b.startupRateReductionMultiplier != 0 && hasEverDetectedLoss && b.hasNoAppLimitedSample { - b.pacingRate = Bandwidth((1.0 - (float64(b.startupBytesLost) * float64(b.startupRateReductionMultiplier) / float64(b.congestionWindow))) * float64(targetRate)) - // Ensure the pacing rate doesn't drop below the startup growth target times - // the bandwidth estimate. - b.pacingRate = maxBandwidth(b.pacingRate, Bandwidth(StartupGrowthTarget*float64(b.BandwidthEstimate()))) - return - } - - // Do not decrease the pacing rate during startup. - b.pacingRate = maxBandwidth(b.pacingRate, targetRate) -} - -func (b *bbrSender) CalculateCongestionWindow(ackedBytes, excessAcked congestion.ByteCount) { - if b.mode == PROBE_RTT { - return - } - - targetWindow := b.GetTargetCongestionWindow(b.congestionWindowGain) - if b.isAtFullBandwidth { - // Add the max recently measured ack aggregation to CWND. - targetWindow += congestion.ByteCount(b.maxAckHeight.GetBest()) - } else if b.enableAckAggregationDuringStartup { - // Add the most recent excess acked. Because CWND never decreases in - // STARTUP, this will automatically create a very localized max filter. - targetWindow += excessAcked - } - - // Instead of immediately setting the target CWND as the new one, BBR grows - // the CWND towards |target_window| by only increasing it |bytes_acked| at a - // time. - addBytesAcked := true || !b.InRecovery() - if b.isAtFullBandwidth { - b.congestionWindow = minByteCount(targetWindow, b.congestionWindow+ackedBytes) - } else if addBytesAcked && (b.congestionWindow < targetWindow || b.sampler.totalBytesAcked < b.initialCongestionWindow) { - // If the connection is not yet out of startup phase, do not decrease the - // window. - b.congestionWindow += ackedBytes - } - - // Enforce the limits on the congestion window. - b.congestionWindow = maxByteCount(b.congestionWindow, b.minCongestionWindow()) - b.congestionWindow = minByteCount(b.congestionWindow, b.maxCongestionWindow()) -} - -func (b *bbrSender) CalculateRecoveryWindow(ackedBytes, lostBytes congestion.ByteCount) { - if b.rateBasedStartup && b.mode == STARTUP { - return - } - - if b.recoveryState == NOT_IN_RECOVERY { - return - } - - // Set up the initial recovery window. - if b.recoveryWindow == 0 { - b.recoveryWindow = maxByteCount(b.GetBytesInFlight()+ackedBytes, b.minCongestionWindow()) - return - } - - // Remove losses from the recovery window, while accounting for a potential - // integer underflow. - if b.recoveryWindow >= lostBytes { - b.recoveryWindow -= lostBytes - } else { - b.recoveryWindow = congestion.ByteCount(b.maxDatagramSize) - } - // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, - // release additional |bytes_acked| to achieve a slow-start-like behavior. - if b.recoveryState == GROWTH { - b.recoveryWindow += ackedBytes - } - // Sanity checks. Ensure that we always allow to send at least an MSS or - // |bytes_acked| in response, whichever is larger. - b.recoveryWindow = maxByteCount(b.recoveryWindow, b.GetBytesInFlight()+ackedBytes) - b.recoveryWindow = maxByteCount(b.recoveryWindow, b.minCongestionWindow()) -} - -var _ congestion.CongestionControl = (*bbrSender)(nil) - -func (b *bbrSender) GetMinRtt() time.Duration { - if b.minRtt > 0 { - return b.minRtt - } else { - return InitialRtt - } -} - -func minRtt(a, b time.Duration) time.Duration { - if a < b { - return a - } else { - return b - } -} - -func minBandwidth(a, b Bandwidth) Bandwidth { - if a < b { - return a - } else { - return b - } -} - -func maxBandwidth(a, b Bandwidth) Bandwidth { - if a > b { - return a - } else { - return b - } -} - -func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount { - if a > b { - return a - } else { - return b - } -} - -func minByteCount(a, b congestion.ByteCount) congestion.ByteCount { - if a < b { - return a - } else { - return b - } -} - -var InfiniteRTT = time.Duration(math.MaxInt64) diff --git a/transport/tuic/congestion/clock.go b/transport/tuic/congestion/clock.go deleted file mode 100644 index dc3ccdc5..00000000 --- a/transport/tuic/congestion/clock.go +++ /dev/null @@ -1,20 +0,0 @@ -package congestion - -import "time" - -// A Clock returns the current time -type Clock interface { - Now() time.Time -} - -// DefaultClock implements the Clock interface using the Go stdlib clock. -type DefaultClock struct { - TimeFunc func() time.Time -} - -var _ Clock = DefaultClock{} - -// Now gets the current time -func (c DefaultClock) Now() time.Time { - return c.TimeFunc() -} diff --git a/transport/tuic/congestion/cubic.go b/transport/tuic/congestion/cubic.go deleted file mode 100644 index d437c540..00000000 --- a/transport/tuic/congestion/cubic.go +++ /dev/null @@ -1,213 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/sagernet/quic-go/congestion" -) - -// This cubic implementation is based on the one found in Chromiums's QUIC -// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. - -// Constants based on TCP defaults. -// The following constants are in 2^10 fractions of a second instead of ms to -// allow a 10 shift right to divide. - -// 1024*1024^3 (first 1024 is from 0.100^3) -// where 0.100 is 100 ms which is the scaling round trip time. -const ( - cubeScale = 40 - cubeCongestionWindowScale = 410 - cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize - // TODO: when re-enabling cubic, make sure to use the actual packet size here - maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4) -) - -const defaultNumConnections = 1 - -// Default Cubic backoff factor -const beta float32 = 0.7 - -// Additional backoff factor when loss occurs in the concave part of the Cubic -// curve. This additional backoff factor is expected to give up bandwidth to -// new concurrent flows and speed up convergence. -const betaLastMax float32 = 0.85 - -// Cubic implements the cubic algorithm from TCP -type Cubic struct { - clock Clock - - // Number of connections to simulate. - numConnections int - - // Time when this cycle started, after last loss event. - epoch time.Time - - // Max congestion window used just before last loss event. - // Note: to improve fairness to other streams an additional back off is - // applied to this value if the new value is below our latest value. - lastMaxCongestionWindow congestion.ByteCount - - // Number of acked bytes since the cycle started (epoch). - ackedBytesCount congestion.ByteCount - - // TCP Reno equivalent congestion window in packets. - estimatedTCPcongestionWindow congestion.ByteCount - - // Origin point of cubic function. - originPointCongestionWindow congestion.ByteCount - - // Time to origin point of cubic function in 2^10 fractions of a second. - timeToOriginPoint uint32 - - // Last congestion window in packets computed by cubic function. - lastTargetCongestionWindow congestion.ByteCount -} - -// NewCubic returns a new Cubic instance -func NewCubic(clock Clock) *Cubic { - c := &Cubic{ - clock: clock, - numConnections: defaultNumConnections, - } - c.Reset() - return c -} - -// Reset is called after a timeout to reset the cubic state -func (c *Cubic) Reset() { - c.epoch = time.Time{} - c.lastMaxCongestionWindow = 0 - c.ackedBytesCount = 0 - c.estimatedTCPcongestionWindow = 0 - c.originPointCongestionWindow = 0 - c.timeToOriginPoint = 0 - c.lastTargetCongestionWindow = 0 -} - -func (c *Cubic) alpha() float32 { - // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that - // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. - // We derive the equivalent alpha for an N-connection emulation as: - b := c.beta() - return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) -} - -func (c *Cubic) beta() float32 { - // kNConnectionBeta is the backoff factor after loss for our N-connection - // emulation, which emulates the effective backoff of an ensemble of N - // TCP-Reno connections on a single loss event. The effective multiplier is - // computed as: - return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) -} - -func (c *Cubic) betaLastMax() float32 { - // betaLastMax is the additional backoff factor after loss for our - // N-connection emulation, which emulates the additional backoff of - // an ensemble of N TCP-Reno connections on a single loss event. The - // effective multiplier is computed as: - return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) -} - -// OnApplicationLimited is called on ack arrival when sender is unable to use -// the available congestion window. Resets Cubic state during quiescence. -func (c *Cubic) OnApplicationLimited() { - // When sender is not using the available congestion window, the window does - // not grow. But to be RTT-independent, Cubic assumes that the sender has been - // using the entire window during the time since the beginning of the current - // "epoch" (the end of the last loss recovery period). Since - // application-limited periods break this assumption, we reset the epoch when - // in such a period. This reset effectively freezes congestion window growth - // through application-limited periods and allows Cubic growth to continue - // when the entire window is being used. - c.epoch = time.Time{} -} - -// CongestionWindowAfterPacketLoss computes a new congestion window to use after -// a loss event. Returns the new congestion window in packets. The new -// congestion window is a multiplicative decrease of our current window. -func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount { - if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { - // We never reached the old max, so assume we are competing with another - // flow. Use our extra back off factor to allow the other flow to go up. - c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) - } else { - c.lastMaxCongestionWindow = currentCongestionWindow - } - c.epoch = time.Time{} // Reset time. - return congestion.ByteCount(float32(currentCongestionWindow) * c.beta()) -} - -// CongestionWindowAfterAck computes a new congestion window to use after a received ACK. -// Returns the new congestion window in packets. The new congestion window -// follows a cubic function that depends on the time passed since last -// packet loss. -func (c *Cubic) CongestionWindowAfterAck( - ackedBytes congestion.ByteCount, - currentCongestionWindow congestion.ByteCount, - delayMin time.Duration, - eventTime time.Time, -) congestion.ByteCount { - c.ackedBytesCount += ackedBytes - - if c.epoch.IsZero() { - // First ACK after a loss event. - c.epoch = eventTime // Start of epoch. - c.ackedBytesCount = ackedBytes // Reset count. - // Reset estimated_tcp_congestion_window_ to be in sync with cubic. - c.estimatedTCPcongestionWindow = currentCongestionWindow - if c.lastMaxCongestionWindow <= currentCongestionWindow { - c.timeToOriginPoint = 0 - c.originPointCongestionWindow = currentCongestionWindow - } else { - c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) - c.originPointCongestionWindow = c.lastMaxCongestionWindow - } - } - - // Change the time unit from microseconds to 2^10 fractions per second. Take - // the round trip time in account. This is done to allow us to use shift as a - // divide operator. - elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) - - // Right-shifts of negative, signed numbers have implementation-dependent - // behavior, so force the offset to be positive, as is done in the kernel. - offset := int64(c.timeToOriginPoint) - elapsedTime - if offset < 0 { - offset = -offset - } - - deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale - var targetCongestionWindow congestion.ByteCount - if elapsedTime > int64(c.timeToOriginPoint) { - targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow - } else { - targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow - } - // Limit the CWND increase to half the acked bytes. - targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) - - // Increase the window by approximately Alpha * 1 MSS of bytes every - // time we ack an estimated tcp window of bytes. For small - // congestion windows (less than 25), the formula below will - // increase slightly slower than linearly per estimated tcp window - // of bytes. - c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) - c.ackedBytesCount = 0 - - // We have a new cubic congestion window. - c.lastTargetCongestionWindow = targetCongestionWindow - - // Compute target congestion_window based on cubic target and estimated TCP - // congestion_window, use highest (fastest). - if targetCongestionWindow < c.estimatedTCPcongestionWindow { - targetCongestionWindow = c.estimatedTCPcongestionWindow - } - return targetCongestionWindow -} - -// SetNumConnections sets the number of emulated connections -func (c *Cubic) SetNumConnections(n int) { - c.numConnections = n -} diff --git a/transport/tuic/congestion/cubic_sender.go b/transport/tuic/congestion/cubic_sender.go deleted file mode 100644 index fc97d17a..00000000 --- a/transport/tuic/congestion/cubic_sender.go +++ /dev/null @@ -1,318 +0,0 @@ -package congestion - -import ( - "fmt" - "time" - - "github.com/sagernet/quic-go/congestion" - "github.com/sagernet/quic-go/logging" -) - -const ( - maxBurstPackets = 3 - renoBeta = 0.7 // Reno backoff factor. - minCongestionWindowPackets = 2 - initialCongestionWindow = 32 -) - -const ( - InvalidPacketNumber congestion.PacketNumber = -1 - MaxCongestionWindowPackets = 20000 - MaxByteCount = congestion.ByteCount(1<<62 - 1) -) - -type cubicSender struct { - hybridSlowStart HybridSlowStart - rttStats congestion.RTTStatsProvider - cubic *Cubic - pacer *pacer - clock Clock - - reno bool - - // Track the largest packet that has been sent. - largestSentPacketNumber congestion.PacketNumber - - // Track the largest packet that has been acked. - largestAckedPacketNumber congestion.PacketNumber - - // Track the largest packet number outstanding when a CWND cutback occurs. - largestSentAtLastCutback congestion.PacketNumber - - // Whether the last loss event caused us to exit slowstart. - // Used for stats collection of slowstartPacketsLost - lastCutbackExitedSlowstart bool - - // Congestion window in bytes. - congestionWindow congestion.ByteCount - - // Slow start congestion window in bytes, aka ssthresh. - slowStartThreshold congestion.ByteCount - - // ACK counter for the Reno implementation. - numAckedPackets uint64 - - initialCongestionWindow congestion.ByteCount - initialMaxCongestionWindow congestion.ByteCount - - maxDatagramSize congestion.ByteCount - - lastState logging.CongestionState - tracer logging.ConnectionTracer -} - -var _ congestion.CongestionControl = &cubicSender{} - -// NewCubicSender makes a new cubic sender -func NewCubicSender( - clock Clock, - initialMaxDatagramSize congestion.ByteCount, - reno bool, - tracer logging.ConnectionTracer, -) *cubicSender { - return newCubicSender( - clock, - reno, - initialMaxDatagramSize, - initialCongestionWindow*initialMaxDatagramSize, - MaxCongestionWindowPackets*initialMaxDatagramSize, - tracer, - ) -} - -func newCubicSender( - clock Clock, - reno bool, - initialMaxDatagramSize, - initialCongestionWindow, - initialMaxCongestionWindow congestion.ByteCount, - tracer logging.ConnectionTracer, -) *cubicSender { - c := &cubicSender{ - largestSentPacketNumber: InvalidPacketNumber, - largestAckedPacketNumber: InvalidPacketNumber, - largestSentAtLastCutback: InvalidPacketNumber, - initialCongestionWindow: initialCongestionWindow, - initialMaxCongestionWindow: initialMaxCongestionWindow, - congestionWindow: initialCongestionWindow, - slowStartThreshold: MaxByteCount, - cubic: NewCubic(clock), - clock: clock, - reno: reno, - tracer: tracer, - maxDatagramSize: initialMaxDatagramSize, - } - c.pacer = newPacer(c.BandwidthEstimate) - if c.tracer != nil { - c.lastState = logging.CongestionStateSlowStart - c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) - } - return c -} - -func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { - c.rttStats = provider -} - -// TimeUntilSend returns when the next packet should be sent. -func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time { - return c.pacer.TimeUntilSend() -} - -func (c *cubicSender) HasPacingBudget(now time.Time) bool { - return c.pacer.Budget(now) >= c.maxDatagramSize -} - -func (c *cubicSender) maxCongestionWindow() congestion.ByteCount { - return c.maxDatagramSize * MaxCongestionWindowPackets -} - -func (c *cubicSender) minCongestionWindow() congestion.ByteCount { - return c.maxDatagramSize * minCongestionWindowPackets -} - -func (c *cubicSender) OnPacketSent( - sentTime time.Time, - _ congestion.ByteCount, - packetNumber congestion.PacketNumber, - bytes congestion.ByteCount, - isRetransmittable bool, -) { - c.pacer.SentPacket(sentTime, bytes) - if !isRetransmittable { - return - } - c.largestSentPacketNumber = packetNumber - c.hybridSlowStart.OnPacketSent(packetNumber) -} - -func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool { - return bytesInFlight < c.GetCongestionWindow() -} - -func (c *cubicSender) InRecovery() bool { - return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback -} - -func (c *cubicSender) InSlowStart() bool { - return c.GetCongestionWindow() < c.slowStartThreshold -} - -func (c *cubicSender) GetCongestionWindow() congestion.ByteCount { - return c.congestionWindow -} - -func (c *cubicSender) MaybeExitSlowStart() { - if c.InSlowStart() && - c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { - // exit slow start - c.slowStartThreshold = c.congestionWindow - c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) - } -} - -func (c *cubicSender) OnPacketAcked( - ackedPacketNumber congestion.PacketNumber, - ackedBytes congestion.ByteCount, - priorInFlight congestion.ByteCount, - eventTime time.Time, -) { - c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber) - if c.InRecovery() { - return - } - c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) - if c.InSlowStart() { - c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) - } -} - -func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { - // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets - // already sent should be treated as a single loss event, since it's expected. - if packetNumber <= c.largestSentAtLastCutback { - return - } - c.lastCutbackExitedSlowstart = c.InSlowStart() - c.maybeTraceStateChange(logging.CongestionStateRecovery) - - if c.reno { - c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta) - } else { - c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) - } - if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { - c.congestionWindow = minCwnd - } - c.slowStartThreshold = c.congestionWindow - c.largestSentAtLastCutback = c.largestSentPacketNumber - // reset packet count from congestion avoidance mode. We start - // counting again when we're out of recovery. - c.numAckedPackets = 0 -} - -// Called when we receive an ack. Normal TCP tracks how many packets one ack -// represents, but quic has a separate ack for each packet. -func (c *cubicSender) maybeIncreaseCwnd( - _ congestion.PacketNumber, - ackedBytes congestion.ByteCount, - priorInFlight congestion.ByteCount, - eventTime time.Time, -) { - // Do not increase the congestion window unless the sender is close to using - // the current window. - if !c.isCwndLimited(priorInFlight) { - c.cubic.OnApplicationLimited() - c.maybeTraceStateChange(logging.CongestionStateApplicationLimited) - return - } - if c.congestionWindow >= c.maxCongestionWindow() { - return - } - if c.InSlowStart() { - // TCP slow start, exponential growth, increase by one for each ACK. - c.congestionWindow += c.maxDatagramSize - c.maybeTraceStateChange(logging.CongestionStateSlowStart) - return - } - // Congestion avoidance - c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) - if c.reno { - // Classic Reno congestion avoidance. - c.numAckedPackets++ - if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { - c.congestionWindow += c.maxDatagramSize - c.numAckedPackets = 0 - } - } else { - c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) - } -} - -func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool { - congestionWindow := c.GetCongestionWindow() - if bytesInFlight >= congestionWindow { - return true - } - availableBytes := congestionWindow - bytesInFlight - slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 - return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize -} - -// BandwidthEstimate returns the current bandwidth estimate -func (c *cubicSender) BandwidthEstimate() Bandwidth { - if c.rttStats == nil { - return infBandwidth - } - srtt := c.rttStats.SmoothedRTT() - if srtt == 0 { - // If we haven't measured an rtt, the bandwidth estimate is unknown. - return infBandwidth - } - return BandwidthFromDelta(c.GetCongestionWindow(), srtt) -} - -// OnRetransmissionTimeout is called on an retransmission timeout -func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { - c.largestSentAtLastCutback = InvalidPacketNumber - if !packetsRetransmitted { - return - } - c.hybridSlowStart.Restart() - c.cubic.Reset() - c.slowStartThreshold = c.congestionWindow / 2 - c.congestionWindow = c.minCongestionWindow() -} - -// OnConnectionMigration is called when the connection is migrated (?) -func (c *cubicSender) OnConnectionMigration() { - c.hybridSlowStart.Restart() - c.largestSentPacketNumber = InvalidPacketNumber - c.largestAckedPacketNumber = InvalidPacketNumber - c.largestSentAtLastCutback = InvalidPacketNumber - c.lastCutbackExitedSlowstart = false - c.cubic.Reset() - c.numAckedPackets = 0 - c.congestionWindow = c.initialCongestionWindow - c.slowStartThreshold = c.initialMaxCongestionWindow -} - -func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { - if c.tracer == nil || new == c.lastState { - return - } - c.tracer.UpdatedCongestionState(new) - c.lastState = new -} - -func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) { - if s < c.maxDatagramSize { - panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) - } - cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() - c.maxDatagramSize = s - if cwndIsMinCwnd { - c.congestionWindow = c.minCongestionWindow() - } - c.pacer.SetMaxDatagramSize(s) -} diff --git a/transport/tuic/congestion/hybrid_slow_start.go b/transport/tuic/congestion/hybrid_slow_start.go deleted file mode 100644 index eba8f7df..00000000 --- a/transport/tuic/congestion/hybrid_slow_start.go +++ /dev/null @@ -1,112 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/sagernet/quic-go/congestion" -) - -// Note(pwestin): the magic clamping numbers come from the original code in -// tcp_cubic.c. -const hybridStartLowWindow = congestion.ByteCount(16) - -// Number of delay samples for detecting the increase of delay. -const hybridStartMinSamples = uint32(8) - -// Exit slow start if the min rtt has increased by more than 1/8th. -const hybridStartDelayFactorExp = 3 // 2^3 = 8 -// The original paper specifies 2 and 8ms, but those have changed over time. -const ( - hybridStartDelayMinThresholdUs = int64(4000) - hybridStartDelayMaxThresholdUs = int64(16000) -) - -// HybridSlowStart implements the TCP hybrid slow start algorithm -type HybridSlowStart struct { - endPacketNumber congestion.PacketNumber - lastSentPacketNumber congestion.PacketNumber - started bool - currentMinRTT time.Duration - rttSampleCount uint32 - hystartFound bool -} - -// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. -func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) { - s.endPacketNumber = lastSent - s.currentMinRTT = 0 - s.rttSampleCount = 0 - s.started = true -} - -// IsEndOfRound returns true if this ack is the last packet number of our current slow start round. -func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool { - return s.endPacketNumber < ack -} - -// ShouldExitSlowStart should be called on every new ack frame, since a new -// RTT measurement can be made then. -// rtt: the RTT for this ack packet. -// minRTT: is the lowest delay (RTT) we have seen during the session. -// congestionWindow: the congestion window in packets. -func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool { - if !s.started { - // Time to start the hybrid slow start. - s.StartReceiveRound(s.lastSentPacketNumber) - } - if s.hystartFound { - return true - } - // Second detection parameter - delay increase detection. - // Compare the minimum delay (s.currentMinRTT) of the current - // burst of packets relative to the minimum delay during the session. - // Note: we only look at the first few(8) packets in each burst, since we - // only want to compare the lowest RTT of the burst relative to previous - // bursts. - s.rttSampleCount++ - if s.rttSampleCount <= hybridStartMinSamples { - if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { - s.currentMinRTT = latestRTT - } - } - // We only need to check this once per round. - if s.rttSampleCount == hybridStartMinSamples { - // Divide minRTT by 8 to get a rtt increase threshold for exiting. - minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) - // Ensure the rtt threshold is never less than 2ms or more than 16ms. - minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) - minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond - - if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { - s.hystartFound = true - } - } - // Exit from slow start if the cwnd is greater than 16 and - // increasing delay is found. - return congestionWindow >= hybridStartLowWindow && s.hystartFound -} - -// OnPacketSent is called when a packet was sent -func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) { - s.lastSentPacketNumber = packetNumber -} - -// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end -// the round when the final packet of the burst is received and start it on -// the next incoming ack. -func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) { - if s.IsEndOfRound(ackedPacketNumber) { - s.started = false - } -} - -// Started returns true if started -func (s *HybridSlowStart) Started() bool { - return s.started -} - -// Restart the slow start phase -func (s *HybridSlowStart) Restart() { - s.started = false - s.hystartFound = false -} diff --git a/transport/tuic/congestion/minmax.go b/transport/tuic/congestion/minmax.go deleted file mode 100644 index ed75072e..00000000 --- a/transport/tuic/congestion/minmax.go +++ /dev/null @@ -1,72 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "golang.org/x/exp/constraints" -) - -// InfDuration is a duration of infinite length -const InfDuration = time.Duration(math.MaxInt64) - -func Max[T constraints.Ordered](a, b T) T { - if a < b { - return b - } - return a -} - -func Min[T constraints.Ordered](a, b T) T { - if a < b { - return a - } - return b -} - -// MinNonZeroDuration return the minimum duration that's not zero. -func MinNonZeroDuration(a, b time.Duration) time.Duration { - if a == 0 { - return b - } - if b == 0 { - return a - } - return Min(a, b) -} - -// AbsDuration returns the absolute value of a time duration -func AbsDuration(d time.Duration) time.Duration { - if d >= 0 { - return d - } - return -d -} - -// MinTime returns the earlier time -func MinTime(a, b time.Time) time.Time { - if a.After(b) { - return b - } - return a -} - -// MinNonZeroTime returns the earlist time that is not time.Time{} -// If both a and b are time.Time{}, it returns time.Time{} -func MinNonZeroTime(a, b time.Time) time.Time { - if a.IsZero() { - return b - } - if b.IsZero() { - return a - } - return MinTime(a, b) -} - -// MaxTime returns the later time -func MaxTime(a, b time.Time) time.Time { - if a.After(b) { - return a - } - return b -} diff --git a/transport/tuic/congestion/pacer.go b/transport/tuic/congestion/pacer.go deleted file mode 100644 index 5d0f13f6..00000000 --- a/transport/tuic/congestion/pacer.go +++ /dev/null @@ -1,81 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/sagernet/quic-go/congestion" -) - -const ( - initialMaxDatagramSize = congestion.ByteCount(1252) - MinPacingDelay = time.Millisecond - TimerGranularity = time.Millisecond - maxBurstSizePackets = 10 -) - -// The pacer implements a token bucket pacing algorithm. -type pacer struct { - budgetAtLastSent congestion.ByteCount - maxDatagramSize congestion.ByteCount - lastSentTime time.Time - getAdjustedBandwidth func() uint64 // in bytes/s -} - -func newPacer(getBandwidth func() Bandwidth) *pacer { - p := &pacer{ - maxDatagramSize: initialMaxDatagramSize, - getAdjustedBandwidth: func() uint64 { - // Bandwidth is in bits/s. We need the value in bytes/s. - bw := uint64(getBandwidth() / BytesPerSecond) - // Use a slightly higher value than the actual measured bandwidth. - // RTT variations then won't result in under-utilization of the congestion window. - // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, - // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. - return bw * 5 / 4 - }, - } - p.budgetAtLastSent = p.maxBurstSize() - return p -} - -func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { - budget := p.Budget(sendTime) - if size > budget { - p.budgetAtLastSent = 0 - } else { - p.budgetAtLastSent = budget - size - } - p.lastSentTime = sendTime -} - -func (p *pacer) Budget(now time.Time) congestion.ByteCount { - if p.lastSentTime.IsZero() { - return p.maxBurstSize() - } - budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 - return Min(p.maxBurstSize(), budget) -} - -func (p *pacer) maxBurstSize() congestion.ByteCount { - return Max( - congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, - maxBurstSizePackets*p.maxDatagramSize, - ) -} - -// TimeUntilSend returns when the next packet should be sent. -// It returns the zero value of time.Time if a packet can be sent immediately. -func (p *pacer) TimeUntilSend() time.Time { - if p.budgetAtLastSent >= p.maxDatagramSize { - return time.Time{} - } - return p.lastSentTime.Add(Max( - MinPacingDelay, - time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, - )) -} - -func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) { - p.maxDatagramSize = s -} diff --git a/transport/tuic/congestion/windowed_filter.go b/transport/tuic/congestion/windowed_filter.go deleted file mode 100644 index 4da595b9..00000000 --- a/transport/tuic/congestion/windowed_filter.go +++ /dev/null @@ -1,132 +0,0 @@ -package congestion - -// WindowedFilter Use the following to construct a windowed filter object of type T. -// For example, a min filter using QuicTime as the time type: -// -// WindowedFilter, QuicTime, QuicTime::Delta> ObjectName; -// -// A max filter using 64-bit integers as the time type: -// -// WindowedFilter, uint64_t, int64_t> ObjectName; -// -// Specifically, this template takes four arguments: -// 1. T -- type of the measurement that is being filtered. -// 2. Compare -- MinFilter or MaxFilter, depending on the type of filter -// desired. -// 3. TimeT -- the type used to represent timestamps. -// 4. TimeDeltaT -- the type used to represent continuous time intervals between -// two timestamps. Has to be the type of (a - b) if both |a| and |b| are -// of type TimeT. -type WindowedFilter struct { - // Time length of window. - windowLength int64 - estimates []Sample - comparator func(int64, int64) bool -} - -type Sample struct { - sample int64 - time int64 -} - -// Compares two values and returns true if the first is greater than or equal -// to the second. -func MaxFilter(a, b int64) bool { - return a >= b -} - -// Compares two values and returns true if the first is less than or equal -// to the second. -func MinFilter(a, b int64) bool { - return a <= b -} - -func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter { - return &WindowedFilter{ - windowLength: windowLength, - estimates: make([]Sample, 3), - comparator: comparator, - } -} - -// Changes the window length. Does not update any current samples. -func (f *WindowedFilter) SetWindowLength(windowLength int64) { - f.windowLength = windowLength -} - -func (f *WindowedFilter) GetBest() int64 { - return f.estimates[0].sample -} - -func (f *WindowedFilter) GetSecondBest() int64 { - return f.estimates[1].sample -} - -func (f *WindowedFilter) GetThirdBest() int64 { - return f.estimates[2].sample -} - -func (f *WindowedFilter) Update(sample int64, time int64) { - if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength { - f.Reset(sample, time) - return - } - - if f.comparator(sample, f.estimates[1].sample) { - f.estimates[1].sample = sample - f.estimates[1].time = time - f.estimates[2].sample = sample - f.estimates[2].time = time - } else if f.comparator(sample, f.estimates[2].sample) { - f.estimates[2].sample = sample - f.estimates[2].time = time - } - - // Expire and update estimates as necessary. - if time-f.estimates[0].time > f.windowLength { - // The best estimate hasn't been updated for an entire window, so promote - // second and third best estimates. - f.estimates[0].sample = f.estimates[1].sample - f.estimates[0].time = f.estimates[1].time - f.estimates[1].sample = f.estimates[2].sample - f.estimates[1].time = f.estimates[2].time - f.estimates[2].sample = sample - f.estimates[2].time = time - // Need to iterate one more time. Check if the new best estimate is - // outside the window as well, since it may also have been recorded a - // long time ago. Don't need to iterate once more since we cover that - // case at the beginning of the method. - if time-f.estimates[0].time > f.windowLength { - f.estimates[0].sample = f.estimates[1].sample - f.estimates[0].time = f.estimates[1].time - f.estimates[1].sample = f.estimates[2].sample - f.estimates[1].time = f.estimates[2].time - } - return - } - if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 { - // A quarter of the window has passed without a better sample, so the - // second-best estimate is taken from the second quarter of the window. - f.estimates[1].sample = sample - f.estimates[1].time = time - f.estimates[2].sample = sample - f.estimates[2].time = time - return - } - - if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 { - // We've passed a half of the window without a better estimate, so take - // a third-best estimate from the second half of the window. - f.estimates[2].sample = sample - f.estimates[2].time = time - } -} - -func (f *WindowedFilter) Reset(newSample int64, newTime int64) { - f.estimates[0].sample = newSample - f.estimates[0].time = newTime - f.estimates[1].sample = newSample - f.estimates[1].time = newTime - f.estimates[2].sample = newSample - f.estimates[2].time = newTime -} diff --git a/transport/tuic/packet.go b/transport/tuic/packet.go deleted file mode 100644 index abc46206..00000000 --- a/transport/tuic/packet.go +++ /dev/null @@ -1,532 +0,0 @@ -package tuic - -import ( - "bytes" - "context" - "encoding/binary" - "errors" - "io" - "math" - "net" - "os" - "sync" - "time" - - "github.com/sagernet/quic-go" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/cache" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" -) - -var udpMessagePool = sync.Pool{ - New: func() interface{} { - return new(udpMessage) - }, -} - -func allocMessage() *udpMessage { - message := udpMessagePool.Get().(*udpMessage) - message.referenced = true - return message -} - -func releaseMessages(messages []*udpMessage) { - for _, message := range messages { - if message != nil { - message.release() - } - } -} - -type udpMessage struct { - sessionID uint16 - packetID uint16 - fragmentTotal uint8 - fragmentID uint8 - destination M.Socksaddr - data *buf.Buffer - referenced bool -} - -func (m *udpMessage) release() { - if !m.referenced { - return - } - *m = udpMessage{} - udpMessagePool.Put(m) -} - -func (m *udpMessage) releaseMessage() { - m.data.Release() - m.release() -} - -func (m *udpMessage) pack() *buf.Buffer { - buffer := buf.NewSize(m.headerSize() + m.data.Len()) - common.Must( - buffer.WriteByte(Version), - buffer.WriteByte(CommandPacket), - binary.Write(buffer, binary.BigEndian, m.sessionID), - binary.Write(buffer, binary.BigEndian, m.packetID), - binary.Write(buffer, binary.BigEndian, m.fragmentTotal), - binary.Write(buffer, binary.BigEndian, m.fragmentID), - binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())), - addressSerializer.WriteAddrPort(buffer, m.destination), - common.Error(buffer.Write(m.data.Bytes())), - ) - return buffer -} - -func (m *udpMessage) headerSize() int { - return 10 + addressSerializer.AddrPortLen(m.destination) -} - -func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { - if message.data.Len() <= maxPacketSize { - return []*udpMessage{message} - } - var fragments []*udpMessage - originPacket := message.data.Bytes() - udpMTU := maxPacketSize - message.headerSize() - for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { - fragment := allocMessage() - *fragment = *message - if remaining > udpMTU { - fragment.data = buf.As(originPacket[:udpMTU]) - originPacket = originPacket[udpMTU:] - } else { - fragment.data = buf.As(originPacket) - originPacket = nil - } - fragments = append(fragments, fragment) - } - fragmentTotal := uint16(len(fragments)) - for index, fragment := range fragments { - fragment.fragmentID = uint8(index) - fragment.fragmentTotal = uint8(fragmentTotal) - if index > 0 { - fragment.destination = M.Socksaddr{} - } - } - return fragments -} - -type udpPacketConn struct { - ctx context.Context - cancel common.ContextCancelCauseFunc - sessionID uint16 - quicConn quic.Connection - data chan *udpMessage - udpStream bool - udpMTU int - udpMTUTime time.Time - packetId atomic.Uint32 - closeOnce sync.Once - isServer bool - defragger *udpDefragger - onDestroy func() -} - -func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn { - ctx, cancel := common.ContextWithCancelCause(ctx) - return &udpPacketConn{ - ctx: ctx, - cancel: cancel, - quicConn: quicConn, - data: make(chan *udpMessage, 64), - udpStream: udpStream, - isServer: isServer, - defragger: newUDPDefragger(), - onDestroy: onDestroy, - } -} - -func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case p := <-c.data: - buffer = p.data - destination = p.destination - p.release() - return - case <-c.ctx.Done(): - return nil, M.Socksaddr{}, io.ErrClosedPipe - } -} - -func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case p := <-c.data: - _, err = buffer.ReadOnceFrom(p.data) - destination = p.destination - p.releaseMessage() - return - case <-c.ctx.Done(): - return M.Socksaddr{}, io.ErrClosedPipe - } -} - -func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case p := <-c.data: - _, err = newBuffer().ReadOnceFrom(p.data) - destination = p.destination - p.releaseMessage() - return - case <-c.ctx.Done(): - return M.Socksaddr{}, io.ErrClosedPipe - } -} - -func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - select { - case pkt := <-c.data: - n = copy(p, pkt.data.Bytes()) - if pkt.destination.IsFqdn() { - addr = pkt.destination - } else { - addr = pkt.destination.UDPAddr() - } - pkt.releaseMessage() - return n, addr, nil - case <-c.ctx.Done(): - return 0, nil, io.ErrClosedPipe - } -} - -func (c *udpPacketConn) needFragment() bool { - nowTime := time.Now() - if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { - c.udpMTUTime = nowTime - return true - } - return false -} - -func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - defer buffer.Release() - select { - case <-c.ctx.Done(): - return net.ErrClosed - default: - } - if buffer.Len() > 0xffff { - return quic.ErrMessageTooLarge(0xffff) - } - if !destination.IsValid() { - return E.New("invalid destination address") - } - packetId := c.packetId.Add(1) - if packetId > math.MaxUint16 { - c.packetId.Store(0) - packetId = 0 - } - message := allocMessage() - *message = udpMessage{ - sessionID: c.sessionID, - packetID: uint16(packetId), - fragmentTotal: 1, - destination: destination, - data: buffer, - } - defer message.releaseMessage() - var err error - if !c.udpStream && c.needFragment() && buffer.Len() > c.udpMTU { - err = c.writePackets(fragUDPMessage(message, c.udpMTU)) - } else { - err = c.writePacket(message) - } - if err == nil { - return nil - } - var tooLargeErr quic.ErrMessageTooLarge - if !errors.As(err, &tooLargeErr) { - return err - } - c.udpMTU = int(tooLargeErr) - c.udpMTUTime = time.Now() - return c.writePackets(fragUDPMessage(message, c.udpMTU)) -} - -func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - select { - case <-c.ctx.Done(): - return 0, net.ErrClosed - default: - } - if len(p) > 0xffff { - return 0, quic.ErrMessageTooLarge(0xffff) - } - destination := M.SocksaddrFromNet(addr) - if !destination.IsValid() { - return 0, E.New("invalid destination address") - } - packetId := c.packetId.Add(1) - if packetId > math.MaxUint16 { - c.packetId.Store(0) - packetId = 0 - } - message := allocMessage() - *message = udpMessage{ - sessionID: c.sessionID, - packetID: uint16(packetId), - fragmentTotal: 1, - destination: destination, - data: buf.As(p), - } - if !c.udpStream && c.needFragment() && len(p) > c.udpMTU { - err = c.writePackets(fragUDPMessage(message, c.udpMTU)) - if err == nil { - return len(p), nil - } - } else { - err = c.writePacket(message) - } - if err == nil { - return len(p), nil - } - var tooLargeErr quic.ErrMessageTooLarge - if !errors.As(err, &tooLargeErr) { - return - } - c.udpMTU = int(tooLargeErr) - c.udpMTUTime = time.Now() - err = c.writePackets(fragUDPMessage(message, c.udpMTU)) - if err == nil { - return len(p), nil - } - return -} - -func (c *udpPacketConn) inputPacket(message *udpMessage) { - if message.fragmentTotal <= 1 { - select { - case c.data <- message: - default: - } - } else { - newMessage := c.defragger.feed(message) - if newMessage != nil { - select { - case c.data <- newMessage: - default: - } - } - } -} - -func (c *udpPacketConn) writePackets(messages []*udpMessage) error { - defer releaseMessages(messages) - for _, message := range messages { - err := c.writePacket(message) - if err != nil { - return err - } - } - return nil -} - -func (c *udpPacketConn) writePacket(message *udpMessage) error { - if !c.udpStream { - buffer := message.pack() - err := c.quicConn.SendMessage(buffer.Bytes()) - buffer.Release() - if err != nil { - return err - } - } else { - stream, err := c.quicConn.OpenUniStream() - if err != nil { - return err - } - buffer := message.pack() - _, err = stream.Write(buffer.Bytes()) - buffer.Release() - stream.Close() - if err != nil { - return err - } - } - return nil -} - -func (c *udpPacketConn) Close() error { - c.closeOnce.Do(func() { - c.closeWithError(os.ErrClosed) - c.onDestroy() - }) - return nil -} - -func (c *udpPacketConn) closeWithError(err error) { - c.cancel(err) - if !c.isServer { - buffer := buf.NewSize(4) - defer buffer.Release() - buffer.WriteByte(Version) - buffer.WriteByte(CommandDissociate) - binary.Write(buffer, binary.BigEndian, c.sessionID) - sendStream, openErr := c.quicConn.OpenUniStream() - if openErr != nil { - return - } - defer sendStream.Close() - sendStream.Write(buffer.Bytes()) - } -} - -func (c *udpPacketConn) LocalAddr() net.Addr { - return c.quicConn.LocalAddr() -} - -func (c *udpPacketConn) SetDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *udpPacketConn) SetReadDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { - return os.ErrInvalid -} - -type udpDefragger struct { - packetMap *cache.LruCache[uint16, *packetItem] -} - -func newUDPDefragger() *udpDefragger { - return &udpDefragger{ - packetMap: cache.New( - cache.WithAge[uint16, *packetItem](10), - cache.WithUpdateAgeOnGet[uint16, *packetItem](), - cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) { - releaseMessages(value.messages) - }), - ), - } -} - -type packetItem struct { - access sync.Mutex - messages []*udpMessage - count uint8 -} - -func (d *udpDefragger) feed(m *udpMessage) *udpMessage { - if m.fragmentTotal <= 1 { - return m - } - if m.fragmentID >= m.fragmentTotal { - return nil - } - item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem) - item.access.Lock() - defer item.access.Unlock() - if int(m.fragmentTotal) != len(item.messages) { - releaseMessages(item.messages) - item.messages = make([]*udpMessage, m.fragmentTotal) - item.count = 1 - item.messages[m.fragmentID] = m - return nil - } - if item.messages[m.fragmentID] != nil { - return nil - } - item.messages[m.fragmentID] = m - item.count++ - if int(item.count) != len(item.messages) { - return nil - } - newMessage := allocMessage() - *newMessage = *item.messages[0] - var dataLength uint16 - for _, message := range item.messages { - dataLength += uint16(message.data.Len()) - } - if dataLength > 0 { - newMessage.data = buf.NewSize(int(dataLength)) - for _, message := range item.messages { - common.Must1(newMessage.data.Write(message.data.Bytes())) - message.releaseMessage() - } - item.messages = nil - return newMessage - } - item.messages = nil - return nil -} - -func newPacketItem() *packetItem { - return new(packetItem) -} - -func readUDPMessage(message *udpMessage, reader io.Reader) error { - err := binary.Read(reader, binary.BigEndian, &message.sessionID) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.packetID) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.fragmentID) - if err != nil { - return err - } - var dataLength uint16 - err = binary.Read(reader, binary.BigEndian, &dataLength) - if err != nil { - return err - } - message.destination, err = addressSerializer.ReadAddrPort(reader) - if err != nil { - return err - } - message.data = buf.NewSize(int(dataLength)) - _, err = message.data.ReadFullFrom(reader, message.data.FreeLen()) - if err != nil { - return err - } - return nil -} - -func decodeUDPMessage(message *udpMessage, data []byte) error { - reader := bytes.NewReader(data) - err := binary.Read(reader, binary.BigEndian, &message.sessionID) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.packetID) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) - if err != nil { - return err - } - err = binary.Read(reader, binary.BigEndian, &message.fragmentID) - if err != nil { - return err - } - var dataLength uint16 - err = binary.Read(reader, binary.BigEndian, &dataLength) - if err != nil { - return err - } - message.destination, err = addressSerializer.ReadAddrPort(reader) - if err != nil { - return err - } - if reader.Len() != int(dataLength) { - return io.ErrUnexpectedEOF - } - message.data = buf.As(data[len(data)-reader.Len():]) - return nil -} diff --git a/transport/tuic/protocol.go b/transport/tuic/protocol.go deleted file mode 100644 index 1247516b..00000000 --- a/transport/tuic/protocol.go +++ /dev/null @@ -1,15 +0,0 @@ -package tuic - -const ( - Version = 5 -) - -const ( - CommandAuthenticate = iota - CommandConnect - CommandPacket - CommandDissociate - CommandHeartbeat -) - -const AuthenticateLen = 2 + 16 + 32 diff --git a/transport/tuic/server.go b/transport/tuic/server.go deleted file mode 100644 index d4f6de8f..00000000 --- a/transport/tuic/server.go +++ /dev/null @@ -1,437 +0,0 @@ -//go:build with_quic - -package tuic - -import ( - "bytes" - "context" - "encoding/binary" - "io" - "net" - "runtime" - "strings" - "sync" - "time" - - "github.com/sagernet/quic-go" - "github.com/sagernet/sing-box/common/qtls" - "github.com/sagernet/sing-box/common/tls" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/auth" - "github.com/sagernet/sing/common/baderror" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - - "github.com/gofrs/uuid/v5" -) - -type ServerOptions struct { - Context context.Context - Logger logger.Logger - TLSConfig tls.ServerConfig - Users []User - CongestionControl string - AuthTimeout time.Duration - ZeroRTTHandshake bool - Heartbeat time.Duration - Handler ServerHandler -} - -type User struct { - Name string - UUID uuid.UUID - Password string -} - -type ServerHandler interface { - N.TCPConnectionHandler - N.UDPConnectionHandler -} - -type Server struct { - ctx context.Context - logger logger.Logger - tlsConfig tls.ServerConfig - heartbeat time.Duration - quicConfig *quic.Config - userMap map[uuid.UUID]User - congestionControl string - authTimeout time.Duration - handler ServerHandler - - quicListener io.Closer -} - -func NewServer(options ServerOptions) (*Server, error) { - if options.AuthTimeout == 0 { - options.AuthTimeout = 3 * time.Second - } - if options.Heartbeat == 0 { - options.Heartbeat = 10 * time.Second - } - quicConfig := &quic.Config{ - DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), - MaxDatagramFrameSize: 1400, - EnableDatagrams: true, - Allow0RTT: options.ZeroRTTHandshake, - MaxIncomingStreams: 1 << 60, - MaxIncomingUniStreams: 1 << 60, - } - switch options.CongestionControl { - case "": - options.CongestionControl = "cubic" - case "cubic", "new_reno", "bbr": - default: - return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) - } - if len(options.Users) == 0 { - return nil, E.New("missing users") - } - userMap := make(map[uuid.UUID]User) - for _, user := range options.Users { - userMap[user.UUID] = user - } - return &Server{ - ctx: options.Context, - logger: options.Logger, - tlsConfig: options.TLSConfig, - heartbeat: options.Heartbeat, - quicConfig: quicConfig, - userMap: userMap, - congestionControl: options.CongestionControl, - authTimeout: options.AuthTimeout, - handler: options.Handler, - }, nil -} - -func (s *Server) Start(conn net.PacketConn) error { - if !s.quicConfig.Allow0RTT { - listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) - if err != nil { - return err - } - s.quicListener = listener - go func() { - for { - connection, hErr := listener.Accept(s.ctx) - if hErr != nil { - if strings.Contains(hErr.Error(), "server closed") { - s.logger.Debug(E.Cause(hErr, "listener closed")) - } else { - s.logger.Error(E.Cause(hErr, "listener closed")) - } - return - } - go s.handleConnection(connection) - } - }() - } else { - listener, err := qtls.ListenEarly(conn, s.tlsConfig, s.quicConfig) - if err != nil { - return err - } - s.quicListener = listener - go func() { - for { - connection, hErr := listener.Accept(s.ctx) - if hErr != nil { - if strings.Contains(hErr.Error(), "server closed") { - s.logger.Debug(E.Cause(hErr, "listener closed")) - } else { - s.logger.Error(E.Cause(hErr, "listener closed")) - } - return - } - go s.handleConnection(connection) - } - }() - } - return nil -} - -func (s *Server) Close() error { - return common.Close( - s.quicListener, - ) -} - -func (s *Server) handleConnection(connection quic.Connection) { - setCongestion(s.ctx, connection, s.congestionControl) - session := &serverSession{ - Server: s, - ctx: s.ctx, - quicConn: connection, - source: M.SocksaddrFromNet(connection.RemoteAddr()), - connDone: make(chan struct{}), - authDone: make(chan struct{}), - udpConnMap: make(map[uint16]*udpPacketConn), - } - session.handle() -} - -type serverSession struct { - *Server - ctx context.Context - quicConn quic.Connection - source M.Socksaddr - connAccess sync.Mutex - connDone chan struct{} - connErr error - authDone chan struct{} - authUser *User - udpAccess sync.RWMutex - udpConnMap map[uint16]*udpPacketConn -} - -func (s *serverSession) handle() { - if s.ctx.Done() != nil { - go func() { - select { - case <-s.ctx.Done(): - s.closeWithError(s.ctx.Err()) - case <-s.connDone: - } - }() - } - go s.loopUniStreams() - go s.loopStreams() - go s.loopMessages() - go s.handleAuthTimeout() - go s.loopHeartbeats() -} - -func (s *serverSession) loopUniStreams() { - for { - uniStream, err := s.quicConn.AcceptUniStream(s.ctx) - if err != nil { - return - } - go func() { - err = s.handleUniStream(uniStream) - if err != nil { - s.closeWithError(E.Cause(err, "handle uni stream")) - } - }() - } -} - -func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error { - defer stream.CancelRead(0) - buffer := buf.New() - defer buffer.Release() - _, err := buffer.ReadAtLeastFrom(stream, 2) - if err != nil { - return E.Cause(err, "read request") - } - version := buffer.Byte(0) - if version != Version { - return E.New("unknown version ", buffer.Byte(0)) - } - command := buffer.Byte(1) - switch command { - case CommandAuthenticate: - select { - case <-s.authDone: - return E.New("authentication: multiple authentication requests") - default: - } - if buffer.Len() < AuthenticateLen { - _, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len()) - if err != nil { - return E.Cause(err, "authentication: read request") - } - } - userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16)) - user, loaded := s.userMap[userUUID] - if !loaded { - return E.New("authentication: unknown user ", userUUID) - } - handshakeState := s.quicConn.ConnectionState() - tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32) - if err != nil { - return E.Cause(err, "authentication: export keying material") - } - if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) { - return E.New("authentication: token mismatch") - } - s.authUser = &user - close(s.authDone) - return nil - case CommandPacket: - select { - case <-s.connDone: - return s.connErr - case <-s.authDone: - } - message := allocMessage() - err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream)) - if err != nil { - message.release() - return err - } - s.handleUDPMessage(message, true) - return nil - case CommandDissociate: - select { - case <-s.connDone: - return s.connErr - case <-s.authDone: - } - if buffer.Len() > 4 { - return E.New("invalid dissociate message") - } - var sessionID uint16 - err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID) - if err != nil { - return err - } - s.udpAccess.RLock() - udpConn, loaded := s.udpConnMap[sessionID] - s.udpAccess.RUnlock() - if loaded { - udpConn.closeWithError(E.New("remote closed")) - s.udpAccess.Lock() - delete(s.udpConnMap, sessionID) - s.udpAccess.Unlock() - } - return nil - default: - return E.New("unknown command ", command) - } -} - -func (s *serverSession) handleAuthTimeout() { - select { - case <-s.connDone: - case <-s.authDone: - case <-time.After(s.authTimeout): - s.closeWithError(E.New("authentication timeout")) - } -} - -func (s *serverSession) loopStreams() { - for { - stream, err := s.quicConn.AcceptStream(s.ctx) - if err != nil { - return - } - go func() { - err = s.handleStream(stream) - if err != nil { - stream.CancelRead(0) - stream.Close() - s.logger.Error(E.Cause(err, "handle stream request")) - } - }() - } -} - -func (s *serverSession) handleStream(stream quic.Stream) error { - buffer := buf.NewSize(2 + M.MaxSocksaddrLength) - defer buffer.Release() - _, err := buffer.ReadAtLeastFrom(stream, 2) - if err != nil { - return E.Cause(err, "read request") - } - version, _ := buffer.ReadByte() - if version != Version { - return E.New("unknown version ", buffer.Byte(0)) - } - command, _ := buffer.ReadByte() - if command != CommandConnect { - return E.New("unsupported stream command ", command) - } - destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream)) - if err != nil { - return E.Cause(err, "read request destination") - } - select { - case <-s.connDone: - return s.connErr - case <-s.authDone: - } - var conn net.Conn = &serverConn{ - Stream: stream, - destination: destination, - } - if buffer.IsEmpty() { - buffer.Release() - } else { - conn = bufio.NewCachedConn(conn, buffer) - } - ctx := s.ctx - if s.authUser.Name != "" { - ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) - } - _ = s.handler.NewConnection(ctx, conn, M.Metadata{ - Source: s.source, - Destination: destination, - }) - return nil -} - -func (s *serverSession) loopHeartbeats() { - ticker := time.NewTicker(s.heartbeat) - defer ticker.Stop() - for { - select { - case <-s.connDone: - return - case <-ticker.C: - err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat}) - if err != nil { - s.closeWithError(E.Cause(err, "send heartbeat")) - } - } - } -} - -func (s *serverSession) closeWithError(err error) { - s.connAccess.Lock() - defer s.connAccess.Unlock() - select { - case <-s.connDone: - return - default: - s.connErr = err - close(s.connDone) - } - if E.IsClosedOrCanceled(err) { - s.logger.Debug(E.Cause(err, "connection failed")) - } else { - s.logger.Error(E.Cause(err, "connection failed")) - } - _ = s.quicConn.CloseWithError(0, "") -} - -type serverConn struct { - quic.Stream - destination M.Socksaddr -} - -func (c *serverConn) Read(p []byte) (n int, err error) { - n, err = c.Stream.Read(p) - return n, baderror.WrapQUIC(err) -} - -func (c *serverConn) Write(p []byte) (n int, err error) { - n, err = c.Stream.Write(p) - return n, baderror.WrapQUIC(err) -} - -func (c *serverConn) LocalAddr() net.Addr { - return c.destination -} - -func (c *serverConn) RemoteAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *serverConn) Close() error { - c.Stream.CancelRead(0) - return c.Stream.Close() -} diff --git a/transport/tuic/server_packet.go b/transport/tuic/server_packet.go deleted file mode 100644 index 5a26cf50..00000000 --- a/transport/tuic/server_packet.go +++ /dev/null @@ -1,75 +0,0 @@ -//go:build with_quic - -package tuic - -import ( - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" -) - -func (s *serverSession) loopMessages() { - select { - case <-s.connDone: - return - case <-s.authDone: - } - for { - message, err := s.quicConn.ReceiveMessage(s.ctx) - if err != nil { - s.closeWithError(E.Cause(err, "receive message")) - return - } - hErr := s.handleMessage(message) - if hErr != nil { - s.closeWithError(E.Cause(hErr, "handle message")) - return - } - } -} - -func (s *serverSession) handleMessage(data []byte) error { - if len(data) < 2 { - return E.New("invalid message") - } - if data[0] != Version { - return E.New("unknown version ", data[0]) - } - switch data[1] { - case CommandPacket: - message := allocMessage() - err := decodeUDPMessage(message, data[2:]) - if err != nil { - message.release() - return E.Cause(err, "decode UDP message") - } - s.handleUDPMessage(message, false) - return nil - case CommandHeartbeat: - return nil - default: - return E.New("unknown command ", data[0]) - } -} - -func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) { - s.udpAccess.RLock() - udpConn, loaded := s.udpConnMap[message.sessionID] - s.udpAccess.RUnlock() - if !loaded || common.Done(udpConn.ctx) { - udpConn = newUDPPacketConn(s.ctx, s.quicConn, udpStream, true, func() { - s.udpAccess.Lock() - delete(s.udpConnMap, message.sessionID) - s.udpAccess.Unlock() - }) - udpConn.sessionID = message.sessionID - s.udpAccess.Lock() - s.udpConnMap[message.sessionID] = udpConn - s.udpAccess.Unlock() - go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{ - Source: s.source, - Destination: message.destination, - }) - } - udpConn.inputPacket(message) -} diff --git a/transport/v2rayquic/client.go b/transport/v2rayquic/client.go index b8037d95..c3345780 100644 --- a/transport/v2rayquic/client.go +++ b/transport/v2rayquic/client.go @@ -9,11 +9,11 @@ import ( "github.com/sagernet/quic-go" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/qtls" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/hysteria" + "github.com/sagernet/sing-quic" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" diff --git a/transport/v2rayquic/server.go b/transport/v2rayquic/server.go index c366006e..71960e58 100644 --- a/transport/v2rayquic/server.go +++ b/transport/v2rayquic/server.go @@ -9,11 +9,11 @@ import ( "github.com/sagernet/quic-go" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/qtls" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/hysteria" + "github.com/sagernet/sing-quic" "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -27,7 +27,7 @@ type Server struct { quicConfig *quic.Config handler adapter.V2RayServerTransportHandler udpListener net.PacketConn - quicListener qtls.QUICListener + quicListener qtls.Listener } func NewServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) {