mihomo/dns/middleware.go

232 lines
5.4 KiB
Go
Raw Normal View History

2019-07-14 19:29:58 +08:00
package dns
import (
2022-04-06 04:25:53 +08:00
"net/netip"
2019-07-14 19:29:58 +08:00
"strings"
"time"
2019-07-14 19:29:58 +08:00
2023-11-03 21:01:45 +08:00
"github.com/metacubex/mihomo/common/cache"
"github.com/metacubex/mihomo/common/nnip"
"github.com/metacubex/mihomo/component/fakeip"
R "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/context"
"github.com/metacubex/mihomo/log"
2019-07-14 19:29:58 +08:00
D "github.com/miekg/dns"
)
2021-10-10 23:44:09 +08:00
type (
handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
middleware func(next handler) handler
)
2019-07-14 19:29:58 +08:00
func withHosts(hosts R.Hosts, mapping *cache.LruCache[netip.Addr, string]) middleware {
return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
if !isIPRequest(q) {
return next(ctx, r)
}
2022-04-11 06:28:42 +08:00
host := strings.TrimRight(q.Name, ".")
handleCName := func(resp *D.Msg, domain string) {
rr := &D.CNAME{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10}
rr.Target = domain + "."
resp.Answer = append([]D.RR{rr}, resp.Answer...)
}
record, ok := hosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA)
if !ok {
if record != nil && record.IsDomain {
// replace request domain
newR := r.Copy()
newR.Question[0].Name = record.Domain + "."
resp, err := next(ctx, newR)
if err == nil {
resp.Id = r.Id
resp.Question = r.Question
handleCName(resp, record.Domain)
}
return resp, err
}
return next(ctx, r)
}
msg := r.Copy()
handleIPs := func() {
for _, ipAddr := range record.IPs {
if ipAddr.Is4() && q.Qtype == D.TypeA {
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
rr.A = ipAddr.AsSlice()
msg.Answer = append(msg.Answer, rr)
if mapping != nil {
mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
}
} else if q.Qtype == D.TypeAAAA {
rr := &D.AAAA{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10}
ip := ipAddr.As16()
rr.AAAA = ip[:]
msg.Answer = append(msg.Answer, rr)
if mapping != nil {
mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
}
}
}
}
switch q.Qtype {
case D.TypeA:
handleIPs()
case D.TypeAAAA:
handleIPs()
case D.TypeCNAME:
handleCName(r, record.Domain)
default:
return next(ctx, r)
2022-04-11 06:28:42 +08:00
}
ctx.SetType(context.DNSTypeHost)
msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true
msg.RecursionAvailable = true
return msg, nil
}
}
}
func withMapping(mapping *cache.LruCache[netip.Addr, string]) middleware {
return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
if !isIPRequest(q) {
return next(ctx, r)
}
msg, err := next(ctx, r)
if err != nil {
return nil, err
}
host := strings.TrimRight(q.Name, ".")
for _, ans := range msg.Answer {
2022-04-20 01:52:51 +08:00
var ip netip.Addr
var ttl uint32
switch a := ans.(type) {
case *D.A:
2022-04-20 01:52:51 +08:00
ip = nnip.IpToAddr(a.A)
ttl = a.Hdr.Ttl
case *D.AAAA:
2022-04-20 01:52:51 +08:00
ip = nnip.IpToAddr(a.AAAA)
ttl = a.Hdr.Ttl
default:
continue
}
if ttl < 1 {
ttl = 1
}
2022-04-20 01:52:51 +08:00
mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*time.Duration(ttl)))
}
return msg, nil
}
}
}
2019-09-11 17:00:55 +08:00
func withFakeIP(fakePool *fakeip.Pool) middleware {
return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
2019-09-11 17:00:55 +08:00
q := r.Question[0]
2020-10-09 00:04:24 +08:00
host := strings.TrimRight(q.Name, ".")
2021-10-11 20:48:58 +08:00
if fakePool.ShouldSkipped(host) {
return next(ctx, r)
2020-10-09 00:04:24 +08:00
}
switch q.Qtype {
case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS:
return handleMsgWithEmptyAnswer(r), nil
}
if q.Qtype != D.TypeA {
return next(ctx, r)
2019-09-11 17:00:55 +08:00
}
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
ip := fakePool.Lookup(host)
rr.A = ip.AsSlice()
2019-09-11 17:00:55 +08:00
msg := r.Copy()
msg.Answer = []D.RR{rr}
ctx.SetType(context.DNSTypeFakeIP)
2019-09-11 17:00:55 +08:00
setMsgTTL(msg, 1)
msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true
msg.RecursionAvailable = true
return msg, nil
2019-09-11 17:00:55 +08:00
}
2019-07-14 19:29:58 +08:00
}
}
func withResolver(resolver *Resolver) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
ctx.SetType(context.DNSTypeRaw)
2020-06-18 18:11:02 +08:00
q := r.Question[0]
// return a empty AAAA msg when ipv6 disabled
if !resolver.ipv6 && q.Qtype == D.TypeAAAA {
return handleMsgWithEmptyAnswer(r), nil
2020-06-18 18:11:02 +08:00
}
msg, err := resolver.ExchangeContext(ctx, r)
2019-07-14 19:29:58 +08:00
if err != nil {
2019-09-11 17:00:55 +08:00
log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
return msg, err
2019-07-14 19:29:58 +08:00
}
msg.SetRcode(r, msg.Rcode)
msg.Authoritative = true
return msg, nil
2019-07-14 19:29:58 +08:00
}
}
2019-09-11 17:00:55 +08:00
func compose(middlewares []middleware, endpoint handler) handler {
length := len(middlewares)
h := endpoint
for i := length - 1; i >= 0; i-- {
middleware := middlewares[i]
h = middleware(h)
2019-07-14 19:29:58 +08:00
}
2019-09-11 17:00:55 +08:00
return h
2019-07-14 19:29:58 +08:00
}
2021-11-17 16:03:47 +08:00
func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
2019-09-11 17:00:55 +08:00
middlewares := []middleware{}
2019-07-14 19:29:58 +08:00
if resolver.hosts != nil {
middlewares = append(middlewares, withHosts(R.NewHosts(resolver.hosts), mapper.mapping))
}
if mapper.mode == C.DNSFakeIP {
middlewares = append(middlewares, withFakeIP(mapper.fakePool))
}
if mapper.mode != C.DNSNormal {
middlewares = append(middlewares, withMapping(mapper.mapping))
2019-07-14 19:29:58 +08:00
}
2019-09-11 17:00:55 +08:00
return compose(middlewares, withResolver(resolver))
2019-07-14 19:29:58 +08:00
}