diff --git a/adapter/inbound.go b/adapter/inbound.go index 821b7b65..aeef3798 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -1,6 +1,8 @@ package adapter import ( + "context" + M "github.com/sagernet/sing/common/metadata" ) @@ -17,6 +19,7 @@ type InboundContext struct { Destination M.Socksaddr Domain string Protocol string + Outbound string // cache @@ -26,3 +29,26 @@ type InboundContext struct { SourceGeoIPCode string GeoIPCode string } + +type inboundContextKey struct{} + +func WithContext(ctx context.Context, inboundContext *InboundContext) context.Context { + return context.WithValue(ctx, (*inboundContextKey)(nil), inboundContext) +} + +func ContextFrom(ctx context.Context) *InboundContext { + metadata := ctx.Value((*inboundContextKey)(nil)) + if metadata == nil { + return nil + } + return metadata.(*InboundContext) +} + +func AppendContext(ctx context.Context) (context.Context, *InboundContext) { + metadata := ContextFrom(ctx) + if metadata != nil { + return ctx, metadata + } + metadata = new(InboundContext) + return WithContext(ctx, metadata), nil +} diff --git a/adapter/router.go b/adapter/router.go index d094113b..f6d0c8e3 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -3,11 +3,15 @@ package adapter import ( "context" "net" + "net/netip" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing-box/common/geoip" "github.com/sagernet/sing-box/common/geosite" + C "github.com/sagernet/sing-box/constant" + + "golang.org/x/net/dns/dnsmessage" ) type Router interface { @@ -18,6 +22,9 @@ type Router interface { RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error GeoIPReader() *geoip.Reader GeositeReader() *geosite.Reader + Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) + Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) + LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) } type Rule interface { diff --git a/service.go b/box.go similarity index 90% rename from service.go rename to box.go index 7772491e..e4e941ba 100644 --- a/service.go +++ b/box.go @@ -16,9 +16,9 @@ import ( "github.com/sagernet/sing-box/route" ) -var _ adapter.Service = (*Service)(nil) +var _ adapter.Service = (*Box)(nil) -type Service struct { +type Box struct { router adapter.Router logger log.Logger inbounds []adapter.Inbound @@ -26,13 +26,13 @@ type Service struct { createdAt time.Time } -func NewService(ctx context.Context, options option.Options) (*Service, error) { +func New(ctx context.Context, options option.Options) (*Box, error) { createdAt := time.Now() logger, err := log.NewLogger(common.PtrValueOrDefault(options.Log)) if err != nil { return nil, E.Cause(err, "parse log options") } - router, err := route.NewRouter(ctx, logger, common.PtrValueOrDefault(options.Route)) + router, err := route.NewRouter(ctx, logger, common.PtrValueOrDefault(options.Route), common.PtrValueOrDefault(options.DNS)) if err != nil { return nil, E.Cause(err, "parse route options") } @@ -63,7 +63,7 @@ func NewService(ctx context.Context, options option.Options) (*Service, error) { if err != nil { return nil, err } - return &Service{ + return &Box{ router: router, logger: logger, inbounds: inbounds, @@ -72,7 +72,7 @@ func NewService(ctx context.Context, options option.Options) (*Service, error) { }, nil } -func (s *Service) Start() error { +func (s *Box) Start() error { err := s.logger.Start() if err != nil { return err @@ -91,7 +91,7 @@ func (s *Service) Start() error { return nil } -func (s *Service) Close() error { +func (s *Box) Close() error { for _, in := range s.inbounds { in.Close() } diff --git a/cmd/sing-box/cmd_check.go b/cmd/sing-box/cmd_check.go index 3a48d52c..41a97a59 100644 --- a/cmd/sing-box/cmd_check.go +++ b/cmd/sing-box/cmd_check.go @@ -30,7 +30,7 @@ func checkConfiguration(cmd *cobra.Command, args []string) { logrus.Fatal("decode config: ", err) } ctx, cancel := context.WithCancel(context.Background()) - _, err = box.NewService(ctx, options) + _, err = box.New(ctx, options) if err != nil { logrus.Fatal("create service: ", err) } diff --git a/cmd/sing-box/cmd_run.go b/cmd/sing-box/cmd_run.go index 374f66ef..a41669fa 100644 --- a/cmd/sing-box/cmd_run.go +++ b/cmd/sing-box/cmd_run.go @@ -38,11 +38,11 @@ func run(cmd *cobra.Command, args []string) { options.Log.DisableColor = true } ctx, cancel := context.WithCancel(context.Background()) - service, err := box.NewService(ctx, options) + instance, err := box.New(ctx, options) if err != nil { logrus.Fatal("create service: ", err) } - err = service.Start() + err = instance.Start() if err != nil { logrus.Fatal("start service: ", err) } @@ -50,5 +50,5 @@ func run(cmd *cobra.Command, args []string) { signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) <-osSignals cancel() - service.Close() + instance.Close() } diff --git a/outbound/dialer/default.go b/common/dialer/default.go similarity index 81% rename from outbound/dialer/default.go rename to common/dialer/default.go index 0332e309..d7e34bd2 100644 --- a/outbound/dialer/default.go +++ b/common/dialer/default.go @@ -7,7 +7,6 @@ import ( "github.com/sagernet/sing/common/control" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" @@ -15,12 +14,12 @@ import ( "github.com/database64128/tfo-go" ) -type defaultDialer struct { +type DefaultDialer struct { tfo.Dialer net.ListenConfig } -func NewDefault(options option.DialerOptions) N.Dialer { +func NewDefault(options option.DialerOptions) *DefaultDialer { var dialer net.Dialer var listener net.ListenConfig if options.BindInterface != "" { @@ -41,13 +40,17 @@ func NewDefault(options option.DialerOptions) N.Dialer { if options.ConnectTimeout != 0 { dialer.Timeout = time.Duration(options.ConnectTimeout) * time.Second } - return &defaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, listener} + return &DefaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, listener} } -func (d *defaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) { +func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) { return d.Dialer.DialContext(ctx, network, address.String()) } -func (d *defaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { return d.ListenConfig.ListenPacket(ctx, C.NetworkUDP, "") } + +func (d *DefaultDialer) Upstream() any { + return &d.Dialer +} diff --git a/outbound/dialer/detour.go b/common/dialer/detour.go similarity index 71% rename from outbound/dialer/detour.go rename to common/dialer/detour.go index bc79d35f..43fa654a 100644 --- a/outbound/dialer/detour.go +++ b/common/dialer/detour.go @@ -12,7 +12,7 @@ import ( "github.com/sagernet/sing-box/adapter" ) -type detourDialer struct { +type DetourDialer struct { router adapter.Router detour string dialer N.Dialer @@ -21,15 +21,15 @@ type detourDialer struct { } func NewDetour(router adapter.Router, detour string) N.Dialer { - return &detourDialer{router: router, detour: detour} + return &DetourDialer{router: router, detour: detour} } -func (d *detourDialer) Start() error { +func (d *DetourDialer) Start() error { _, err := d.Dialer() return err } -func (d *detourDialer) Dialer() (N.Dialer, error) { +func (d *DetourDialer) Dialer() (N.Dialer, error) { d.initOnce.Do(func() { var loaded bool d.dialer, loaded = d.router.Outbound(d.detour) @@ -40,7 +40,7 @@ func (d *detourDialer) Dialer() (N.Dialer, error) { return d.dialer, d.initErr } -func (d *detourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { dialer, err := d.Dialer() if err != nil { return nil, err @@ -48,10 +48,15 @@ func (d *detourDialer) DialContext(ctx context.Context, network string, destinat return dialer.DialContext(ctx, network, destination) } -func (d *detourDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (d *DetourDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { dialer, err := d.Dialer() if err != nil { return nil, err } return dialer.ListenPacket(ctx, destination) } + +func (d *DetourDialer) Upstream() any { + detour, _ := d.Dialer() + return detour +} diff --git a/outbound/dialer/dialer.go b/common/dialer/dialer.go similarity index 66% rename from outbound/dialer/dialer.go rename to common/dialer/dialer.go index 411163ba..2039729d 100644 --- a/outbound/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -5,15 +5,21 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" ) func New(router adapter.Router, options option.DialerOptions) N.Dialer { + domainStrategy := C.DomainStrategy(options.DomainStrategy) var dialer N.Dialer if options.Detour == "" { dialer = NewDefault(options) + dialer = NewResolveDialer(router, dialer, domainStrategy) } else { dialer = NewDetour(router, options.Detour) + if domainStrategy != C.DomainStrategyAsIS { + dialer = NewResolveDialer(router, dialer, domainStrategy) + } } if options.OverrideOptions.IsValid() { dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions)) diff --git a/outbound/dialer/override.go b/common/dialer/override.go similarity index 83% rename from outbound/dialer/override.go rename to common/dialer/override.go index 4da9a122..dce05c95 100644 --- a/outbound/dialer/override.go +++ b/common/dialer/override.go @@ -13,9 +13,9 @@ import ( "github.com/sagernet/sing-box/option" ) -var _ N.Dialer = (*overrideDialer)(nil) +var _ N.Dialer = (*OverrideDialer)(nil) -type overrideDialer struct { +type OverrideDialer struct { upstream N.Dialer tlsEnabled bool tlsConfig tls.Config @@ -23,7 +23,7 @@ type overrideDialer struct { } func NewOverride(upstream N.Dialer, options option.OverrideStreamOptions) N.Dialer { - return &overrideDialer{ + return &OverrideDialer{ upstream, options.TLS, tls.Config{ @@ -34,7 +34,7 @@ func NewOverride(upstream N.Dialer, options option.OverrideStreamOptions) N.Dial } } -func (d *overrideDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +func (d *OverrideDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch network { case C.NetworkTCP: conn, err := d.upstream.DialContext(ctx, C.NetworkTCP, destination) @@ -54,7 +54,7 @@ func (d *overrideDialer) DialContext(ctx context.Context, network string, destin return d.upstream.DialContext(ctx, network, destination) } -func (d *overrideDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (d *OverrideDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { if d.uotEnabled { tcpConn, err := d.upstream.DialContext(ctx, C.NetworkTCP, destination) if err != nil { @@ -64,3 +64,7 @@ func (d *overrideDialer) ListenPacket(ctx context.Context, destination M.Socksad } return d.upstream.ListenPacket(ctx, destination) } + +func (d *OverrideDialer) Upstream() any { + return d.upstream +} diff --git a/outbound/dialer/protect.go b/common/dialer/protect.go similarity index 100% rename from outbound/dialer/protect.go rename to common/dialer/protect.go diff --git a/outbound/dialer/protect_stub.go b/common/dialer/protect_stub.go similarity index 100% rename from outbound/dialer/protect_stub.go rename to common/dialer/protect_stub.go diff --git a/common/dialer/resolve.go b/common/dialer/resolve.go new file mode 100644 index 00000000..2584c568 --- /dev/null +++ b/common/dialer/resolve.go @@ -0,0 +1,84 @@ +package dialer + +import ( + "context" + "net" + "net/netip" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" +) + +type ResolveDialer struct { + dialer N.Dialer + router adapter.Router + strategy C.DomainStrategy +} + +func NewResolveDialer(router adapter.Router, dialer N.Dialer, strategy C.DomainStrategy) *ResolveDialer { + return &ResolveDialer{ + dialer, + router, + strategy, + } +} + +func (d *ResolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if !destination.IsFqdn() { + return d.dialer.DialContext(ctx, network, destination) + } + var addresses []netip.Addr + var err error + if d.strategy == C.DomainStrategyAsIS { + addresses, err = d.router.LookupDefault(ctx, destination.Fqdn) + } else { + addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy) + } + if err != nil { + return nil, err + } + var conn net.Conn + var connErrors []error + for _, address := range addresses { + conn, err = d.dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port)) + if err != nil { + connErrors = append(connErrors, err) + } + return conn, nil + } + return nil, E.Errors(connErrors...) +} + +func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if !destination.IsFqdn() { + return d.dialer.ListenPacket(ctx, destination) + } + var addresses []netip.Addr + var err error + if d.strategy == C.DomainStrategyAsIS { + addresses, err = d.router.LookupDefault(ctx, destination.Fqdn) + } else { + addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy) + } + if err != nil { + return nil, err + } + var conn net.PacketConn + var connErrors []error + for _, address := range addresses { + conn, err = d.dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port)) + if err != nil { + connErrors = append(connErrors, err) + } + return conn, nil + } + return nil, E.Errors(connErrors...) +} + +func (d *ResolveDialer) Upstream() any { + return d.dialer +} diff --git a/common/domain/matcher_test.go b/common/domain/matcher_test.go index c4bc89ae..4ba31e4b 100644 --- a/common/domain/matcher_test.go +++ b/common/domain/matcher_test.go @@ -1,14 +1,16 @@ -package domain +package domain_test import ( "testing" + "github.com/sagernet/sing-box/common/domain" + "github.com/stretchr/testify/require" ) func TestMatch(t *testing.T) { r := require.New(t) - matcher := NewMatcher([]string{"domain.com"}, []string{"suffix.com", ".suffix.org"}) + matcher := domain.NewMatcher([]string{"domain.com"}, []string{"suffix.com", ".suffix.org"}) r.True(matcher.Match("domain.com")) r.False(matcher.Match("my.domain.com")) r.True(matcher.Match("suffix.com")) diff --git a/common/sniff/quic_test.go b/common/sniff/quic_test.go index fd0a5cd5..d47e0416 100644 --- a/common/sniff/quic_test.go +++ b/common/sniff/quic_test.go @@ -1,23 +1,25 @@ -package sniff +package sniff_test import ( "context" "encoding/hex" "testing" + "github.com/sagernet/sing-box/common/sniff" + "github.com/stretchr/testify/require" ) func TestSniffQUICv1(t *testing.T) { pkt, err := hex.DecodeString("cc0000000108d2dc7bad02241f5003796e71004215a71bfcb05159416c724be418537389acdd9a4047306283dcb4d7a9cad5cc06322042d204da67a8dbaa328ab476bb428b48fd001501863afd203f8d4ef085629d664f1a734a65969a47e4a63d4e01a21f18c1d90db0c027180906dc135f9ae421bb8617314c8d54c175fef3d3383d310d0916ebcbd6eed9329befbbb109d8fd4af1d2cf9d6adce8e6c1260a7f8256e273e326da0aa7cc148d76e7a08489dc9d52ade89c027cbc3491ada46417c2c04e2ca768e9a7dd6aa00c594e48b678927325da796817693499bb727050cb3baf3d3291a397c3a8d868e8ec7b8f7295e347455c9dadbe2252ae917ac793d958c7fb8a3d2cdb34e3891eb4286f18617556ff7216dd60256aa5b1d11ff4753459fc5f9dedf11d483a26a0835dc6cd50e1c1f54f86e8f1e502821183cd874f6447a74e818bf3445c7795acf4559d1c1fac474911d2ead5c8d23e4aa4f67afb66efe305a30a0b5d825679b31ddc186cbea936535795c7e8c378c87b8c5adc065154d15bae8f85ac8fec2da40c3aa623b682a065440831555011d7647cde44446a0fb4cf5892f2c088ae1920643094be72e3c499fe8d265caf939e8ab607a5b9317917d2a32a812e8a0e6a2f84721bbb5984ffd242838f705d13f4cfb249bc6a5c80d58ac2595edf56648ec3fe21d787573c253a79805252d6d81e26d367d4ff29ef66b5fe8992086af7bada8cad10b82a7c0dc406c5b6d0c5ec3c583e767f759ce08cad6c3c8f91e5a8") require.NoError(t, err) - metadata, err := QUICClientHello(context.Background(), pkt) + metadata, err := sniff.QUICClientHello(context.Background(), pkt) require.NoError(t, err) require.Equal(t, metadata.Domain, "cloudflare-quic.com") } func FuzzSniffQUIC(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { - QUICClientHello(context.Background(), data) + sniff.QUICClientHello(context.Background(), data) }) } diff --git a/dns/client_test.go b/dns/client_test.go new file mode 100644 index 00000000..cda7f543 --- /dev/null +++ b/dns/client_test.go @@ -0,0 +1,46 @@ +package dns_test + +import ( + "context" + "testing" + "time" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + + "github.com/stretchr/testify/require" + "golang.org/x/net/dns/dnsmessage" +) + +func TestClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + client := dns.NewClient(option.DNSClientOptions{}) + dnsTransport := dns.NewTCPTransport(context.Background(), N.SystemDialer, log.NewNopLogger(), M.ParseSocksaddr("1.0.0.1:53")) + response, err := client.Exchange(ctx, dnsTransport, makeQuery()) + require.NoError(t, err) + require.NotEmpty(t, response.Answers, "no answers") + response, err = client.Exchange(ctx, dnsTransport, makeQuery()) + require.NoError(t, err) + require.NotEmpty(t, response.Answers, "no answers") + addresses, err := client.Lookup(ctx, dnsTransport, "www.google.com", C.DomainStrategyAsIS) + require.NoError(t, err) + require.NotEmpty(t, addresses, "no answers") + cancel() +} + +func makeQuery() *dnsmessage.Message { + message := &dnsmessage.Message{} + message.Header.ID = 1 + message.Header.RecursionDesired = true + message.Questions = append(message.Questions, dnsmessage.Question{ + Name: dnsmessage.MustNewName("google.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + return message +} diff --git a/dns/dialer.go b/dns/dialer.go new file mode 100644 index 00000000..d7f47e42 --- /dev/null +++ b/dns/dialer.go @@ -0,0 +1,68 @@ +package dns + +import ( + "context" + "net" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" +) + +type DialerWrapper struct { + dialer N.Dialer + strategy C.DomainStrategy + client adapter.DNSClient + transport adapter.DNSTransport +} + +func NewDialerWrapper(dialer N.Dialer, strategy C.DomainStrategy, client adapter.DNSClient, transport adapter.DNSTransport) N.Dialer { + return &DialerWrapper{dialer, strategy, client, transport} +} + +func (d *DialerWrapper) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if destination.IsIP() { + return d.dialer.DialContext(ctx, network, destination) + } + addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, d.strategy) + if err != nil { + return nil, err + } + var conn net.Conn + var connErrors []error + for _, address := range addresses { + conn, err = d.dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port)) + if err != nil { + connErrors = append(connErrors, err) + } + return conn, nil + } + return nil, E.Errors(connErrors...) +} + +func (d *DialerWrapper) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if destination.IsIP() { + return d.dialer.ListenPacket(ctx, destination) + } + addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, d.strategy) + if err != nil { + return nil, err + } + var conn net.PacketConn + var connErrors []error + for _, address := range addresses { + conn, err = d.dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port)) + if err != nil { + connErrors = append(connErrors, err) + } + return conn, nil + } + return nil, E.Errors(connErrors...) +} + +func (d *DialerWrapper) Upstream() any { + return d.dialer +} diff --git a/go.mod b/go.mod index a89d9e24..228f4d00 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/goccy/go-json v0.9.8 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/oschwald/maxminddb-golang v1.9.0 - github.com/sagernet/sing v0.0.0-20220706131532-6d16497f03a6 + github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 github.com/sirupsen/logrus v1.8.1 github.com/spf13/cobra v1.5.0 diff --git a/go.sum b/go.sum index ffce45a4..a78d1274 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sagernet/sing v0.0.0-20220706131532-6d16497f03a6 h1:NKDjOKPHP4JOrYomj2Q/tvKDWLmCNLHNQSPZLE5o3I4= -github.com/sagernet/sing v0.0.0-20220706131532-6d16497f03a6/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= +github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4 h1:nV/DyNi+O1VxNoChD5E9M6Y0VoFdVr0UEW9h9JnqxNs= +github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 h1:whNDUGOAX5GPZkSy4G3Gv9QyIgk5SXRyjkRuP7ohF8k= github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649/go.mod h1:MuyT+9fEPjvauAv0fSE0a6Q+l0Tv2ZrAafTkYfnxBFw= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= diff --git a/option/config.go b/option/config.go index 0871df22..f0c7e26f 100644 --- a/option/config.go +++ b/option/config.go @@ -2,17 +2,19 @@ package option import ( "bytes" + "strings" "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" "github.com/goccy/go-json" ) type _Options struct { Log *LogOption `json:"log,omitempty"` + DNS *DNSOptions `json:"dns,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"` - DNS *DNSOptions `json:"dns,omitempty"` Route *RouteOptions `json:"route,omitempty"` } @@ -21,11 +23,22 @@ type Options _Options func (o *Options) UnmarshalJSON(content []byte) error { decoder := json.NewDecoder(bytes.NewReader(content)) decoder.DisallowUnknownFields() - return decoder.Decode((*_Options)(o)) + err := decoder.Decode((*_Options)(o)) + if err == nil { + return nil + } + if syntaxError, isSyntaxError := err.(*json.SyntaxError); isSyntaxError { + prefix := string(content[:syntaxError.Offset]) + row := strings.Count(prefix, "\n") + 1 + column := len(prefix) - strings.LastIndex(prefix, "\n") - 1 + return E.Extend(syntaxError, "row ", row, ", column ", column) + } + return err } func (o Options) Equals(other Options) bool { return common.ComparablePtrEquals(o.Log, other.Log) && + common.PtrEquals(o.DNS, other.DNS) && common.SliceEquals(o.Inbounds, other.Inbounds) && common.ComparableSliceEquals(o.Outbounds, other.Outbounds) && common.PtrEquals(o.Route, other.Route) diff --git a/option/dns.go b/option/dns.go index f7f380e2..02d6779e 100644 --- a/option/dns.go +++ b/option/dns.go @@ -1,18 +1,146 @@ package option +import ( + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + + C "github.com/sagernet/sing-box/constant" + + "github.com/goccy/go-json" +) + type DNSOptions struct { - Servers []DNSServerOptions `json:"servers,omitempty"` + Servers []DNSServerOptions `json:"servers,omitempty"` + Rules []DNSRule `json:"rules,omitempty"` + Final string `json:"final,omitempty"` + Strategy DomainStrategy `json:"strategy,omitempty"` DNSClientOptions } +func (o DNSOptions) Equals(other DNSOptions) bool { + return common.ComparableSliceEquals(o.Servers, other.Servers) && + common.SliceEquals(o.Rules, other.Rules) && + o.Final == other.Final && + o.Strategy == other.Strategy && + o.DNSClientOptions == other.DNSClientOptions +} + type DNSClientOptions struct { DisableCache bool `json:"disable_cache,omitempty"` DisableExpire bool `json:"disable_expire,omitempty"` } type DNSServerOptions struct { - Tag string `json:"tag,omitempty"` - Address string `json:"address"` - AddressResolver string `json:"address_resolver,omitempty"` + Tag string `json:"tag,omitempty"` + Address string `json:"address"` + AddressResolver string `json:"address_resolver,omitempty"` + AddressStrategy DomainStrategy `json:"address_strategy,omitempty"` DialerOptions } + +type _DNSRule struct { + Type string `json:"type,omitempty"` + DefaultOptions DefaultDNSRule `json:"-"` + LogicalOptions LogicalDNSRule `json:"-"` +} + +type DNSRule _DNSRule + +func (r DNSRule) Equals(other DNSRule) bool { + return r.Type == other.Type && + r.DefaultOptions.Equals(other.DefaultOptions) && + r.LogicalOptions.Equals(other.LogicalOptions) +} + +func (r DNSRule) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleTypeDefault: + v = r.DefaultOptions + case C.RuleTypeLogical: + v = r.LogicalOptions + default: + return nil, E.New("unknown rule type: " + r.Type) + } + return MarshallObjects((_DNSRule)(r), v) +} + +func (r *DNSRule) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_DNSRule)(r)) + if err != nil { + return err + } + if r.Type == "" { + r.Type = C.RuleTypeDefault + } + var v any + switch r.Type { + case C.RuleTypeDefault: + v = &r.DefaultOptions + case C.RuleTypeLogical: + v = &r.LogicalOptions + default: + return E.New("unknown rule type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_DNSRule)(r), v) + if err != nil { + return E.Cause(err, "dns route rule") + } + return nil +} + +type DefaultDNSRule struct { + Inbound Listable[string] `json:"inbound,omitempty"` + Network string `json:"network,omitempty"` + Protocol Listable[string] `json:"protocol,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + Outbound Listable[string] `json:"outbound,omitempty"` + Server string `json:"server,omitempty"` +} + +func (r DefaultDNSRule) IsValid() bool { + var defaultValue DefaultDNSRule + defaultValue.Server = r.Server + return !r.Equals(defaultValue) +} + +func (r DefaultDNSRule) Equals(other DefaultDNSRule) bool { + return common.ComparableSliceEquals(r.Inbound, other.Inbound) && + r.Network == other.Network && + common.ComparableSliceEquals(r.Protocol, other.Protocol) && + common.ComparableSliceEquals(r.Domain, other.Domain) && + common.ComparableSliceEquals(r.DomainSuffix, other.DomainSuffix) && + common.ComparableSliceEquals(r.DomainKeyword, other.DomainKeyword) && + common.ComparableSliceEquals(r.DomainRegex, other.DomainRegex) && + common.ComparableSliceEquals(r.Geosite, other.Geosite) && + common.ComparableSliceEquals(r.SourceGeoIP, other.SourceGeoIP) && + common.ComparableSliceEquals(r.SourceIPCIDR, other.SourceIPCIDR) && + common.ComparableSliceEquals(r.SourcePort, other.SourcePort) && + common.ComparableSliceEquals(r.Port, other.Port) && + common.ComparableSliceEquals(r.Outbound, other.Outbound) && + r.Server == other.Server +} + +type LogicalDNSRule struct { + Mode string `json:"mode"` + Rules []DefaultDNSRule `json:"rules,omitempty"` + Server string `json:"server,omitempty"` +} + +func (r LogicalDNSRule) IsValid() bool { + return len(r.Rules) > 0 && common.All(r.Rules, DefaultDNSRule.IsValid) +} + +func (r LogicalDNSRule) Equals(other LogicalDNSRule) bool { + return r.Mode == other.Mode && + common.SliceEquals(r.Rules, other.Rules) && + r.Server == other.Server +} diff --git a/option/outbound.go b/option/outbound.go index 1d086a14..5f297c16 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -75,6 +75,7 @@ type DialerOptions struct { ConnectTimeout int `json:"connect_timeout,omitempty"` TCPFastOpen bool `json:"tcp_fast_open,omitempty"` OverrideOptions *OverrideStreamOptions `json:"override,omitempty"` + DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` } type OverrideStreamOptions struct { diff --git a/option/route.go b/option/route.go index 2d8ed17f..875074ea 100644 --- a/option/route.go +++ b/option/route.go @@ -10,10 +10,10 @@ import ( ) type RouteOptions struct { - GeoIP *GeoIPOptions `json:"geoip,omitempty"` - Geosite *GeositeOptions `json:"geosite,omitempty"` - Rules []Rule `json:"rules,omitempty"` - DefaultDetour string `json:"default_detour,omitempty"` + GeoIP *GeoIPOptions `json:"geoip,omitempty"` + Geosite *GeositeOptions `json:"geosite,omitempty"` + Rules []Rule `json:"rules,omitempty"` + Final string `json:"final,omitempty"` } func (o RouteOptions) Equals(other RouteOptions) bool { @@ -52,6 +52,7 @@ func (r Rule) MarshalJSON() ([]byte, error) { var v any switch r.Type { case C.RuleTypeDefault: + r.Type = "" v = r.DefaultOptions case C.RuleTypeLogical: v = r.LogicalOptions @@ -66,12 +67,10 @@ func (r *Rule) UnmarshalJSON(bytes []byte) error { if err != nil { return err } - if r.Type == "" { - r.Type = C.RuleTypeDefault - } var v any switch r.Type { - case C.RuleTypeDefault: + case "": + r.Type = C.RuleTypeDefault v = &r.DefaultOptions case C.RuleTypeLogical: v = &r.LogicalOptions diff --git a/option/types.go b/option/types.go index f2567891..81f6f06e 100644 --- a/option/types.go +++ b/option/types.go @@ -97,7 +97,8 @@ func (s DomainStrategy) MarshalJSON() ([]byte, error) { var value string switch C.DomainStrategy(s) { case C.DomainStrategyAsIS: - value = "AsIS" + value = "" + // value = "AsIS" case C.DomainStrategyPreferIPv4: value = "PreferIPv4" case C.DomainStrategyPreferIPv6: @@ -119,7 +120,7 @@ func (s *DomainStrategy) UnmarshalJSON(bytes []byte) error { return err } switch value { - case "AsIS": + case "", "AsIS": *s = DomainStrategy(C.DomainStrategyAsIS) case "PreferIPv4": *s = DomainStrategy(C.DomainStrategyPreferIPv4) diff --git a/outbound/direct.go b/outbound/direct.go index 4127cf80..0c489af1 100644 --- a/outbound/direct.go +++ b/outbound/direct.go @@ -9,10 +9,10 @@ import ( N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/outbound/dialer" ) var _ adapter.Outbound = (*Direct)(nil) @@ -47,41 +47,45 @@ func NewDirect(router adapter.Router, logger log.Logger, tag string, options opt return outbound } -func (d *Direct) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - switch d.overrideOption { +func (h *Direct) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + ctx, metadata := adapter.AppendContext(ctx) + metadata.Outbound = h.tag + switch h.overrideOption { case 1: - destination = d.overrideDestination + destination = h.overrideDestination case 2: - newDestination := d.overrideDestination + newDestination := h.overrideDestination newDestination.Port = destination.Port destination = newDestination case 3: - destination.Port = d.overrideDestination.Port + destination.Port = h.overrideDestination.Port } switch network { case C.NetworkTCP: - d.logger.WithContext(ctx).Info("outbound connection to ", destination) + h.logger.WithContext(ctx).Info("outbound connection to ", destination) case C.NetworkUDP: - d.logger.WithContext(ctx).Info("outbound packet connection to ", destination) + h.logger.WithContext(ctx).Info("outbound packet connection to ", destination) } - return d.dialer.DialContext(ctx, network, destination) + return h.dialer.DialContext(ctx, network, destination) } -func (d *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - d.logger.WithContext(ctx).Info("outbound packet connection") - return d.dialer.ListenPacket(ctx, destination) +func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + ctx, metadata := adapter.AppendContext(ctx) + metadata.Outbound = h.tag + h.logger.WithContext(ctx).Info("outbound packet connection") + return h.dialer.ListenPacket(ctx, destination) } -func (d *Direct) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { - outConn, err := d.DialContext(ctx, C.NetworkTCP, destination) +func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { + outConn, err := h.DialContext(ctx, C.NetworkTCP, destination) if err != nil { return err } return bufio.CopyConn(ctx, conn, outConn) } -func (d *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { - outConn, err := d.ListenPacket(ctx, destination) +func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { + outConn, err := h.ListenPacket(ctx, destination) if err != nil { return err } diff --git a/outbound/http.go b/outbound/http.go index 254d1984..284651bd 100644 --- a/outbound/http.go +++ b/outbound/http.go @@ -11,10 +11,10 @@ import ( "github.com/sagernet/sing/protocol/http" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/outbound/dialer" ) var _ adapter.Outbound = (*HTTP)(nil) @@ -32,16 +32,20 @@ func NewHTTP(router adapter.Router, logger log.Logger, tag string, options optio tag: tag, network: []string{C.NetworkTCP}, }, - http.NewClient(dialer.New(router, options.DialerOptions), M.ParseSocksaddrHostPort(options.Server, options.ServerPort), options.Username, options.Password), + http.NewClient(dialer.New(router, options.DialerOptions), options.ServerOptions.Build(), options.Username, options.Password), } } func (h *HTTP) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + ctx, metadata := adapter.AppendContext(ctx) + metadata.Outbound = h.tag h.logger.WithContext(ctx).Info("outbound connection to ", destination) return h.client.DialContext(ctx, network, destination) } func (h *HTTP) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + ctx, metadata := adapter.AppendContext(ctx) + metadata.Outbound = h.tag return nil, os.ErrInvalid } diff --git a/outbound/shadowsocks.go b/outbound/shadowsocks.go index ade64c0e..8f3d159f 100644 --- a/outbound/shadowsocks.go +++ b/outbound/shadowsocks.go @@ -5,7 +5,6 @@ import ( "net" "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -13,10 +12,10 @@ import ( "github.com/sagernet/sing-shadowsocks/shadowimpl" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/outbound/dialer" ) var _ adapter.Outbound = (*Shadowsocks)(nil) @@ -29,69 +28,67 @@ type Shadowsocks struct { } func NewShadowsocks(router adapter.Router, logger log.Logger, tag string, options option.ShadowsocksOutboundOptions) (*Shadowsocks, error) { - outbound := &Shadowsocks{ - myOutboundAdapter: myOutboundAdapter{ + method, err := shadowimpl.FetchMethod(options.Method, options.Password) + if err != nil { + return nil, err + } + return &Shadowsocks{ + myOutboundAdapter{ protocol: C.TypeDirect, logger: logger, tag: tag, network: options.Network.Build(), }, - dialer: dialer.New(router, options.DialerOptions), - } - var err error - outbound.method, err = shadowimpl.FetchMethod(options.Method, options.Password) - if err != nil { - return nil, err - } - if options.Server == "" { - return nil, E.New("missing server address") - } else if options.ServerPort == 0 { - return nil, E.New("missing server port") - } - outbound.serverAddr = M.ParseSocksaddrHostPort(options.Server, options.ServerPort) - return outbound, nil + dialer.New(router, options.DialerOptions), + method, + options.ServerOptions.Build(), + }, nil } -func (o *Shadowsocks) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +func (h *Shadowsocks) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + ctx, metadata := adapter.AppendContext(ctx) + metadata.Outbound = h.tag switch network { case C.NetworkTCP: - o.logger.WithContext(ctx).Info("outbound connection to ", destination) - outConn, err := o.dialer.DialContext(ctx, C.NetworkTCP, o.serverAddr) + h.logger.WithContext(ctx).Info("outbound connection to ", destination) + outConn, err := h.dialer.DialContext(ctx, C.NetworkTCP, h.serverAddr) if err != nil { return nil, err } - return o.method.DialEarlyConn(outConn, destination), nil + return h.method.DialEarlyConn(outConn, destination), nil case C.NetworkUDP: - o.logger.WithContext(ctx).Info("outbound packet connection to ", destination) - outConn, err := o.dialer.DialContext(ctx, C.NetworkUDP, o.serverAddr) + h.logger.WithContext(ctx).Info("outbound packet connection to ", destination) + outConn, err := h.dialer.DialContext(ctx, C.NetworkUDP, h.serverAddr) if err != nil { return nil, err } - return &bufio.BindPacketConn{PacketConn: o.method.DialPacketConn(outConn), Addr: destination}, nil + return &bufio.BindPacketConn{PacketConn: h.method.DialPacketConn(outConn), Addr: destination}, nil default: panic("unknown network " + network) } } -func (o *Shadowsocks) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - o.logger.WithContext(ctx).Info("outbound packet connection to ", o.serverAddr) - outConn, err := o.dialer.ListenPacket(ctx, destination) +func (h *Shadowsocks) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + ctx, metadata := adapter.AppendContext(ctx) + metadata.Outbound = h.tag + h.logger.WithContext(ctx).Info("outbound packet connection to ", h.serverAddr) + outConn, err := h.dialer.ListenPacket(ctx, destination) if err != nil { return nil, err } - return o.method.DialPacketConn(&bufio.BindPacketConn{PacketConn: outConn, Addr: o.serverAddr.UDPAddr()}), nil + return h.method.DialPacketConn(&bufio.BindPacketConn{PacketConn: outConn, Addr: h.serverAddr.UDPAddr()}), nil } -func (o *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { - serverConn, err := o.DialContext(ctx, C.NetworkTCP, destination) +func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, destination M.Socksaddr) error { + serverConn, err := h.DialContext(ctx, C.NetworkTCP, destination) if err != nil { return err } return CopyEarlyConn(ctx, conn, serverConn) } -func (o *Shadowsocks) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { - serverConn, err := o.ListenPacket(ctx, destination) +func (h *Shadowsocks) NewPacketConnection(ctx context.Context, conn N.PacketConn, destination M.Socksaddr) error { + serverConn, err := h.ListenPacket(ctx, destination) if err != nil { return err } diff --git a/outbound/socks.go b/outbound/socks.go index b4d60ff2..6ca959a0 100644 --- a/outbound/socks.go +++ b/outbound/socks.go @@ -10,10 +10,10 @@ import ( "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/outbound/dialer" ) var _ adapter.Outbound = (*Socks)(nil) @@ -42,11 +42,13 @@ func NewSocks(router adapter.Router, logger log.Logger, tag string, options opti tag: tag, network: options.Network.Build(), }, - socks.NewClient(detour, M.ParseSocksaddrHostPort(options.Server, options.ServerPort), version, options.Username, options.Password), + socks.NewClient(detour, options.ServerOptions.Build(), version, options.Username, options.Password), }, nil } func (h *Socks) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + ctx, metadata := adapter.AppendContext(ctx) + metadata.Outbound = h.tag switch network { case C.NetworkTCP: h.logger.WithContext(ctx).Info("outbound connection to ", destination) @@ -59,6 +61,8 @@ func (h *Socks) DialContext(ctx context.Context, network string, destination M.S } func (h *Socks) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + ctx, metadata := adapter.AppendContext(ctx) + metadata.Outbound = h.tag h.logger.WithContext(ctx).Info("outbound packet connection to ", destination) return h.client.ListenPacket(ctx, destination) } diff --git a/route/router.go b/route/router.go index f3579623..3172b412 100644 --- a/route/router.go +++ b/route/router.go @@ -5,8 +5,10 @@ import ( "io" "net" "net/http" + "net/netip" "os" "path/filepath" + "strings" "time" "github.com/sagernet/sing/common" @@ -19,19 +21,24 @@ import ( "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/geoip" "github.com/sagernet/sing-box/common/geosite" "github.com/sagernet/sing-box/common/sniff" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + + "golang.org/x/net/dns/dnsmessage" ) var _ adapter.Router = (*Router)(nil) type Router struct { - ctx context.Context - logger log.Logger + ctx context.Context + logger log.Logger + dnsLogger log.Logger outboundByTag map[string]adapter.Outbound rules []adapter.Rule @@ -46,26 +53,116 @@ type Router struct { geositeOptions option.GeositeOptions geoIPReader *geoip.Reader geositeReader *geosite.Reader + + dnsClient adapter.DNSClient + defaultDomainStrategy C.DomainStrategy + + defaultTransport adapter.DNSTransport + transports []adapter.DNSTransport + transportMap map[string]adapter.DNSTransport } -func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptions) (*Router, error) { +func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptions, dnsOptions option.DNSOptions) (*Router, error) { router := &Router{ - ctx: ctx, - logger: logger.WithPrefix("router: "), - outboundByTag: make(map[string]adapter.Outbound), - rules: make([]adapter.Rule, 0, len(options.Rules)), - needGeoIPDatabase: hasGeoRule(options.Rules, isGeoIPRule), - needGeositeDatabase: hasGeoRule(options.Rules, isGeositeRule), - geoIPOptions: common.PtrValueOrDefault(options.GeoIP), - defaultDetour: options.DefaultDetour, + ctx: ctx, + logger: logger.WithPrefix("router: "), + dnsLogger: logger.WithPrefix("dns: "), + outboundByTag: make(map[string]adapter.Outbound), + rules: make([]adapter.Rule, 0, len(options.Rules)), + needGeoIPDatabase: hasGeoRule(options.Rules, isGeoIPRule) || hasGeoDNSRule(dnsOptions.Rules, isGeoIPDNSRule), + needGeositeDatabase: hasGeoRule(options.Rules, isGeositeRule) || hasGeoDNSRule(dnsOptions.Rules, isGeositeDNSRule), + geoIPOptions: common.PtrValueOrDefault(options.GeoIP), + defaultDetour: options.Final, + dnsClient: dns.NewClient(dnsOptions.DNSClientOptions), + defaultDomainStrategy: C.DomainStrategy(dnsOptions.Strategy), } for i, ruleOptions := range options.Rules { - rule, err := NewRule(router, logger, ruleOptions) + routeRule, err := NewRule(router, logger, ruleOptions) if err != nil { return nil, E.Cause(err, "parse rule[", i, "]") } - router.rules = append(router.rules, rule) + router.rules = append(router.rules, routeRule) } + for i, dnsRuleOptions := range dnsOptions.Rules { + dnsRule, err := NewDNSRule(router, logger, dnsRuleOptions) + if err != nil { + return nil, E.Cause(err, "parse dns rule[", i, "]") + } + router.rules = append(router.rules, dnsRule) + } + transports := make([]adapter.DNSTransport, len(dnsOptions.Servers)) + dummyTransportMap := make(map[string]adapter.DNSTransport) + transportMap := make(map[string]adapter.DNSTransport) + transportTags := make([]string, len(dnsOptions.Servers)) + transportTagMap := make(map[string]bool) + for i, server := range dnsOptions.Servers { + var tag string + if server.Tag != "" { + tag = server.Tag + } else { + tag = F.ToString(i) + } + transportTags[i] = tag + transportTagMap[tag] = true + } + for { + lastLen := len(dummyTransportMap) + for i, server := range dnsOptions.Servers { + tag := transportTags[i] + if _, exists := dummyTransportMap[tag]; exists { + continue + } + detour := dialer.New(router, server.DialerOptions) + if server.AddressResolver != "" { + if !transportTagMap[server.AddressResolver] { + return nil, E.New("parse dns server[", tag, "]: address resolver not found: ", server.AddressResolver) + } + if upstream, exists := dummyTransportMap[server.AddressResolver]; exists { + detour = dns.NewDialerWrapper(detour, C.DomainStrategy(server.AddressStrategy), router.dnsClient, upstream) + } else { + continue + } + } + transport, err := dns.NewTransport(ctx, detour, logger, server.Address) + if err != nil { + return nil, E.Cause(err, "parse dns server[", tag, "]") + } + transports[i] = transport + dummyTransportMap[tag] = transport + if server.Tag != "" { + transportMap[server.Tag] = transport + } + } + if len(transports) == len(dummyTransportMap) { + break + } + if lastLen != len(dummyTransportMap) { + continue + } + unresolvedTags := common.MapIndexed(common.FilterIndexed(dnsOptions.Servers, func(index int, server option.DNSServerOptions) bool { + _, exists := dummyTransportMap[transportTags[index]] + return !exists + }), func(index int, server option.DNSServerOptions) string { + return transportTags[index] + }) + return nil, E.New("found circular reference in dns servers: ", strings.Join(unresolvedTags, " ")) + } + var defaultTransport adapter.DNSTransport + if options.Final != "" { + defaultTransport = dummyTransportMap[options.Final] + if defaultTransport == nil { + return nil, E.New("default dns server not found: ", options.Final) + } + } + if defaultTransport == nil { + if len(transports) == 0 { + transports = append(transports, dns.NewLocalTransport()) + } + defaultTransport = transports[0] + } + router.defaultTransport = defaultTransport + router.transports = transports + router.transportMap = transportMap return router, nil } @@ -135,6 +232,11 @@ func (r *Router) Initialize(outbounds []adapter.Outbound, defaultOutbound func() r.defaultOutboundForConnection = defaultOutboundForConnection r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection r.outboundByTag = outboundByTag + for i, rule := range r.rules { + if _, loaded := outboundByTag[rule.Outbound()]; !loaded { + return E.New("outbound not found for rule[", i, "]: ", rule.Outbound()) + } + } return nil } @@ -228,7 +330,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad conn.Close() return E.New("missing supported outbound, closing connection") } - return detour.NewConnection(ctx, conn, metadata.Destination) + return detour.NewConnection(adapter.WithContext(ctx, &metadata), conn, metadata.Destination) } func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { @@ -262,7 +364,19 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m conn.Close() return E.New("missing supported outbound, closing packet connection") } - return detour.NewPacketConnection(ctx, conn, metadata.Destination) + return detour.NewPacketConnection(adapter.WithContext(ctx, &metadata), conn, metadata.Destination) +} + +func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) { + return r.dnsClient.Exchange(ctx, r.matchDNS(ctx), message) +} + +func (r *Router) Lookup(ctx context.Context, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) { + return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, strategy) +} + +func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) { + return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, r.defaultDomainStrategy) } func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, defaultOutbound adapter.Outbound) adapter.Outbound { @@ -280,6 +394,26 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def return defaultOutbound } +func (r *Router) matchDNS(ctx context.Context) adapter.DNSTransport { + metadata := adapter.ContextFrom(ctx) + if metadata == nil { + r.dnsLogger.WithContext(ctx).Info("no context") + return r.defaultTransport + } + for i, rule := range r.rules { + if rule.Match(metadata) { + detour := rule.Outbound() + r.dnsLogger.WithContext(ctx).Info("match[", i, "] ", rule.String(), " => ", detour) + if transport, loaded := r.transportMap[detour]; loaded { + return transport + } + r.dnsLogger.WithContext(ctx).Error("transport not found: ", detour) + } + } + r.dnsLogger.WithContext(ctx).Info("no match") + return r.defaultTransport +} + func hasGeoRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { for _, rule := range rules { switch rule.Type { @@ -298,14 +432,40 @@ func hasGeoRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bo return false } +func hasGeoDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool { + for _, rule := range rules { + switch rule.Type { + case C.RuleTypeDefault: + if cond(rule.DefaultOptions) { + return true + } + case C.RuleTypeLogical: + for _, subRule := range rule.LogicalOptions.Rules { + if cond(subRule) { + return true + } + } + } + } + return false +} + func isGeoIPRule(rule option.DefaultRule) bool { return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode) } +func isGeoIPDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) +} + func isGeositeRule(rule option.DefaultRule) bool { return len(rule.Geosite) > 0 } +func isGeositeDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.Geosite) > 0 +} + func notPrivateNode(code string) bool { return code != "private" } diff --git a/route/rule.go b/route/rule.go index 77f6caf1..ae570cd7 100644 --- a/route/rule.go +++ b/route/rule.go @@ -225,3 +225,91 @@ func (r *DefaultRule) Outbound() string { func (r *DefaultRule) String() string { return strings.Join(common.Map(r.allItems, F.ToString0[RuleItem]), " ") } + +var _ adapter.Rule = (*LogicalRule)(nil) + +type LogicalRule struct { + mode string + rules []*DefaultRule + outbound string +} + +func (r *LogicalRule) UpdateGeosite() error { + for _, rule := range r.rules { + err := rule.UpdateGeosite() + if err != nil { + return err + } + } + return nil +} + +func (r *LogicalRule) Start() error { + for _, rule := range r.rules { + err := rule.Start() + if err != nil { + return err + } + } + return nil +} + +func (r *LogicalRule) Close() error { + for _, rule := range r.rules { + err := rule.Close() + if err != nil { + return err + } + } + return nil +} + +func NewLogicalRule(router adapter.Router, logger log.Logger, options option.LogicalRule) (*LogicalRule, error) { + r := &LogicalRule{ + rules: make([]*DefaultRule, len(options.Rules)), + outbound: options.Outbound, + } + switch options.Mode { + case C.LogicalTypeAnd: + r.mode = C.LogicalTypeAnd + case C.LogicalTypeOr: + r.mode = C.LogicalTypeOr + default: + return nil, E.New("unknown logical mode: ", options.Mode) + } + for i, subRule := range options.Rules { + rule, err := NewDefaultRule(router, logger, subRule) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } + r.rules[i] = rule + } + return r, nil +} + +func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool { + if r.mode == C.LogicalTypeAnd { + return common.All(r.rules, func(it *DefaultRule) bool { + return it.Match(metadata) + }) + } else { + return common.Any(r.rules, func(it *DefaultRule) bool { + return it.Match(metadata) + }) + } +} + +func (r *LogicalRule) Outbound() string { + return r.outbound +} + +func (r *LogicalRule) String() string { + var op string + switch r.mode { + case C.LogicalTypeAnd: + op = "&&" + case C.LogicalTypeOr: + op = "||" + } + return "logical(" + strings.Join(common.Map(r.rules, F.ToString0[*DefaultRule]), " "+op+" ") + ")" +} diff --git a/route/rule_dns.go b/route/rule_dns.go new file mode 100644 index 00000000..85f2bf01 --- /dev/null +++ b/route/rule_dns.go @@ -0,0 +1,250 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +func NewDNSRule(router adapter.Router, logger log.Logger, options option.DNSRule) (adapter.Rule, error) { + if common.IsEmptyByEquals(options) { + return nil, E.New("empty rule config") + } + switch options.Type { + case "", C.RuleTypeDefault: + if !options.DefaultOptions.IsValid() { + return nil, E.New("missing conditions") + } + if options.DefaultOptions.Server == "" { + return nil, E.New("missing server field") + } + return NewDefaultDNSRule(router, logger, options.DefaultOptions) + case C.RuleTypeLogical: + if !options.LogicalOptions.IsValid() { + return nil, E.New("missing conditions") + } + if options.LogicalOptions.Server == "" { + return nil, E.New("missing server field") + } + return NewLogicalDNSRule(router, logger, options.LogicalOptions) + default: + return nil, E.New("unknown rule type: ", options.Type) + } +} + +var _ adapter.Rule = (*DefaultDNSRule)(nil) + +type DefaultDNSRule struct { + items []RuleItem + outbound string +} + +func NewDefaultDNSRule(router adapter.Router, logger log.Logger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { + rule := &DefaultDNSRule{ + outbound: options.Server, + } + if len(options.Inbound) > 0 { + item := NewInboundRule(options.Inbound) + rule.items = append(rule.items, item) + } + if options.Network != "" { + switch options.Network { + case C.NetworkTCP, C.NetworkUDP: + item := NewNetworkItem(options.Network) + rule.items = append(rule.items, item) + default: + return nil, E.New("invalid network: ", options.Network) + } + } + if len(options.Protocol) > 0 { + item := NewProtocolItem(options.Protocol) + rule.items = append(rule.items, item) + } + if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 { + item := NewDomainItem(options.Domain, options.DomainSuffix) + rule.items = append(rule.items, item) + } + if len(options.DomainKeyword) > 0 { + item := NewDomainKeywordItem(options.DomainKeyword) + rule.items = append(rule.items, item) + } + if len(options.DomainRegex) > 0 { + item, err := NewDomainRegexItem(options.DomainRegex) + if err != nil { + return nil, E.Cause(err, "domain_regex") + } + rule.items = append(rule.items, item) + } + if len(options.Geosite) > 0 { + item := NewGeositeItem(router, logger, options.Geosite) + rule.items = append(rule.items, item) + } + if len(options.SourceGeoIP) > 0 { + item := NewGeoIPItem(router, logger, true, options.SourceGeoIP) + rule.items = append(rule.items, item) + } + if len(options.SourceIPCIDR) > 0 { + item, err := NewIPCIDRItem(true, options.SourceIPCIDR) + if err != nil { + return nil, E.Cause(err, "source_ipcidr") + } + rule.items = append(rule.items, item) + } + if len(options.SourcePort) > 0 { + item := NewPortItem(true, options.SourcePort) + rule.items = append(rule.items, item) + } + if len(options.Port) > 0 { + item := NewPortItem(false, options.Port) + rule.items = append(rule.items, item) + } + if len(options.Outbound) > 0 { + item := NewOutboundRule(options.Outbound) + rule.items = append(rule.items, item) + } + return rule, nil +} + +func (r *DefaultDNSRule) Start() error { + for _, item := range r.items { + err := common.Start(item) + if err != nil { + return err + } + } + return nil +} + +func (r *DefaultDNSRule) Close() error { + for _, item := range r.items { + err := common.Close(item) + if err != nil { + return err + } + } + return nil +} + +func (r *DefaultDNSRule) UpdateGeosite() error { + for _, item := range r.items { + if geositeItem, isSite := item.(*GeositeItem); isSite { + err := geositeItem.Update() + if err != nil { + return err + } + } + } + return nil +} + +func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { + for _, item := range r.items { + if !item.Match(metadata) { + return false + } + } + return true +} + +func (r *DefaultDNSRule) Outbound() string { + return r.outbound +} + +func (r *DefaultDNSRule) String() string { + return strings.Join(common.Map(r.items, F.ToString0[RuleItem]), " ") +} + +var _ adapter.Rule = (*LogicalRule)(nil) + +type LogicalDNSRule struct { + mode string + rules []*DefaultDNSRule + outbound string +} + +func (r *LogicalDNSRule) UpdateGeosite() error { + for _, rule := range r.rules { + err := rule.UpdateGeosite() + if err != nil { + return err + } + } + return nil +} + +func (r *LogicalDNSRule) Start() error { + for _, rule := range r.rules { + err := rule.Start() + if err != nil { + return err + } + } + return nil +} + +func (r *LogicalDNSRule) Close() error { + for _, rule := range r.rules { + err := rule.Close() + if err != nil { + return err + } + } + return nil +} + +func NewLogicalDNSRule(router adapter.Router, logger log.Logger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { + r := &LogicalDNSRule{ + rules: make([]*DefaultDNSRule, len(options.Rules)), + outbound: options.Server, + } + switch options.Mode { + case C.LogicalTypeAnd: + r.mode = C.LogicalTypeAnd + case C.LogicalTypeOr: + r.mode = C.LogicalTypeOr + default: + return nil, E.New("unknown logical mode: ", options.Mode) + } + for i, subRule := range options.Rules { + rule, err := NewDefaultDNSRule(router, logger, subRule) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } + r.rules[i] = rule + } + return r, nil +} + +func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { + if r.mode == C.LogicalTypeAnd { + return common.All(r.rules, func(it *DefaultDNSRule) bool { + return it.Match(metadata) + }) + } else { + return common.Any(r.rules, func(it *DefaultDNSRule) bool { + return it.Match(metadata) + }) + } +} + +func (r *LogicalDNSRule) Outbound() string { + return r.outbound +} + +func (r *LogicalDNSRule) String() string { + var op string + switch r.mode { + case C.LogicalTypeAnd: + op = "&&" + case C.LogicalTypeOr: + op = "||" + } + return "logical(" + strings.Join(common.Map(r.rules, F.ToString0[*DefaultDNSRule]), " "+op+" ") + ")" +} diff --git a/route/rule_logical.go b/route/rule_logical.go deleted file mode 100644 index 141a3ab3..00000000 --- a/route/rule_logical.go +++ /dev/null @@ -1,102 +0,0 @@ -package route - -import ( - "strings" - - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - - "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" -) - -var _ adapter.Rule = (*LogicalRule)(nil) - -type LogicalRule struct { - mode string - rules []*DefaultRule - outbound string -} - -func (r *LogicalRule) UpdateGeosite() error { - for _, rule := range r.rules { - err := rule.UpdateGeosite() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalRule) Start() error { - for _, rule := range r.rules { - err := rule.Start() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalRule) Close() error { - for _, rule := range r.rules { - err := rule.Close() - if err != nil { - return err - } - } - return nil -} - -func NewLogicalRule(router adapter.Router, logger log.Logger, options option.LogicalRule) (*LogicalRule, error) { - r := &LogicalRule{ - rules: make([]*DefaultRule, len(options.Rules)), - outbound: options.Outbound, - } - switch options.Mode { - case C.LogicalTypeAnd: - r.mode = C.LogicalTypeAnd - case C.LogicalTypeOr: - r.mode = C.LogicalTypeOr - default: - return nil, E.New("unknown logical mode: ", options.Mode) - } - for i, subRule := range options.Rules { - rule, err := NewDefaultRule(router, logger, subRule) - if err != nil { - return nil, E.Cause(err, "sub rule[", i, "]") - } - r.rules[i] = rule - } - return r, nil -} - -func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool { - if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it *DefaultRule) bool { - return it.Match(metadata) - }) - } else { - return common.Any(r.rules, func(it *DefaultRule) bool { - return it.Match(metadata) - }) - } -} - -func (r *LogicalRule) Outbound() string { - return r.outbound -} - -func (r *LogicalRule) String() string { - var op string - switch r.mode { - case C.LogicalTypeAnd: - op = "&&" - case C.LogicalTypeOr: - op = "||" - } - return "logical(" + strings.Join(common.Map(r.rules, F.ToString0[*DefaultRule]), " "+op+" ") + ")" -} diff --git a/route/rule_outbound.go b/route/rule_outbound.go new file mode 100644 index 00000000..612d4876 --- /dev/null +++ b/route/rule_outbound.go @@ -0,0 +1,36 @@ +package route + +import ( + "strings" + + F "github.com/sagernet/sing/common/format" + + "github.com/sagernet/sing-box/adapter" +) + +var _ RuleItem = (*OutboundItem)(nil) + +type OutboundItem struct { + outbounds []string + outboundMap map[string]bool +} + +func NewOutboundRule(outbounds []string) *OutboundItem { + rule := &OutboundItem{outbounds, make(map[string]bool)} + for _, outbound := range outbounds { + rule.outboundMap[outbound] = true + } + return rule +} + +func (r *OutboundItem) Match(metadata *adapter.InboundContext) bool { + return r.outboundMap[metadata.Outbound] +} + +func (r *OutboundItem) String() string { + if len(r.outbounds) == 1 { + return F.ToString("outbound=", r.outbounds[0]) + } else { + return F.ToString("outbound=[", strings.Join(r.outbounds, " "), "]") + } +}