diff --git a/dns/client.go b/dns/client.go index 6a54f9fa..5cb1fe02 100644 --- a/dns/client.go +++ b/dns/client.go @@ -15,9 +15,10 @@ import ( type client struct { *D.Client - r *Resolver - port string - host string + r *Resolver + port string + host string + iface string } func (c *client) Exchange(m *D.Msg) (*D.Msg, error) { @@ -45,7 +46,11 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) network = "tcp" } - conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), c.port)) + options := []dialer.Option{} + if c.iface != "" { + options = append(options, dialer.WithInterface(c.iface)) + } + conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), c.port), options...) if err != nil { return nil, err } diff --git a/dns/dhcp.go b/dns/dhcp.go index 94f8a36c..f964cec8 100644 --- a/dns/dhcp.go +++ b/dns/dhcp.go @@ -68,8 +68,11 @@ func (d *dhcpClient) resolve(ctx context.Context) (*Resolver, error) { dns, err := dhcp.ResolveDNSFromDHCP(ctx, d.ifaceName) if err == nil { nameserver := make([]NameServer, 0, len(dns)) - for _, d := range dns { - nameserver = append(nameserver, NameServer{Addr: net.JoinHostPort(d.String(), "53")}) + for _, item := range dns { + nameserver = append(nameserver, NameServer{ + Addr: net.JoinHostPort(item.String(), "53"), + Interface: d.ifaceName, + }) } res = NewResolver(Config{ diff --git a/dns/resolver.go b/dns/resolver.go index 914f1418..8bbd0e8b 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -302,8 +302,9 @@ func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D } type NameServer struct { - Net string - Addr string + Net string + Addr string + Interface string } type FallbackFilter struct { diff --git a/dns/util.go b/dns/util.go index 0c3f42e5..b167809a 100644 --- a/dns/util.go +++ b/dns/util.go @@ -138,9 +138,10 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { UDPSize: 4096, Timeout: 5 * time.Second, }, - port: port, - host: host, - r: resolver, + port: port, + host: host, + iface: s.Interface, + r: resolver, }) } return ret