From 923d3222b013af5174569a83843f324770f6ea4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 24 Jun 2024 09:49:15 +0800 Subject: [PATCH] WTF is this --- adapter/experimental.go | 17 +-- adapter/router.go | 2 +- box_outbound.go | 4 +- cmd/internal/build_libbox/main.go | 4 +- cmd/internal/build_shared/sdk.go | 10 +- cmd/sing-box/cmd_merge.go | 6 +- cmd/sing-box/cmd_run.go | 2 +- common/geosite/geosite_test.go | 34 +++++ common/geosite/reader.go | 90 +++++++---- common/geosite/writer.go | 21 ++- common/srs/binary.go | 121 ++++++--------- common/srs/ip_set.go | 107 +++++-------- constant/path.go | 6 +- experimental/libbox/command_clash_mode.go | 18 +-- .../libbox/command_close_connection.go | 5 +- experimental/libbox/command_connections.go | 36 +++-- experimental/libbox/command_group.go | 140 +++--------------- experimental/libbox/command_log.go | 66 +++++---- experimental/libbox/command_power.go | 10 +- experimental/libbox/command_select.go | 10 +- experimental/libbox/command_server.go | 8 - experimental/libbox/command_shared.go | 6 +- experimental/libbox/command_urltest.go | 6 +- experimental/libbox/profile_import.go | 43 +++--- experimental/libbox/setup.go | 6 + go.mod | 2 +- go.sum | 4 +- inbound/mixed.go | 14 +- inbound/vless.go | 11 +- inbound/vmess.go | 11 +- option/outbound.go | 2 +- option/route.go | 2 +- outbound/proxy.go | 18 +-- outbound/tor.go | 12 +- route/router.go | 4 +- route/router_geo_resources.go | 8 +- route/rule_abstract.go | 19 ++- transport/trojan/mux.go | 27 +++- transport/trojan/service.go | 4 +- transport/v2raygrpc/conn.go | 4 +- transport/v2raygrpclite/conn.go | 8 +- 41 files changed, 421 insertions(+), 507 deletions(-) create mode 100644 common/geosite/geosite_test.go diff --git a/adapter/experimental.go b/adapter/experimental.go index 5e1cbd9d..0cab5ed5 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -4,14 +4,13 @@ import ( "bytes" "context" "encoding/binary" - "io" "net" "time" "github.com/sagernet/sing-box/common/urltest" "github.com/sagernet/sing-dns" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) type ClashServer interface { @@ -56,16 +55,15 @@ func (s *SavedRuleSet) MarshalBinary() ([]byte, error) { if err != nil { return nil, err } - err = rw.WriteUVariant(&buffer, uint64(len(s.Content))) + err = varbin.Write(&buffer, binary.BigEndian, s.Content) if err != nil { return nil, err } - buffer.Write(s.Content) err = binary.Write(&buffer, binary.BigEndian, s.LastUpdated.Unix()) if err != nil { return nil, err } - err = rw.WriteVString(&buffer, s.LastEtag) + err = varbin.Write(&buffer, binary.BigEndian, s.LastEtag) if err != nil { return nil, err } @@ -79,12 +77,7 @@ func (s *SavedRuleSet) UnmarshalBinary(data []byte) error { if err != nil { return err } - contentLen, err := rw.ReadUVariant(reader) - if err != nil { - return err - } - s.Content = make([]byte, contentLen) - _, err = io.ReadFull(reader, s.Content) + err = varbin.Read(reader, binary.BigEndian, &s.Content) if err != nil { return err } @@ -94,7 +87,7 @@ func (s *SavedRuleSet) UnmarshalBinary(data []byte) error { return err } s.LastUpdated = time.Unix(lastUpdated, 0) - s.LastEtag, err = rw.ReadVString(reader) + err = varbin.Read(reader, binary.BigEndian, &s.LastEtag) if err != nil { return err } diff --git a/adapter/router.go b/adapter/router.go index 54dc3396..c481f0c8 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -45,7 +45,7 @@ type Router interface { DefaultInterface() string AutoDetectInterface() bool AutoDetectInterfaceFunc() control.Func - DefaultMark() int + DefaultMark() uint32 NetworkMonitor() tun.NetworkUpdateMonitor InterfaceMonitor() tun.DefaultInterfaceMonitor PackageManager() tun.PackageManager diff --git a/box_outbound.go b/box_outbound.go index 6e3f0617..f03f3b7d 100644 --- a/box_outbound.go +++ b/box_outbound.go @@ -45,7 +45,9 @@ func (s *Box) startOutbounds() error { } started[outboundTag] = true canContinue = true - if starter, isStarter := outboundToStart.(common.Starter); isStarter { + if starter, isStarter := outboundToStart.(interface { + Start() error + }); isStarter { monitor.Start("initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]") err := starter.Start() monitor.Finish() diff --git a/cmd/internal/build_libbox/main.go b/cmd/internal/build_libbox/main.go index ae0fe34a..fc9308ff 100644 --- a/cmd/internal/build_libbox/main.go +++ b/cmd/internal/build_libbox/main.go @@ -93,7 +93,7 @@ func buildAndroid() { const name = "libbox.aar" copyPath := filepath.Join("..", "sing-box-for-android", "app", "libs") - if rw.FileExists(copyPath) { + if rw.IsDir(copyPath) { copyPath, _ = filepath.Abs(copyPath) err = rw.CopyFile(name, filepath.Join(copyPath, name)) if err != nil { @@ -134,7 +134,7 @@ func buildiOS() { } copyPath := filepath.Join("..", "sing-box-for-apple") - if rw.FileExists(copyPath) { + if rw.IsDir(copyPath) { targetDir := filepath.Join(copyPath, "Libbox.xcframework") targetDir, _ = filepath.Abs(targetDir) os.RemoveAll(targetDir) diff --git a/cmd/internal/build_shared/sdk.go b/cmd/internal/build_shared/sdk.go index ce7f0c86..b6c1ec9d 100644 --- a/cmd/internal/build_shared/sdk.go +++ b/cmd/internal/build_shared/sdk.go @@ -30,7 +30,7 @@ func FindSDK() { } for _, path := range searchPath { path = os.ExpandEnv(path) - if rw.FileExists(filepath.Join(path, "licenses", "android-sdk-license")) { + if rw.IsFile(filepath.Join(path, "licenses", "android-sdk-license")) { androidSDKPath = path break } @@ -60,7 +60,7 @@ func FindSDK() { func findNDK() bool { const fixedVersion = "26.2.11394342" const versionFile = "source.properties" - if fixedPath := filepath.Join(androidSDKPath, "ndk", fixedVersion); rw.FileExists(filepath.Join(fixedPath, versionFile)) { + if fixedPath := filepath.Join(androidSDKPath, "ndk", fixedVersion); rw.IsFile(filepath.Join(fixedPath, versionFile)) { androidNDKPath = fixedPath return true } @@ -86,7 +86,7 @@ func findNDK() bool { }) for _, versionName := range versionNames { currentNDKPath := filepath.Join(androidSDKPath, "ndk", versionName) - if rw.FileExists(filepath.Join(androidSDKPath, versionFile)) { + if rw.IsFile(filepath.Join(androidSDKPath, versionFile)) { androidNDKPath = currentNDKPath log.Warn("reproducibility warning: using NDK version " + versionName + " instead of " + fixedVersion) return true @@ -100,11 +100,11 @@ var GoBinPath string func FindMobile() { goBin := filepath.Join(build.Default.GOPATH, "bin") if runtime.GOOS == "windows" { - if !rw.FileExists(filepath.Join(goBin, "gobind.exe")) { + if !rw.IsFile(filepath.Join(goBin, "gobind.exe")) { log.Fatal("missing gomobile installation") } } else { - if !rw.FileExists(filepath.Join(goBin, "gobind")) { + if !rw.IsFile(filepath.Join(goBin, "gobind")) { log.Fatal("missing gomobile installation") } } diff --git a/cmd/sing-box/cmd_merge.go b/cmd/sing-box/cmd_merge.go index 1d19ff17..10dd38a1 100644 --- a/cmd/sing-box/cmd_merge.go +++ b/cmd/sing-box/cmd_merge.go @@ -54,7 +54,11 @@ func merge(outputPath string) error { return nil } } - err = rw.WriteFile(outputPath, buffer.Bytes()) + err = rw.MkdirParent(outputPath) + if err != nil { + return err + } + err = os.WriteFile(outputPath, buffer.Bytes(), 0o644) if err != nil { return err } diff --git a/cmd/sing-box/cmd_run.go b/cmd/sing-box/cmd_run.go index 3c4dd0d9..e717c594 100644 --- a/cmd/sing-box/cmd_run.go +++ b/cmd/sing-box/cmd_run.go @@ -109,7 +109,7 @@ func readConfigAndMerge() (option.Options, error) { } var mergedMessage json.RawMessage for _, options := range optionsList { - mergedMessage, err = badjson.MergeJSON(options.options.RawMessage, mergedMessage) + mergedMessage, err = badjson.MergeJSON(options.options.RawMessage, mergedMessage, false) if err != nil { return option.Options{}, E.Cause(err, "merge config at ", options.path) } diff --git a/common/geosite/geosite_test.go b/common/geosite/geosite_test.go new file mode 100644 index 00000000..bdcb7a7a --- /dev/null +++ b/common/geosite/geosite_test.go @@ -0,0 +1,34 @@ +package geosite_test + +import ( + "bytes" + "testing" + + "github.com/sagernet/sing-box/common/geosite" + + "github.com/stretchr/testify/require" +) + +func TestGeosite(t *testing.T) { + t.Parallel() + + var buffer bytes.Buffer + err := geosite.Write(&buffer, map[string][]geosite.Item{ + "test": { + { + Type: geosite.RuleTypeDomain, + Value: "example.org", + }, + }, + }) + require.NoError(t, err) + reader, codes, err := geosite.NewReader(bytes.NewReader(buffer.Bytes())) + require.NoError(t, err) + require.Equal(t, []string{"test"}, codes) + items, err := reader.Read("test") + require.NoError(t, err) + require.Equal(t, []geosite.Item{{ + Type: geosite.RuleTypeDomain, + Value: "example.org", + }}, items) +} diff --git a/common/geosite/reader.go b/common/geosite/reader.go index a1b39f28..3b3f7fec 100644 --- a/common/geosite/reader.go +++ b/common/geosite/reader.go @@ -1,17 +1,24 @@ package geosite import ( + "bufio" + "encoding/binary" "io" "os" + "sync" + "sync/atomic" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) type Reader struct { - reader io.ReadSeeker - domainIndex map[string]int - domainLength map[string]int + access sync.Mutex + reader io.ReadSeeker + bufferedReader *bufio.Reader + metadataIndex int64 + domainIndex map[string]int + domainLength map[string]int } func Open(path string) (*Reader, []string, error) { @@ -19,14 +26,22 @@ func Open(path string) (*Reader, []string, error) { if err != nil { return nil, nil, err } - reader := &Reader{ - reader: content, - } - err = reader.readMetadata() + reader, codes, err := NewReader(content) if err != nil { content.Close() return nil, nil, err } + return reader, codes, nil +} + +func NewReader(readSeeker io.ReadSeeker) (*Reader, []string, error) { + reader := &Reader{ + reader: readSeeker, + } + err := reader.readMetadata() + if err != nil { + return nil, nil, err + } codes := make([]string, 0, len(reader.domainIndex)) for code := range reader.domainIndex { codes = append(codes, code) @@ -34,15 +49,23 @@ func Open(path string) (*Reader, []string, error) { return reader, codes, nil } +type geositeMetadata struct { + Code string + Index uint64 + Length uint64 +} + func (r *Reader) readMetadata() error { - version, err := rw.ReadByte(r.reader) + counter := &readCounter{Reader: r.reader} + reader := bufio.NewReader(counter) + version, err := reader.ReadByte() if err != nil { return err } if version != 0 { return E.New("unknown version") } - entryLength, err := rw.ReadUVariant(r.reader) + entryLength, err := binary.ReadUvarint(reader) if err != nil { return err } @@ -55,16 +78,16 @@ func (r *Reader) readMetadata() error { codeIndex uint64 codeLength uint64 ) - code, err = rw.ReadVString(r.reader) + code, err = varbin.ReadValue[string](reader, binary.BigEndian) if err != nil { return err } keys[i] = code - codeIndex, err = rw.ReadUVariant(r.reader) + codeIndex, err = binary.ReadUvarint(reader) if err != nil { return err } - codeLength, err = rw.ReadUVariant(r.reader) + codeLength, err = binary.ReadUvarint(reader) if err != nil { return err } @@ -73,6 +96,8 @@ func (r *Reader) readMetadata() error { } r.domainIndex = domainIndex r.domainLength = domainLength + r.metadataIndex = counter.count - int64(reader.Buffered()) + r.bufferedReader = reader return nil } @@ -81,31 +106,32 @@ func (r *Reader) Read(code string) ([]Item, error) { if !exists { return nil, E.New("code ", code, " not exists!") } - _, err := r.reader.Seek(int64(index), io.SeekCurrent) + _, err := r.reader.Seek(r.metadataIndex+int64(index), io.SeekStart) if err != nil { return nil, err } - counter := &rw.ReadCounter{Reader: r.reader} - domain := make([]Item, r.domainLength[code]) - for i := range domain { - var ( - item Item - err error - ) - item.Type, err = rw.ReadByte(counter) - if err != nil { - return nil, err - } - item.Value, err = rw.ReadVString(counter) - if err != nil { - return nil, err - } - domain[i] = item + r.bufferedReader.Reset(r.reader) + itemList := make([]Item, r.domainLength[code]) + err = varbin.Read(r.bufferedReader, binary.BigEndian, &itemList) + if err != nil { + return nil, err } - _, err = r.reader.Seek(int64(-index)-counter.Count(), io.SeekCurrent) - return domain, err + return itemList, nil } func (r *Reader) Upstream() any { return r.reader } + +type readCounter struct { + io.Reader + count int64 +} + +func (r *readCounter) Read(p []byte) (n int, err error) { + n, err = r.Reader.Read(p) + if n > 0 { + atomic.AddInt64(&r.count, int64(n)) + } + return +} diff --git a/common/geosite/writer.go b/common/geosite/writer.go index 4e7ec514..1615fa34 100644 --- a/common/geosite/writer.go +++ b/common/geosite/writer.go @@ -2,13 +2,13 @@ package geosite import ( "bytes" - "io" + "encoding/binary" "sort" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) -func Write(writer io.Writer, domains map[string][]Item) error { +func Write(writer varbin.Writer, domains map[string][]Item) error { keys := make([]string, 0, len(domains)) for code := range domains { keys = append(keys, code) @@ -19,35 +19,34 @@ func Write(writer io.Writer, domains map[string][]Item) error { index := make(map[string]int) for _, code := range keys { index[code] = content.Len() - for _, domain := range domains[code] { - content.WriteByte(domain.Type) - err := rw.WriteVString(content, domain.Value) + for _, item := range domains[code] { + err := varbin.Write(content, binary.BigEndian, item) if err != nil { return err } } } - err := rw.WriteByte(writer, 0) + err := writer.WriteByte(0) if err != nil { return err } - err = rw.WriteUVariant(writer, uint64(len(keys))) + _, err = varbin.WriteUvarint(writer, uint64(len(keys))) if err != nil { return err } for _, code := range keys { - err = rw.WriteVString(writer, code) + err = varbin.Write(writer, binary.BigEndian, code) if err != nil { return err } - err = rw.WriteUVariant(writer, uint64(index[code])) + _, err = varbin.WriteUvarint(writer, uint64(index[code])) if err != nil { return err } - err = rw.WriteUVariant(writer, uint64(len(domains[code]))) + _, err = varbin.WriteUvarint(writer, uint64(len(domains[code]))) if err != nil { return err } diff --git a/common/srs/binary.go b/common/srs/binary.go index faf4cd17..c7c55e08 100644 --- a/common/srs/binary.go +++ b/common/srs/binary.go @@ -1,6 +1,7 @@ package srs import ( + "bufio" "compress/zlib" "encoding/binary" "io" @@ -11,7 +12,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/domain" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" "go4.org/netipx" ) @@ -38,7 +39,7 @@ const ( ruleItemFinal uint8 = 0xFF ) -func Read(reader io.Reader, recovery bool) (ruleSet option.PlainRuleSet, err error) { +func Read(reader io.Reader, recover bool) (ruleSet option.PlainRuleSet, err error) { var magicBytes [3]byte _, err = io.ReadFull(reader, magicBytes[:]) if err != nil { @@ -60,13 +61,14 @@ func Read(reader io.Reader, recovery bool) (ruleSet option.PlainRuleSet, err err if err != nil { return } - length, err := rw.ReadUVariant(zReader) + bReader := bufio.NewReader(zReader) + length, err := binary.ReadUvarint(bReader) if err != nil { return } ruleSet.Rules = make([]option.HeadlessRule, length) for i := uint64(0); i < length; i++ { - ruleSet.Rules[i], err = readRule(zReader, recovery) + ruleSet.Rules[i], err = readRule(bReader, recover) if err != nil { err = E.Cause(err, "read rule[", i, "]") return @@ -88,20 +90,25 @@ func Write(writer io.Writer, ruleSet option.PlainRuleSet) error { if err != nil { return err } - err = rw.WriteUVariant(zWriter, uint64(len(ruleSet.Rules))) + bWriter := bufio.NewWriter(zWriter) + _, err = varbin.WriteUvarint(bWriter, uint64(len(ruleSet.Rules))) if err != nil { return err } for _, rule := range ruleSet.Rules { - err = writeRule(zWriter, rule) + err = writeRule(bWriter, rule) if err != nil { return err } } + err = bWriter.Flush() + if err != nil { + return err + } return zWriter.Close() } -func readRule(reader io.Reader, recovery bool) (rule option.HeadlessRule, err error) { +func readRule(reader varbin.Reader, recover bool) (rule option.HeadlessRule, err error) { var ruleType uint8 err = binary.Read(reader, binary.BigEndian, &ruleType) if err != nil { @@ -110,17 +117,17 @@ func readRule(reader io.Reader, recovery bool) (rule option.HeadlessRule, err er switch ruleType { case 0: rule.Type = C.RuleTypeDefault - rule.DefaultOptions, err = readDefaultRule(reader, recovery) + rule.DefaultOptions, err = readDefaultRule(reader, recover) case 1: rule.Type = C.RuleTypeLogical - rule.LogicalOptions, err = readLogicalRule(reader, recovery) + rule.LogicalOptions, err = readLogicalRule(reader, recover) default: err = E.New("unknown rule type: ", ruleType) } return } -func writeRule(writer io.Writer, rule option.HeadlessRule) error { +func writeRule(writer varbin.Writer, rule option.HeadlessRule) error { switch rule.Type { case C.RuleTypeDefault: return writeDefaultRule(writer, rule.DefaultOptions) @@ -131,7 +138,7 @@ func writeRule(writer io.Writer, rule option.HeadlessRule) error { } } -func readDefaultRule(reader io.Reader, recovery bool) (rule option.DefaultHeadlessRule, err error) { +func readDefaultRule(reader varbin.Reader, recover bool) (rule option.DefaultHeadlessRule, err error) { var lastItemType uint8 for { var itemType uint8 @@ -158,6 +165,9 @@ func readDefaultRule(reader io.Reader, recovery bool) (rule option.DefaultHeadle return } rule.DomainMatcher = matcher + if recover { + rule.Domain, rule.DomainSuffix = matcher.Dump() + } case ruleItemDomainKeyword: rule.DomainKeyword, err = readRuleItemString(reader) case ruleItemDomainRegex: @@ -167,7 +177,7 @@ func readDefaultRule(reader io.Reader, recovery bool) (rule option.DefaultHeadle if err != nil { return } - if recovery { + if recover { rule.SourceIPCIDR = common.Map(rule.SourceIPSet.Prefixes(), netip.Prefix.String) } case ruleItemIPCIDR: @@ -175,7 +185,7 @@ func readDefaultRule(reader io.Reader, recovery bool) (rule option.DefaultHeadle if err != nil { return } - if recovery { + if recover { rule.IPCIDR = common.Map(rule.IPSet.Prefixes(), netip.Prefix.String) } case ruleItemSourcePort: @@ -209,7 +219,7 @@ func readDefaultRule(reader io.Reader, recovery bool) (rule option.DefaultHeadle } } -func writeDefaultRule(writer io.Writer, rule option.DefaultHeadlessRule) error { +func writeDefaultRule(writer varbin.Writer, rule option.DefaultHeadlessRule) error { err := binary.Write(writer, binary.BigEndian, uint8(0)) if err != nil { return err @@ -327,73 +337,31 @@ func writeDefaultRule(writer io.Writer, rule option.DefaultHeadlessRule) error { return nil } -func readRuleItemString(reader io.Reader) ([]string, error) { - length, err := rw.ReadUVariant(reader) - if err != nil { - return nil, err - } - value := make([]string, length) - for i := uint64(0); i < length; i++ { - value[i], err = rw.ReadVString(reader) - if err != nil { - return nil, err - } - } - return value, nil +func readRuleItemString(reader varbin.Reader) ([]string, error) { + return varbin.ReadValue[[]string](reader, binary.BigEndian) } -func writeRuleItemString(writer io.Writer, itemType uint8, value []string) error { - err := binary.Write(writer, binary.BigEndian, itemType) +func writeRuleItemString(writer varbin.Writer, itemType uint8, value []string) error { + err := writer.WriteByte(itemType) if err != nil { return err } - err = rw.WriteUVariant(writer, uint64(len(value))) + return varbin.Write(writer, binary.BigEndian, value) +} + +func readRuleItemUint16(reader varbin.Reader) ([]uint16, error) { + return varbin.ReadValue[[]uint16](reader, binary.BigEndian) +} + +func writeRuleItemUint16(writer varbin.Writer, itemType uint8, value []uint16) error { + err := writer.WriteByte(itemType) if err != nil { return err } - for _, item := range value { - err = rw.WriteVString(writer, item) - if err != nil { - return err - } - } - return nil + return varbin.Write(writer, binary.BigEndian, value) } -func readRuleItemUint16(reader io.Reader) ([]uint16, error) { - length, err := rw.ReadUVariant(reader) - if err != nil { - return nil, err - } - value := make([]uint16, length) - for i := uint64(0); i < length; i++ { - err = binary.Read(reader, binary.BigEndian, &value[i]) - if err != nil { - return nil, err - } - } - return value, nil -} - -func writeRuleItemUint16(writer io.Writer, itemType uint8, value []uint16) error { - err := binary.Write(writer, binary.BigEndian, itemType) - if err != nil { - return err - } - err = rw.WriteUVariant(writer, uint64(len(value))) - if err != nil { - return err - } - for _, item := range value { - err = binary.Write(writer, binary.BigEndian, item) - if err != nil { - return err - } - } - return nil -} - -func writeRuleItemCIDR(writer io.Writer, itemType uint8, value []string) error { +func writeRuleItemCIDR(writer varbin.Writer, itemType uint8, value []string) error { var builder netipx.IPSetBuilder for i, prefixString := range value { prefix, err := netip.ParsePrefix(prefixString) @@ -419,9 +387,8 @@ func writeRuleItemCIDR(writer io.Writer, itemType uint8, value []string) error { return writeIPSet(writer, ipSet) } -func readLogicalRule(reader io.Reader, recovery bool) (logicalRule option.LogicalHeadlessRule, err error) { - var mode uint8 - err = binary.Read(reader, binary.BigEndian, &mode) +func readLogicalRule(reader varbin.Reader, recovery bool) (logicalRule option.LogicalHeadlessRule, err error) { + mode, err := reader.ReadByte() if err != nil { return } @@ -434,7 +401,7 @@ func readLogicalRule(reader io.Reader, recovery bool) (logicalRule option.Logica err = E.New("unknown logical mode: ", mode) return } - length, err := rw.ReadUVariant(reader) + length, err := binary.ReadUvarint(reader) if err != nil { return } @@ -453,7 +420,7 @@ func readLogicalRule(reader io.Reader, recovery bool) (logicalRule option.Logica return } -func writeLogicalRule(writer io.Writer, logicalRule option.LogicalHeadlessRule) error { +func writeLogicalRule(writer varbin.Writer, logicalRule option.LogicalHeadlessRule) error { err := binary.Write(writer, binary.BigEndian, uint8(1)) if err != nil { return err @@ -469,7 +436,7 @@ func writeLogicalRule(writer io.Writer, logicalRule option.LogicalHeadlessRule) if err != nil { return err } - err = rw.WriteUVariant(writer, uint64(len(logicalRule.Rules))) + _, err = varbin.WriteUvarint(writer, uint64(len(logicalRule.Rules))) if err != nil { return err } diff --git a/common/srs/ip_set.go b/common/srs/ip_set.go index b346da26..044dc823 100644 --- a/common/srs/ip_set.go +++ b/common/srs/ip_set.go @@ -2,11 +2,13 @@ package srs import ( "encoding/binary" - "io" "net/netip" + "os" "unsafe" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/varbin" "go4.org/netipx" ) @@ -20,94 +22,57 @@ type myIPRange struct { to netip.Addr } -func readIPSet(reader io.Reader) (*netipx.IPSet, error) { - var version uint8 - err := binary.Read(reader, binary.BigEndian, &version) +type myIPRangeData struct { + From []byte + To []byte +} + +func readIPSet(reader varbin.Reader) (*netipx.IPSet, error) { + version, err := reader.ReadByte() if err != nil { return nil, err } + if version != 1 { + return nil, os.ErrInvalid + } + // WTF why using uint64 here var length uint64 err = binary.Read(reader, binary.BigEndian, &length) if err != nil { return nil, err } - mySet := &myIPSet{ - rr: make([]myIPRange, length), + ranges := make([]myIPRangeData, length) + err = varbin.Read(reader, binary.BigEndian, &ranges) + if err != nil { + return nil, err } - for i := uint64(0); i < length; i++ { - var ( - fromLen uint64 - toLen uint64 - fromAddr netip.Addr - toAddr netip.Addr - ) - fromLen, err = rw.ReadUVariant(reader) - if err != nil { - return nil, err - } - fromBytes := make([]byte, fromLen) - _, err = io.ReadFull(reader, fromBytes) - if err != nil { - return nil, err - } - err = fromAddr.UnmarshalBinary(fromBytes) - if err != nil { - return nil, err - } - toLen, err = rw.ReadUVariant(reader) - if err != nil { - return nil, err - } - toBytes := make([]byte, toLen) - _, err = io.ReadFull(reader, toBytes) - if err != nil { - return nil, err - } - err = toAddr.UnmarshalBinary(toBytes) - if err != nil { - return nil, err - } - mySet.rr[i] = myIPRange{fromAddr, toAddr} + mySet := &myIPSet{ + rr: make([]myIPRange, len(ranges)), + } + for i, rangeData := range ranges { + mySet.rr[i].from = M.AddrFromIP(rangeData.From) + mySet.rr[i].to = M.AddrFromIP(rangeData.To) } return (*netipx.IPSet)(unsafe.Pointer(mySet)), nil } -func writeIPSet(writer io.Writer, set *netipx.IPSet) error { - err := binary.Write(writer, binary.BigEndian, uint8(1)) +func writeIPSet(writer varbin.Writer, set *netipx.IPSet) error { + err := writer.WriteByte(1) if err != nil { return err } - mySet := (*myIPSet)(unsafe.Pointer(set)) - err = binary.Write(writer, binary.BigEndian, uint64(len(mySet.rr))) + dataList := common.Map((*myIPSet)(unsafe.Pointer(set)).rr, func(rr myIPRange) myIPRangeData { + return myIPRangeData{ + From: rr.from.AsSlice(), + To: rr.to.AsSlice(), + } + }) + err = binary.Write(writer, binary.BigEndian, uint64(len(dataList))) if err != nil { return err } - for _, rr := range mySet.rr { - var ( - fromBinary []byte - toBinary []byte - ) - fromBinary, err = rr.from.MarshalBinary() - if err != nil { - return err - } - err = rw.WriteUVariant(writer, uint64(len(fromBinary))) - if err != nil { - return err - } - _, err = writer.Write(fromBinary) - if err != nil { - return err - } - toBinary, err = rr.to.MarshalBinary() - if err != nil { - return err - } - err = rw.WriteUVariant(writer, uint64(len(toBinary))) - if err != nil { - return err - } - _, err = writer.Write(toBinary) + for _, data := range dataList { + err = varbin.Write(writer, binary.BigEndian, data) if err != nil { return err } diff --git a/constant/path.go b/constant/path.go index 98acacdc..ea2aad3e 100644 --- a/constant/path.go +++ b/constant/path.go @@ -13,14 +13,14 @@ var resourcePaths []string func FindPath(name string) (string, bool) { name = os.ExpandEnv(name) - if rw.FileExists(name) { + if rw.IsFile(name) { return name, true } for _, dir := range resourcePaths { - if path := filepath.Join(dir, dirName, name); rw.FileExists(path) { + if path := filepath.Join(dir, dirName, name); rw.IsFile(path) { return path, true } - if path := filepath.Join(dir, name); rw.FileExists(path) { + if path := filepath.Join(dir, name); rw.IsFile(path) { return path, true } } diff --git a/experimental/libbox/command_clash_mode.go b/experimental/libbox/command_clash_mode.go index 3377ae3a..1b6eb470 100644 --- a/experimental/libbox/command_clash_mode.go +++ b/experimental/libbox/command_clash_mode.go @@ -9,7 +9,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/experimental/clashapi" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) func (c *CommandClient) SetClashMode(newMode string) error { @@ -22,7 +22,7 @@ func (c *CommandClient) SetClashMode(newMode string) error { if err != nil { return err } - err = rw.WriteVString(conn, newMode) + err = varbin.Write(conn, binary.BigEndian, newMode) if err != nil { return err } @@ -30,7 +30,7 @@ func (c *CommandClient) SetClashMode(newMode string) error { } func (s *CommandServer) handleSetClashMode(conn net.Conn) error { - newMode, err := rw.ReadVString(conn) + newMode, err := varbin.ReadValue[string](conn, binary.BigEndian) if err != nil { return err } @@ -50,7 +50,7 @@ func (c *CommandClient) handleModeConn(conn net.Conn) { defer conn.Close() for { - newMode, err := rw.ReadVString(conn) + newMode, err := varbin.ReadValue[string](conn, binary.BigEndian) if err != nil { c.handler.Disconnected(err.Error()) return @@ -80,7 +80,7 @@ func (s *CommandServer) handleModeConn(conn net.Conn) error { for { select { case <-s.modeUpdate: - err = rw.WriteVString(conn, clashServer.Mode()) + err = varbin.Write(conn, binary.BigEndian, clashServer.Mode()) if err != nil { return err } @@ -101,12 +101,12 @@ func readClashModeList(reader io.Reader) (modeList []string, currentMode string, } modeList = make([]string, modeListLength) for i := 0; i < int(modeListLength); i++ { - modeList[i], err = rw.ReadVString(reader) + modeList[i], err = varbin.ReadValue[string](reader, binary.BigEndian) if err != nil { return } } - currentMode, err = rw.ReadVString(reader) + currentMode, err = varbin.ReadValue[string](reader, binary.BigEndian) return } @@ -118,12 +118,12 @@ func writeClashModeList(writer io.Writer, clashServer adapter.ClashServer) error } if len(modeList) > 0 { for _, mode := range modeList { - err = rw.WriteVString(writer, mode) + err = varbin.Write(writer, binary.BigEndian, mode) if err != nil { return err } } - err = rw.WriteVString(writer, clashServer.Mode()) + err = varbin.Write(writer, binary.BigEndian, clashServer.Mode()) if err != nil { return err } diff --git a/experimental/libbox/command_close_connection.go b/experimental/libbox/command_close_connection.go index 62f5dc84..1edd5911 100644 --- a/experimental/libbox/command_close_connection.go +++ b/experimental/libbox/command_close_connection.go @@ -7,6 +7,7 @@ import ( "github.com/sagernet/sing-box/experimental/clashapi" "github.com/sagernet/sing/common/binary" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/varbin" "github.com/gofrs/uuid/v5" ) @@ -18,7 +19,7 @@ func (c *CommandClient) CloseConnection(connId string) error { } defer conn.Close() writer := bufio.NewWriter(conn) - err = binary.WriteData(writer, binary.BigEndian, connId) + err = varbin.Write(writer, binary.BigEndian, connId) if err != nil { return err } @@ -32,7 +33,7 @@ func (c *CommandClient) CloseConnection(connId string) error { func (s *CommandServer) handleCloseConnection(conn net.Conn) error { reader := bufio.NewReader(conn) var connId string - err := binary.ReadData(reader, binary.BigEndian, &connId) + err := varbin.Read(reader, binary.BigEndian, &connId) if err != nil { return E.Cause(err, "read connection id") } diff --git a/experimental/libbox/command_connections.go b/experimental/libbox/command_connections.go index 9aaa995a..b51c7352 100644 --- a/experimental/libbox/command_connections.go +++ b/experimental/libbox/command_connections.go @@ -12,6 +12,7 @@ import ( "github.com/sagernet/sing/common/binary" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/varbin" "github.com/gofrs/uuid/v5" ) @@ -19,14 +20,18 @@ import ( func (c *CommandClient) handleConnectionsConn(conn net.Conn) { defer conn.Close() reader := bufio.NewReader(conn) - var connections Connections + var ( + rawConnections []Connection + connections Connections + ) for { rawConnections = nil - err := binary.ReadData(reader, binary.BigEndian, &connections.connections) + err := varbin.Read(reader, binary.BigEndian, &rawConnections) if err != nil { c.handler.Disconnected(err.Error()) return } + connections.input = rawConnections c.handler.WriteConnections(&connections) } } @@ -70,7 +75,7 @@ func (s *CommandServer) handleConnectionsConn(conn net.Conn) error { for _, connection := range trafficManager.ClosedConnections() { outConnections = append(outConnections, newConnection(connections, connection, true)) } - err = binary.WriteData(writer, binary.BigEndian, outConnections) + err = varbin.Write(writer, binary.BigEndian, outConnections) if err != nil { return err } @@ -93,33 +98,32 @@ const ( ) type Connections struct { - connections []Connection - filteredConnections []Connection - outConnections *[]Connection + input []Connection + filtered []Connection } func (c *Connections) FilterState(state int32) { - c.filteredConnections = c.filteredConnections[:0] + c.filtered = c.filtered[:0] switch state { case ConnectionStateAll: - c.filteredConnections = append(c.filteredConnections, c.connections...) + c.filtered = append(c.filtered, c.input...) case ConnectionStateActive: - for _, connection := range c.connections { + for _, connection := range c.input { if connection.ClosedAt == 0 { - c.filteredConnections = append(c.filteredConnections, connection) + c.filtered = append(c.filtered, connection) } } case ConnectionStateClosed: - for _, connection := range c.connections { + for _, connection := range c.input { if connection.ClosedAt != 0 { - c.filteredConnections = append(c.filteredConnections, connection) + c.filtered = append(c.filtered, connection) } } } } func (c *Connections) SortByDate() { - slices.SortStableFunc(c.filteredConnections, func(x, y Connection) int { + slices.SortStableFunc(c.filtered, func(x, y Connection) int { if x.CreatedAt < y.CreatedAt { return 1 } else if x.CreatedAt > y.CreatedAt { @@ -131,7 +135,7 @@ func (c *Connections) SortByDate() { } func (c *Connections) SortByTraffic() { - slices.SortStableFunc(c.filteredConnections, func(x, y Connection) int { + slices.SortStableFunc(c.filtered, func(x, y Connection) int { xTraffic := x.Uplink + x.Downlink yTraffic := y.Uplink + y.Downlink if xTraffic < yTraffic { @@ -145,7 +149,7 @@ func (c *Connections) SortByTraffic() { } func (c *Connections) SortByTrafficTotal() { - slices.SortStableFunc(c.filteredConnections, func(x, y Connection) int { + slices.SortStableFunc(c.filtered, func(x, y Connection) int { xTraffic := x.UplinkTotal + x.DownlinkTotal yTraffic := y.UplinkTotal + y.DownlinkTotal if xTraffic < yTraffic { @@ -159,7 +163,7 @@ func (c *Connections) SortByTrafficTotal() { } func (c *Connections) Iterator() ConnectionIterator { - return newPtrIterator(c.filteredConnections) + return newPtrIterator(c.filtered) } type Connection struct { diff --git a/experimental/libbox/command_group.go b/experimental/libbox/command_group.go index a5572ea1..3a8d2a07 100644 --- a/experimental/libbox/command_group.go +++ b/experimental/libbox/command_group.go @@ -1,6 +1,7 @@ package libbox import ( + "bufio" "encoding/binary" "io" "net" @@ -10,7 +11,7 @@ import ( "github.com/sagernet/sing-box/common/urltest" "github.com/sagernet/sing-box/outbound" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" "github.com/sagernet/sing/service" ) @@ -36,19 +37,24 @@ func (s *CommandServer) handleGroupConn(conn net.Conn) error { ticker := time.NewTicker(time.Duration(interval)) defer ticker.Stop() ctx := connKeepAlive(conn) + writer := bufio.NewWriter(conn) for { service := s.service if service != nil { - err := writeGroups(conn, service) + err = writeGroups(writer, service) if err != nil { return err } } else { - err := binary.Write(conn, binary.BigEndian, uint16(0)) + err = binary.Write(writer, binary.BigEndian, uint16(0)) if err != nil { return err } } + err = writer.Flush() + if err != nil { + return err + } select { case <-ctx.Done(): return ctx.Err() @@ -68,11 +74,11 @@ type OutboundGroup struct { Selectable bool Selected string IsExpand bool - items []*OutboundGroupItem + ItemList []*OutboundGroupItem } func (g *OutboundGroup) GetItems() OutboundGroupItemIterator { - return newIterator(g.items) + return newIterator(g.ItemList) } type OutboundGroupIterator interface { @@ -93,73 +99,10 @@ type OutboundGroupItemIterator interface { } func readGroups(reader io.Reader) (OutboundGroupIterator, error) { - var groupLength uint16 - err := binary.Read(reader, binary.BigEndian, &groupLength) + groups, err := varbin.ReadValue[[]*OutboundGroup](reader, binary.BigEndian) if err != nil { return nil, err } - - groups := make([]*OutboundGroup, 0, groupLength) - for i := 0; i < int(groupLength); i++ { - var group OutboundGroup - group.Tag, err = rw.ReadVString(reader) - if err != nil { - return nil, err - } - - group.Type, err = rw.ReadVString(reader) - if err != nil { - return nil, err - } - - err = binary.Read(reader, binary.BigEndian, &group.Selectable) - if err != nil { - return nil, err - } - - group.Selected, err = rw.ReadVString(reader) - if err != nil { - return nil, err - } - - err = binary.Read(reader, binary.BigEndian, &group.IsExpand) - if err != nil { - return nil, err - } - - var itemLength uint16 - err = binary.Read(reader, binary.BigEndian, &itemLength) - if err != nil { - return nil, err - } - - group.items = make([]*OutboundGroupItem, itemLength) - for j := 0; j < int(itemLength); j++ { - var item OutboundGroupItem - item.Tag, err = rw.ReadVString(reader) - if err != nil { - return nil, err - } - - item.Type, err = rw.ReadVString(reader) - if err != nil { - return nil, err - } - - err = binary.Read(reader, binary.BigEndian, &item.URLTestTime) - if err != nil { - return nil, err - } - - err = binary.Read(reader, binary.BigEndian, &item.URLTestDelay) - if err != nil { - return nil, err - } - - group.items[j] = &item - } - groups = append(groups, &group) - } return newIterator(groups), nil } @@ -199,63 +142,14 @@ func writeGroups(writer io.Writer, boxService *BoxService) error { item.URLTestTime = history.Time.Unix() item.URLTestDelay = int32(history.Delay) } - group.items = append(group.items, &item) + group.ItemList = append(group.ItemList, &item) } - if len(group.items) < 2 { + if len(group.ItemList) < 2 { continue } groups = append(groups, group) } - - err := binary.Write(writer, binary.BigEndian, uint16(len(groups))) - if err != nil { - return err - } - for _, group := range groups { - err = rw.WriteVString(writer, group.Tag) - if err != nil { - return err - } - err = rw.WriteVString(writer, group.Type) - if err != nil { - return err - } - err = binary.Write(writer, binary.BigEndian, group.Selectable) - if err != nil { - return err - } - err = rw.WriteVString(writer, group.Selected) - if err != nil { - return err - } - err = binary.Write(writer, binary.BigEndian, group.IsExpand) - if err != nil { - return err - } - err = binary.Write(writer, binary.BigEndian, uint16(len(group.items))) - if err != nil { - return err - } - for _, item := range group.items { - err = rw.WriteVString(writer, item.Tag) - if err != nil { - return err - } - err = rw.WriteVString(writer, item.Type) - if err != nil { - return err - } - err = binary.Write(writer, binary.BigEndian, item.URLTestTime) - if err != nil { - return err - } - err = binary.Write(writer, binary.BigEndian, item.URLTestDelay) - if err != nil { - return err - } - } - } - return nil + return varbin.Write(writer, binary.BigEndian, groups) } func (c *CommandClient) SetGroupExpand(groupTag string, isExpand bool) error { @@ -268,7 +162,7 @@ func (c *CommandClient) SetGroupExpand(groupTag string, isExpand bool) error { if err != nil { return err } - err = rw.WriteVString(conn, groupTag) + err = varbin.Write(conn, binary.BigEndian, groupTag) if err != nil { return err } @@ -280,7 +174,7 @@ func (c *CommandClient) SetGroupExpand(groupTag string, isExpand bool) error { } func (s *CommandServer) handleSetGroupExpand(conn net.Conn) error { - groupTag, err := rw.ReadVString(conn) + groupTag, err := varbin.ReadValue[string](conn, binary.BigEndian) if err != nil { return err } diff --git a/experimental/libbox/command_log.go b/experimental/libbox/command_log.go index 8a22aa2e..07f6e839 100644 --- a/experimental/libbox/command_log.go +++ b/experimental/libbox/command_log.go @@ -9,8 +9,19 @@ import ( "github.com/sagernet/sing/common/binary" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/varbin" ) +func (s *CommandServer) ResetLog() { + s.access.Lock() + defer s.access.Unlock() + s.savedLines.Init() + select { + case s.logReset <- struct{}{}: + default: + } +} + func (s *CommandServer) WriteMessage(message string) { s.subscriber.Emit(message) s.access.Lock() @@ -21,26 +32,6 @@ func (s *CommandServer) WriteMessage(message string) { s.access.Unlock() } -func writeLog(writer *bufio.Writer, messages []string) error { - err := binary.Write(writer, binary.BigEndian, uint8(0)) - if err != nil { - return err - } - err = binary.WriteData(writer, binary.BigEndian, messages) - if err != nil { - return err - } - return writer.Flush() -} - -func writeClearLog(writer *bufio.Writer) error { - err := binary.Write(writer, binary.BigEndian, uint8(1)) - if err != nil { - return err - } - return writer.Flush() -} - func (s *CommandServer) handleLogConn(conn net.Conn) error { var ( interval int64 @@ -67,8 +58,24 @@ func (s *CommandServer) handleLogConn(conn net.Conn) error { } defer s.observer.UnSubscribe(subscription) writer := bufio.NewWriter(conn) + select { + case <-s.logReset: + err = writer.WriteByte(1) + if err != nil { + return err + } + err = writer.Flush() + if err != nil { + return err + } + default: + } if len(savedLines) > 0 { - err = writeLog(writer, savedLines) + err = writer.WriteByte(0) + if err != nil { + return err + } + err = varbin.Write(writer, binary.BigEndian, savedLines) if err != nil { return err } @@ -76,11 +83,15 @@ func (s *CommandServer) handleLogConn(conn net.Conn) error { ctx := connKeepAlive(conn) var logLines []string for { + err = writer.Flush() + if err != nil { + return err + } select { case <-ctx.Done(): return ctx.Err() case <-s.logReset: - err = writeClearLog(writer) + err = writer.WriteByte(1) if err != nil { return err } @@ -99,7 +110,11 @@ func (s *CommandServer) handleLogConn(conn net.Conn) error { break loopLogs } } - err = writeLog(writer, logLines) + err = writer.WriteByte(0) + if err != nil { + return err + } + err = varbin.Write(writer, binary.BigEndian, logLines) if err != nil { return err } @@ -110,8 +125,7 @@ func (s *CommandServer) handleLogConn(conn net.Conn) error { func (c *CommandClient) handleLogConn(conn net.Conn) { reader := bufio.NewReader(conn) for { - var messageType uint8 - err := binary.Read(reader, binary.BigEndian, &messageType) + messageType, err := reader.ReadByte() if err != nil { c.handler.Disconnected(err.Error()) return @@ -119,7 +133,7 @@ func (c *CommandClient) handleLogConn(conn net.Conn) { var messages []string switch messageType { case 0: - err = binary.ReadData(reader, binary.BigEndian, &messages) + err = varbin.Read(reader, binary.BigEndian, &messages) if err != nil { c.handler.Disconnected(err.Error()) return diff --git a/experimental/libbox/command_power.go b/experimental/libbox/command_power.go index 619cb57b..5ed7b014 100644 --- a/experimental/libbox/command_power.go +++ b/experimental/libbox/command_power.go @@ -5,7 +5,7 @@ import ( "net" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) func (c *CommandClient) ServiceReload() error { @@ -24,7 +24,7 @@ func (c *CommandClient) ServiceReload() error { return err } if hasError { - errorMessage, err := rw.ReadVString(conn) + errorMessage, err := varbin.ReadValue[string](conn, binary.BigEndian) if err != nil { return err } @@ -40,7 +40,7 @@ func (s *CommandServer) handleServiceReload(conn net.Conn) error { return err } if rErr != nil { - return rw.WriteVString(conn, rErr.Error()) + return varbin.Write(conn, binary.BigEndian, rErr.Error()) } return nil } @@ -61,7 +61,7 @@ func (c *CommandClient) ServiceClose() error { return nil } if hasError { - errorMessage, err := rw.ReadVString(conn) + errorMessage, err := varbin.ReadValue[string](conn, binary.BigEndian) if err != nil { return nil } @@ -78,7 +78,7 @@ func (s *CommandServer) handleServiceClose(conn net.Conn) error { return err } if rErr != nil { - return rw.WriteVString(conn, rErr.Error()) + return varbin.Write(conn, binary.BigEndian, rErr.Error()) } return nil } diff --git a/experimental/libbox/command_select.go b/experimental/libbox/command_select.go index e7d5b08f..e1e67e60 100644 --- a/experimental/libbox/command_select.go +++ b/experimental/libbox/command_select.go @@ -6,7 +6,7 @@ import ( "github.com/sagernet/sing-box/outbound" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) func (c *CommandClient) SelectOutbound(groupTag string, outboundTag string) error { @@ -19,11 +19,11 @@ func (c *CommandClient) SelectOutbound(groupTag string, outboundTag string) erro if err != nil { return err } - err = rw.WriteVString(conn, groupTag) + err = varbin.Write(conn, binary.BigEndian, groupTag) if err != nil { return err } - err = rw.WriteVString(conn, outboundTag) + err = varbin.Write(conn, binary.BigEndian, outboundTag) if err != nil { return err } @@ -31,11 +31,11 @@ func (c *CommandClient) SelectOutbound(groupTag string, outboundTag string) erro } func (s *CommandServer) handleSelectOutbound(conn net.Conn) error { - groupTag, err := rw.ReadVString(conn) + groupTag, err := varbin.ReadValue[string](conn, binary.BigEndian) if err != nil { return err } - outboundTag, err := rw.ReadVString(conn) + outboundTag, err := varbin.ReadValue[string](conn, binary.BigEndian) if err != nil { return err } diff --git a/experimental/libbox/command_server.go b/experimental/libbox/command_server.go index 8918756d..f913191d 100644 --- a/experimental/libbox/command_server.go +++ b/experimental/libbox/command_server.go @@ -66,14 +66,6 @@ func (s *CommandServer) SetService(newService *BoxService) { s.notifyURLTestUpdate() } -func (s *CommandServer) ResetLog() { - s.savedLines.Init() - select { - case s.logReset <- struct{}{}: - default: - } -} - func (s *CommandServer) notifyURLTestUpdate() { select { case s.urlTestUpdate <- struct{}{}: diff --git a/experimental/libbox/command_shared.go b/experimental/libbox/command_shared.go index ecad78dd..b98c2e5d 100644 --- a/experimental/libbox/command_shared.go +++ b/experimental/libbox/command_shared.go @@ -5,7 +5,7 @@ import ( "io" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) func readError(reader io.Reader) error { @@ -15,7 +15,7 @@ func readError(reader io.Reader) error { return err } if hasError { - errorMessage, err := rw.ReadVString(reader) + errorMessage, err := varbin.ReadValue[string](reader, binary.BigEndian) if err != nil { return err } @@ -30,7 +30,7 @@ func writeError(writer io.Writer, wErr error) error { return err } if wErr != nil { - err = rw.WriteVString(writer, wErr.Error()) + err = varbin.Write(writer, binary.BigEndian, wErr.Error()) if err != nil { return err } diff --git a/experimental/libbox/command_urltest.go b/experimental/libbox/command_urltest.go index 19ddf3da..6feda3f8 100644 --- a/experimental/libbox/command_urltest.go +++ b/experimental/libbox/command_urltest.go @@ -11,7 +11,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/batch" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" "github.com/sagernet/sing/service" ) @@ -25,7 +25,7 @@ func (c *CommandClient) URLTest(groupTag string) error { if err != nil { return err } - err = rw.WriteVString(conn, groupTag) + err = varbin.Write(conn, binary.BigEndian, groupTag) if err != nil { return err } @@ -33,7 +33,7 @@ func (c *CommandClient) URLTest(groupTag string) error { } func (s *CommandServer) handleURLTest(conn net.Conn) error { - groupTag, err := rw.ReadVString(conn) + groupTag, err := varbin.ReadValue[string](conn, binary.BigEndian) if err != nil { return err } diff --git a/experimental/libbox/profile_import.go b/experimental/libbox/profile_import.go index 75ddb06d..258c175a 100644 --- a/experimental/libbox/profile_import.go +++ b/experimental/libbox/profile_import.go @@ -1,13 +1,13 @@ package libbox import ( + "bufio" "bytes" "compress/gzip" "encoding/binary" - "io" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) func EncodeChunkedMessage(data []byte) []byte { @@ -35,13 +35,13 @@ type ErrorMessage struct { func (e *ErrorMessage) Encode() []byte { var buffer bytes.Buffer buffer.WriteByte(MessageTypeError) - rw.WriteVString(&buffer, e.Message) + varbin.Write(&buffer, binary.BigEndian, e.Message) return buffer.Bytes() } func DecodeErrorMessage(data []byte) (*ErrorMessage, error) { reader := bytes.NewReader(data) - messageType, err := rw.ReadByte(reader) + messageType, err := reader.ReadByte() if err != nil { return nil, err } @@ -49,7 +49,7 @@ func DecodeErrorMessage(data []byte) (*ErrorMessage, error) { return nil, E.New("invalid message") } var message ErrorMessage - message.Message, err = rw.ReadVString(reader) + message.Message, err = varbin.ReadValue[string](reader, binary.BigEndian) if err != nil { return nil, err } @@ -87,7 +87,7 @@ func (e *ProfileEncoder) Encode() []byte { binary.Write(&buffer, binary.BigEndian, uint16(len(e.profiles))) for _, preview := range e.profiles { binary.Write(&buffer, binary.BigEndian, preview.ProfileID) - rw.WriteVString(&buffer, preview.Name) + varbin.Write(&buffer, binary.BigEndian, preview.Name) binary.Write(&buffer, binary.BigEndian, preview.Type) } return buffer.Bytes() @@ -117,7 +117,7 @@ func (d *ProfileDecoder) Decode(data []byte) error { if err != nil { return err } - profile.Name, err = rw.ReadVString(reader) + profile.Name, err = varbin.ReadValue[string](reader, binary.BigEndian) if err != nil { return err } @@ -147,7 +147,7 @@ func (r *ProfileContentRequest) Encode() []byte { func DecodeProfileContentRequest(data []byte) (*ProfileContentRequest, error) { reader := bytes.NewReader(data) - messageType, err := rw.ReadByte(reader) + messageType, err := reader.ReadByte() if err != nil { return nil, err } @@ -176,12 +176,13 @@ func (c *ProfileContent) Encode() []byte { buffer := new(bytes.Buffer) buffer.WriteByte(MessageTypeProfileContent) buffer.WriteByte(1) - writer := gzip.NewWriter(buffer) - rw.WriteVString(writer, c.Name) + gWriter := gzip.NewWriter(buffer) + writer := bufio.NewWriter(gWriter) + varbin.Write(writer, binary.BigEndian, c.Name) binary.Write(writer, binary.BigEndian, c.Type) - rw.WriteVString(writer, c.Config) + varbin.Write(writer, binary.BigEndian, c.Config) if c.Type != ProfileTypeLocal { - rw.WriteVString(writer, c.RemotePath) + varbin.Write(writer, binary.BigEndian, c.RemotePath) } if c.Type == ProfileTypeRemote { binary.Write(writer, binary.BigEndian, c.AutoUpdate) @@ -189,29 +190,31 @@ func (c *ProfileContent) Encode() []byte { binary.Write(writer, binary.BigEndian, c.LastUpdated) } writer.Flush() - writer.Close() + gWriter.Flush() + gWriter.Close() return buffer.Bytes() } func DecodeProfileContent(data []byte) (*ProfileContent, error) { - var reader io.Reader = bytes.NewReader(data) - messageType, err := rw.ReadByte(reader) + reader := bytes.NewReader(data) + messageType, err := reader.ReadByte() if err != nil { return nil, err } if messageType != MessageTypeProfileContent { return nil, E.New("invalid message") } - version, err := rw.ReadByte(reader) + version, err := reader.ReadByte() if err != nil { return nil, err } - reader, err = gzip.NewReader(reader) + gReader, err := gzip.NewReader(reader) if err != nil { return nil, E.Cause(err, "unsupported profile") } + bReader := varbin.StubReader(gReader) var content ProfileContent - content.Name, err = rw.ReadVString(reader) + content.Name, err = varbin.ReadValue[string](bReader, binary.BigEndian) if err != nil { return nil, err } @@ -219,12 +222,12 @@ func DecodeProfileContent(data []byte) (*ProfileContent, error) { if err != nil { return nil, err } - content.Config, err = rw.ReadVString(reader) + content.Config, err = varbin.ReadValue[string](bReader, binary.BigEndian) if err != nil { return nil, err } if content.Type != ProfileTypeLocal { - content.RemotePath, err = rw.ReadVString(reader) + content.RemotePath, err = varbin.ReadValue[string](bReader, binary.BigEndian) if err != nil { return nil, err } diff --git a/experimental/libbox/setup.go b/experimental/libbox/setup.go index 31611354..ac67db38 100644 --- a/experimental/libbox/setup.go +++ b/experimental/libbox/setup.go @@ -3,6 +3,7 @@ package libbox import ( "os" "os/user" + "runtime/debug" "strconv" "time" @@ -21,6 +22,11 @@ var ( sTVOS bool ) +func init() { + debug.SetPanicOnFault(true) + debug.SetTraceback("all") +} + func Setup(basePath string, workingPath string, tempPath string, isTVOS bool) { sBasePath = basePath sWorkingPath = workingPath diff --git a/go.mod b/go.mod index d471b443..6346e38a 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f github.com/sagernet/quic-go v0.47.0-beta.2 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.4.3 + github.com/sagernet/sing v0.5.0-beta.2 github.com/sagernet/sing-dns v0.2.3 github.com/sagernet/sing-mux v0.2.0 github.com/sagernet/sing-quic v0.2.2 diff --git a/go.sum b/go.sum index e6147304..1301c072 100644 --- a/go.sum +++ b/go.sum @@ -108,8 +108,8 @@ github.com/sagernet/quic-go v0.47.0-beta.2/go.mod h1:bLVKvElSEMNv7pu7SZHscW02TYi github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= -github.com/sagernet/sing v0.4.3 h1:Ty/NAiNnVd6844k7ujlL5lkzydhcTH5Psc432jXA4Y8= -github.com/sagernet/sing v0.4.3/go.mod h1:ieZHA/+Y9YZfXs2I3WtuwgyCZ6GPsIR7HdKb1SdEnls= +github.com/sagernet/sing v0.5.0-beta.2 h1:V12EpwtsgYo5OLGjAiGoJobDJZeUsKv0b5y+yGAM6W0= +github.com/sagernet/sing v0.5.0-beta.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-dns v0.2.3 h1:YzeBUn2tR38F7HtvGEQ0kLRLmZWMEgi/+7wqa4Twb1k= github.com/sagernet/sing-dns v0.2.3/go.mod h1:BJpJv6XLnrUbSyIntOT6DG9FW0f4fETmPAHvNjOprLg= github.com/sagernet/sing-mux v0.2.0 h1:4C+vd8HztJCWNYfufvgL49xaOoOHXty2+EAjnzN3IYo= diff --git a/inbound/mixed.go b/inbound/mixed.go index 982842ef..3933f7af 100644 --- a/inbound/mixed.go +++ b/inbound/mixed.go @@ -12,10 +12,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/auth" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/protocol/http" "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/socks/socks4" @@ -51,16 +48,17 @@ func NewMixed(ctx context.Context, router adapter.Router, logger log.ContextLogg } func (h *Mixed) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - headerType, err := rw.ReadByte(conn) + reader := std_bufio.NewReader(conn) + headerBytes, err := reader.Peek(1) if err != nil { return err } - switch headerType { + switch headerBytes[0] { case socks4.Version, socks5.Version: - return socks.HandleConnection0(ctx, conn, headerType, h.authenticator, h.upstreamUserHandler(metadata), adapter.UpstreamMetadata(metadata)) + return socks.HandleConnection0(ctx, conn, reader, h.authenticator, h.upstreamUserHandler(metadata), adapter.UpstreamMetadata(metadata)) + default: + return http.HandleConnection(ctx, conn, reader, h.authenticator, h.upstreamUserHandler(metadata), adapter.UpstreamMetadata(metadata)) } - reader := std_bufio.NewReader(bufio.NewCachedReader(conn, buf.As([]byte{headerType}))) - return http.HandleConnection(ctx, conn, reader, h.authenticator, h.upstreamUserHandler(metadata), adapter.UpstreamMetadata(metadata)) } func (h *Mixed) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { diff --git a/inbound/vless.go b/inbound/vless.go index 69ed042b..56747564 100644 --- a/inbound/vless.go +++ b/inbound/vless.go @@ -83,12 +83,11 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg } func (h *VLESS) Start() error { - err := common.Start( - h.service, - h.tlsConfig, - ) - if err != nil { - return err + if h.tlsConfig != nil { + err := h.tlsConfig.Start() + if err != nil { + return err + } } if h.transport == nil { return h.myInboundAdapter.Start() diff --git a/inbound/vmess.go b/inbound/vmess.go index 70676bbd..15451275 100644 --- a/inbound/vmess.go +++ b/inbound/vmess.go @@ -93,13 +93,16 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg } func (h *VMess) Start() error { - err := common.Start( - h.service, - h.tlsConfig, - ) + err := h.service.Start() if err != nil { return err } + if h.tlsConfig != nil { + err = h.tlsConfig.Start() + if err != nil { + return err + } + } if h.transport == nil { return h.myInboundAdapter.Start() } diff --git a/option/outbound.go b/option/outbound.go index 59ee85ab..6c943cd9 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -113,7 +113,7 @@ type DialerOptions struct { Inet4BindAddress *ListenAddress `json:"inet4_bind_address,omitempty"` Inet6BindAddress *ListenAddress `json:"inet6_bind_address,omitempty"` ProtectPath string `json:"protect_path,omitempty"` - RoutingMark int `json:"routing_mark,omitempty"` + RoutingMark uint32 `json:"routing_mark,omitempty"` ReuseAddr bool `json:"reuse_addr,omitempty"` ConnectTimeout Duration `json:"connect_timeout,omitempty"` TCPFastOpen bool `json:"tcp_fast_open,omitempty"` diff --git a/option/route.go b/option/route.go index e313fcf2..dfd72986 100644 --- a/option/route.go +++ b/option/route.go @@ -10,7 +10,7 @@ type RouteOptions struct { AutoDetectInterface bool `json:"auto_detect_interface,omitempty"` OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"` DefaultInterface string `json:"default_interface,omitempty"` - DefaultMark int `json:"default_mark,omitempty"` + DefaultMark uint32 `json:"default_mark,omitempty"` } type GeoIPOptions struct { diff --git a/outbound/proxy.go b/outbound/proxy.go index 6127f0f2..fbc48481 100644 --- a/outbound/proxy.go +++ b/outbound/proxy.go @@ -1,7 +1,6 @@ package outbound import ( - std_bufio "bufio" "context" "crypto/rand" "encoding/hex" @@ -11,16 +10,10 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/protocol/http" "github.com/sagernet/sing/protocol/socks" - "github.com/sagernet/sing/protocol/socks/socks4" - "github.com/sagernet/sing/protocol/socks/socks5" ) type ProxyListener struct { @@ -102,16 +95,7 @@ func (l *ProxyListener) acceptLoop() { } func (l *ProxyListener) accept(ctx context.Context, conn *net.TCPConn) error { - headerType, err := rw.ReadByte(conn) - if err != nil { - return err - } - switch headerType { - case socks4.Version, socks5.Version: - return socks.HandleConnection0(ctx, conn, headerType, l.authenticator, l, M.Metadata{}) - } - reader := std_bufio.NewReader(bufio.NewCachedReader(conn, buf.As([]byte{headerType}))) - return http.HandleConnection(ctx, conn, reader, l.authenticator, l, M.Metadata{}) + return socks.HandleConnection(ctx, conn, l.authenticator, l, M.Metadata{}) } func (l *ProxyListener) NewConnection(ctx context.Context, conn net.Conn, upstreamMetadata M.Metadata) error { diff --git a/outbound/tor.go b/outbound/tor.go index 76c7955d..8ae73a66 100644 --- a/outbound/tor.go +++ b/outbound/tor.go @@ -44,10 +44,10 @@ func NewTor(ctx context.Context, router adapter.Router, logger log.ContextLogger startConf.ExtraArgs = options.ExtraArgs if options.DataDirectory != "" { dataDirAbs, _ := filepath.Abs(startConf.DataDir) - if geoIPPath := filepath.Join(dataDirAbs, "geoip"); rw.FileExists(geoIPPath) && !common.Contains(options.ExtraArgs, "--GeoIPFile") { + if geoIPPath := filepath.Join(dataDirAbs, "geoip"); rw.IsFile(geoIPPath) && !common.Contains(options.ExtraArgs, "--GeoIPFile") { options.ExtraArgs = append(options.ExtraArgs, "--GeoIPFile", geoIPPath) } - if geoIP6Path := filepath.Join(dataDirAbs, "geoip6"); rw.FileExists(geoIP6Path) && !common.Contains(options.ExtraArgs, "--GeoIPv6File") { + if geoIP6Path := filepath.Join(dataDirAbs, "geoip6"); rw.IsFile(geoIP6Path) && !common.Contains(options.ExtraArgs, "--GeoIPv6File") { options.ExtraArgs = append(options.ExtraArgs, "--GeoIPv6File", geoIP6Path) } } @@ -58,8 +58,12 @@ func NewTor(ctx context.Context, router adapter.Router, logger log.ContextLogger } if startConf.DataDir != "" { torrcFile := filepath.Join(startConf.DataDir, "torrc") - if !rw.FileExists(torrcFile) { - err := rw.WriteFile(torrcFile, []byte("")) + err := rw.MkdirParent(torrcFile) + if err != nil { + return nil, err + } + if !rw.IsFile(torrcFile) { + err := os.WriteFile(torrcFile, []byte(""), 0o600) if err != nil { return nil, err } diff --git a/route/router.go b/route/router.go index a8698264..3a895685 100644 --- a/route/router.go +++ b/route/router.go @@ -82,7 +82,7 @@ type Router struct { interfaceFinder *control.DefaultInterfaceFinder autoDetectInterface bool defaultInterface string - defaultMark int + defaultMark uint32 networkMonitor tun.NetworkUpdateMonitor interfaceMonitor tun.DefaultInterfaceMonitor packageManager tun.PackageManager @@ -1171,7 +1171,7 @@ func (r *Router) DefaultInterface() string { return r.defaultInterface } -func (r *Router) DefaultMark() int { +func (r *Router) DefaultMark() uint32 { return r.defaultMark } diff --git a/route/router_geo_resources.go b/route/router_geo_resources.go index e0a572c9..14364d21 100644 --- a/route/router_geo_resources.go +++ b/route/router_geo_resources.go @@ -50,7 +50,7 @@ func (r *Router) prepareGeoIPDatabase() error { geoPath = foundPath } } - if !rw.FileExists(geoPath) { + if !rw.IsFile(geoPath) { geoPath = filemanager.BasePath(r.ctx, geoPath) } if stat, err := os.Stat(geoPath); err == nil { @@ -61,7 +61,7 @@ func (r *Router) prepareGeoIPDatabase() error { os.Remove(geoPath) } } - if !rw.FileExists(geoPath) { + if !rw.IsFile(geoPath) { r.logger.Warn("geoip database not exists: ", geoPath) var err error for attempts := 0; attempts < 3; attempts++ { @@ -96,7 +96,7 @@ func (r *Router) prepareGeositeDatabase() error { geoPath = foundPath } } - if !rw.FileExists(geoPath) { + if !rw.IsFile(geoPath) { geoPath = filemanager.BasePath(r.ctx, geoPath) } if stat, err := os.Stat(geoPath); err == nil { @@ -107,7 +107,7 @@ func (r *Router) prepareGeositeDatabase() error { os.Remove(geoPath) } } - if !rw.FileExists(geoPath) { + if !rw.IsFile(geoPath) { r.logger.Warn("geosite database not exists: ", geoPath) var err error for attempts := 0; attempts < 3; attempts++ { diff --git a/route/rule_abstract.go b/route/rule_abstract.go index c13bdd8d..9ef2e932 100644 --- a/route/rule_abstract.go +++ b/route/rule_abstract.go @@ -29,9 +29,13 @@ func (r *abstractDefaultRule) Type() string { func (r *abstractDefaultRule) Start() error { for _, item := range r.allItems { - err := common.Start(item) - if err != nil { - return err + if starter, isStarter := item.(interface { + Start() error + }); isStarter { + err := starter.Start() + if err != nil { + return err + } } } return nil @@ -183,8 +187,13 @@ func (r *abstractLogicalRule) UpdateGeosite() error { } func (r *abstractLogicalRule) Start() error { - for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (common.Starter, bool) { - rule, loaded := it.(common.Starter) + for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (interface { + Start() error + }, bool, + ) { + rule, loaded := it.(interface { + Start() error + }) return rule, loaded }) { err := rule.Start() diff --git a/transport/trojan/mux.go b/transport/trojan/mux.go index 13ac1e83..b1cc9985 100644 --- a/transport/trojan/mux.go +++ b/transport/trojan/mux.go @@ -1,12 +1,14 @@ package trojan import ( + std_bufio "bufio" "context" "net" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/task" "github.com/sagernet/smux" ) @@ -33,27 +35,36 @@ func HandleMuxConnection(ctx context.Context, conn net.Conn, metadata M.Metadata return group.Run(ctx) } -func newMuxConnection(ctx context.Context, stream net.Conn, metadata M.Metadata, handler Handler) { - err := newMuxConnection0(ctx, stream, metadata, handler) +func newMuxConnection(ctx context.Context, conn net.Conn, metadata M.Metadata, handler Handler) { + err := newMuxConnection0(ctx, conn, metadata, handler) if err != nil { handler.NewError(ctx, E.Cause(err, "process trojan-go multiplex connection")) } } -func newMuxConnection0(ctx context.Context, stream net.Conn, metadata M.Metadata, handler Handler) error { - command, err := rw.ReadByte(stream) +func newMuxConnection0(ctx context.Context, conn net.Conn, metadata M.Metadata, handler Handler) error { + reader := std_bufio.NewReader(conn) + command, err := reader.ReadByte() if err != nil { return E.Cause(err, "read command") } - metadata.Destination, err = M.SocksaddrSerializer.ReadAddrPort(stream) + metadata.Destination, err = M.SocksaddrSerializer.ReadAddrPort(reader) if err != nil { return E.Cause(err, "read destination") } + if reader.Buffered() > 0 { + buffer := buf.NewSize(reader.Buffered()) + _, err = buffer.ReadFullFrom(reader, buffer.Len()) + if err != nil { + return err + } + conn = bufio.NewCachedConn(conn, buffer) + } switch command { case CommandTCP: - return handler.NewConnection(ctx, stream, metadata) + return handler.NewConnection(ctx, conn, metadata) case CommandUDP: - return handler.NewPacketConnection(ctx, &PacketConn{Conn: stream}, metadata) + return handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata) default: return E.New("unknown command ", command) } diff --git a/transport/trojan/service.go b/transport/trojan/service.go index 9078276c..97f674ab 100644 --- a/transport/trojan/service.go +++ b/transport/trojan/service.go @@ -2,6 +2,7 @@ package trojan import ( "context" + "encoding/binary" "net" "github.com/sagernet/sing/common/auth" @@ -76,7 +77,8 @@ func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata return E.Cause(err, "skip crlf") } - command, err := rw.ReadByte(conn) + var command byte + err = binary.Read(conn, binary.BigEndian, &command) if err != nil { return E.Cause(err, "read command") } diff --git a/transport/v2raygrpc/conn.go b/transport/v2raygrpc/conn.go index 0fecbf33..bc78f91e 100644 --- a/transport/v2raygrpc/conn.go +++ b/transport/v2raygrpc/conn.go @@ -8,7 +8,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/baderror" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/rw" + N "github.com/sagernet/sing/common/network" ) var _ net.Conn = (*GRPCConn)(nil) @@ -90,7 +90,7 @@ func (c *GRPCConn) Upstream() any { return c.GunService } -var _ rw.WriteCloser = (*clientConnWrapper)(nil) +var _ N.WriteCloser = (*clientConnWrapper)(nil) type clientConnWrapper struct { GunService_TunClient diff --git a/transport/v2raygrpclite/conn.go b/transport/v2raygrpclite/conn.go index f5a71939..5ab02569 100644 --- a/transport/v2raygrpclite/conn.go +++ b/transport/v2raygrpclite/conn.go @@ -13,7 +13,7 @@ import ( "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) // kanged from: https://github.com/Qv2ray/gun-lite @@ -96,7 +96,7 @@ func (c *GunConn) read(b []byte) (n int, err error) { } func (c *GunConn) Write(b []byte) (n int, err error) { - varLen := rw.UVariantLen(uint64(len(b))) + varLen := varbin.UvarintLen(uint64(len(b))) buffer := buf.NewSize(6 + varLen + len(b)) header := buffer.Extend(6 + varLen) header[0] = 0x00 @@ -117,13 +117,13 @@ func (c *GunConn) Write(b []byte) (n int, err error) { func (c *GunConn) WriteBuffer(buffer *buf.Buffer) error { defer buffer.Release() dataLen := buffer.Len() - varLen := rw.UVariantLen(uint64(dataLen)) + varLen := varbin.UvarintLen(uint64(dataLen)) header := buffer.ExtendHeader(6 + varLen) header[0] = 0x00 binary.BigEndian.PutUint32(header[1:5], uint32(1+varLen+dataLen)) header[5] = 0x0A binary.PutUvarint(header[6:], uint64(dataLen)) - err := rw.WriteBytes(c.writer, buffer.Bytes()) + err := common.Error(c.writer.Write(buffer.Bytes())) if err != nil { return baderror.WrapH2(err) }