diff --git a/component/sniffer/base_sniffer.go b/component/sniffer/base_sniffer.go index 6d076b59..c2958cc6 100644 --- a/component/sniffer/base_sniffer.go +++ b/component/sniffer/base_sniffer.go @@ -9,7 +9,8 @@ import ( ) type SnifferConfig struct { - Ports []utils.Range[uint16] + OverrideDest bool + Ports []utils.Range[uint16] } type BaseSniffer struct { diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index a450693b..0d6badf5 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -26,15 +26,12 @@ var ( var Dispatcher *SnifferDispatcher type SnifferDispatcher struct { - enable bool - - sniffers []sniffer.Sniffer - - forceDomain *trie.DomainTrie[struct{}] - skipSNI *trie.DomainTrie[struct{}] - skipList *cache.LruCache[string, uint8] - rwMux sync.RWMutex - + enable bool + sniffers map[sniffer.Sniffer]SnifferConfig + forceDomain *trie.DomainTrie[struct{}] + skipSNI *trie.DomainTrie[struct{}] + skipList *cache.LruCache[string, uint8] + rwMux sync.RWMutex forceDnsMapping bool parsePureIp bool } @@ -53,10 +50,12 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) { } inWhitelist := false - for _, sniffer := range sd.sniffers { + overrideDest := false + for sniffer, config := range sd.sniffers { if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet { inWhitelist = sniffer.SupportPort(uint16(port)) if inWhitelist { + overrideDest = config.OverrideDest break } } @@ -89,12 +88,12 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) { sd.skipList.Delete(dst) sd.rwMux.RUnlock() - sd.replaceDomain(metadata, host) + sd.replaceDomain(metadata, host, overrideDest) } } } -func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string) { +func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) { dstIP := "" if metadata.DstIP.IsValid() { dstIP = metadata.DstIP.String() @@ -112,7 +111,11 @@ func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string) { metadata.Host, host) } - metadata.Host = host + if overrideDest { + metadata.Host = host + } else { + metadata.SniffHost = host + } metadata.DNSMode = C.DNSNormal } @@ -121,7 +124,7 @@ func (sd *SnifferDispatcher) Enable() bool { } func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) { - for _, s := range sd.sniffers { + for s := range sd.sniffers { if s.SupportNetwork() == C.TCP { _ = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) _, err := conn.Peek(1) @@ -189,9 +192,10 @@ func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, forceDom enable: true, forceDomain: forceDomain, skipSNI: skipSNI, - skipList: cache.New[string, uint8](cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)), + skipList: cache.New(cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)), forceDnsMapping: forceDnsMapping, parsePureIp: parsePureIp, + sniffers: make(map[sniffer.Sniffer]SnifferConfig, 0), } for snifferName, config := range snifferConfig { @@ -200,8 +204,7 @@ func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, forceDom log.Errorln("Sniffer name[%s] is error", snifferName) return &SnifferDispatcher{enable: false}, err } - - dispatcher.sniffers = append(dispatcher.sniffers, s) + dispatcher.sniffers[s] = config } return &dispatcher, nil diff --git a/config/config.go b/config/config.go index f4a991a9..c5ff0577 100644 --- a/config/config.go +++ b/config/config.go @@ -288,6 +288,7 @@ type RawGeoXUrl struct { type RawSniffer struct { Enable bool `yaml:"enable" json:"enable"` + OverrideDest bool `yaml:"override-destination" json:"override-destination"` Sniffing []string `yaml:"sniffing" json:"sniffing"` ForceDomain []string `yaml:"force-domain" json:"force-domain"` SkipDomain []string `yaml:"skip-domain" json:"skip-domain"` @@ -298,7 +299,8 @@ type RawSniffer struct { } type RawSniffingConfig struct { - Ports []string `yaml:"ports" json:"ports"` + Ports []string `yaml:"ports" json:"ports"` + OverrideDest *bool `yaml:"override-destination" json:"override-destination"` } // EBpf config @@ -1201,11 +1203,16 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { if err != nil { return nil, err } + overrideDest := snifferRaw.OverrideDest + if sniffConfig.OverrideDest != nil { + overrideDest = *sniffConfig.OverrideDest + } for _, snifferType := range snifferTypes.List { if snifferType.String() == strings.ToUpper(sniffType) { find = true loadSniffer[snifferType] = SNIFF.SnifferConfig{ - Ports: ports, + Ports: ports, + OverrideDest: overrideDest, } } } @@ -1228,7 +1235,8 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { if snifferType.String() == strings.ToUpper(snifferName) { find = true loadSniffer[snifferType] = SNIFF.SnifferConfig{ - Ports: globalPorts, + Ports: globalPorts, + OverrideDest: snifferRaw.OverrideDest, } } } diff --git a/constant/metadata.go b/constant/metadata.go index 4605a146..d57c21b6 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -134,6 +134,8 @@ type Metadata struct { SpecialProxy string `json:"specialProxy"` SpecialRules string `json:"specialRules"` RemoteDst string `json:"remoteDestination"` + // Only domain rule + SniffHost string } func (m *Metadata) RemoteAddress() string { @@ -176,6 +178,14 @@ func (m *Metadata) Resolved() bool { return m.DstIP.IsValid() } +func (m *Metadata) RuleHost() string { + if len(m.SniffHost) == 0 { + return m.Host + } else { + return m.SniffHost + } +} + // Pure is used to solve unexpected behavior // when dialing proxy connection in DNSMapping mode. func (m *Metadata) Pure() *Metadata { diff --git a/rules/common/domain.go b/rules/common/domain.go index 1dc6b250..6b3eba22 100644 --- a/rules/common/domain.go +++ b/rules/common/domain.go @@ -19,7 +19,7 @@ func (d *Domain) RuleType() C.RuleType { } func (d *Domain) Match(metadata *C.Metadata) (bool, string) { - return metadata.Host == d.domain, d.adapter + return metadata.RuleHost() == d.domain, d.adapter } func (d *Domain) Adapter() string { diff --git a/rules/common/domain_keyword.go b/rules/common/domain_keyword.go index 4fff673e..94d2a949 100644 --- a/rules/common/domain_keyword.go +++ b/rules/common/domain_keyword.go @@ -19,7 +19,7 @@ func (dk *DomainKeyword) RuleType() C.RuleType { } func (dk *DomainKeyword) Match(metadata *C.Metadata) (bool, string) { - domain := metadata.Host + domain := metadata.RuleHost() return strings.Contains(domain, dk.keyword), dk.adapter } diff --git a/rules/common/domain_suffix.go b/rules/common/domain_suffix.go index 1e48a236..4bdc2e2e 100644 --- a/rules/common/domain_suffix.go +++ b/rules/common/domain_suffix.go @@ -19,7 +19,7 @@ func (ds *DomainSuffix) RuleType() C.RuleType { } func (ds *DomainSuffix) Match(metadata *C.Metadata) (bool, string) { - domain := metadata.Host + domain := metadata.RuleHost() return strings.HasSuffix(domain, "."+ds.suffix) || domain == ds.suffix, ds.adapter } diff --git a/rules/common/geosite.go b/rules/common/geosite.go index 865b0b1b..c0a3a1da 100644 --- a/rules/common/geosite.go +++ b/rules/common/geosite.go @@ -29,7 +29,7 @@ func (gs *GEOSITE) Match(metadata *C.Metadata) (bool, string) { return false, "" } - domain := metadata.Host + domain := metadata.RuleHost() return gs.matcher.ApplyDomain(domain), gs.adapter } diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index 85528f40..add64e76 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -17,7 +17,7 @@ func (d *domainStrategy) ShouldFindProcess() bool { } func (d *domainStrategy) Match(metadata *C.Metadata) bool { - return d.domainRules != nil && d.domainRules.Search(metadata.Host) != nil + return d.domainRules != nil && d.domainRules.Search(metadata.RuleHost()) != nil } func (d *domainStrategy) Count() int {