1
0
Fork 0
mirror of https://github.com/SamTherapy/dnscrypt.git synced 2024-07-02 21:56:06 +00:00
dnscrypt/client.go
2023-04-13 16:24:15 +03:00

342 lines
8.8 KiB
Go

package dnscrypt
import (
"crypto/ed25519"
"encoding/binary"
"fmt"
"net"
"strings"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnsstamps"
"github.com/miekg/dns"
)
// Client is a DNSCrypt resolver client
type Client struct {
Net string // protocol (can be "udp" or "tcp", by default - "udp")
Timeout time.Duration // read/write timeout
// UDPSize is the maximum size of a DNS response (or query) this client can
// sent or receive. If not set, we use dns.MinMsgSize by default.
UDPSize int
}
// ResolverInfo contains DNSCrypt resolver information necessary for decryption/encryption
type ResolverInfo struct {
SecretKey [keySize]byte // Client short-term secret key
PublicKey [keySize]byte // Client short-term public key
ServerPublicKey ed25519.PublicKey // Resolver public key (this key is used to validate cert signature)
ServerAddress string // Server IP address
ProviderName string // Provider name
ResolverCert *Cert // Certificate info (obtained with the first unencrypted DNS request)
SharedKey [keySize]byte // Shared key that is to be used to encrypt/decrypt messages
}
// Dial fetches and validates DNSCrypt certificate from the given server
// Data received during this call is then used for DNS requests encryption/decryption
// stampStr is an sdns:// address which is parsed using go-dnsstamps package
func (c *Client) Dial(stampStr string) (*ResolverInfo, error) {
stamp, err := dnsstamps.NewServerStampFromString(stampStr)
if err != nil {
// Invalid SDNS stamp
return nil, err
}
if stamp.Proto != dnsstamps.StampProtoTypeDNSCrypt {
return nil, ErrInvalidDNSStamp
}
return c.DialStamp(stamp)
}
// DialStamp fetches and validates DNSCrypt certificate from the given server
// Data received during this call is then used for DNS requests encryption/decryption
func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ResolverInfo, error) {
resolverInfo := &ResolverInfo{}
// Generate the secret/public pair
resolverInfo.SecretKey, resolverInfo.PublicKey = generateRandomKeyPair()
// Set the provider properties
resolverInfo.ServerPublicKey = stamp.ServerPk
resolverInfo.ServerAddress = stamp.ServerAddrStr
resolverInfo.ProviderName = stamp.ProviderName
cert, err := c.fetchCert(stamp)
if err != nil {
return nil, err
}
resolverInfo.ResolverCert = cert
// Compute shared key that we'll use to encrypt/decrypt messages
sharedKey, err := computeSharedKey(cert.EsVersion, &resolverInfo.SecretKey, &cert.ResolverPk)
if err != nil {
return nil, err
}
resolverInfo.SharedKey = sharedKey
return resolverInfo, nil
}
// Exchange performs a synchronous DNS query to the specified DNSCrypt server and returns a DNS response.
// This method creates a new network connection for every call so avoid using it for TCP.
// DNSCrypt cert needs to be fetched and validated prior to this call using the c.DialStamp method.
func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (resp *dns.Msg, err error) {
network := "udp"
if c.Net == "tcp" {
network = "tcp"
}
conn, err := net.Dial(network, resolverInfo.ServerAddress)
if err != nil {
return nil, fmt.Errorf("dialing: %w", err)
}
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
resp, err = c.ExchangeConn(conn, m, resolverInfo)
if err != nil {
return nil, fmt.Errorf("exchanging: %w", err)
}
return resp, nil
}
// ExchangeConn performs a synchronous DNS query to the specified DNSCrypt server and returns a DNS response.
// DNSCrypt server information needs to be fetched and validated prior to this call using the c.DialStamp method
func (c *Client) ExchangeConn(conn net.Conn, m *dns.Msg, resolverInfo *ResolverInfo) (*dns.Msg, error) {
query, err := c.encrypt(m, resolverInfo)
if err != nil {
return nil, err
}
err = c.writeQuery(conn, query)
if err != nil {
return nil, err
}
b, err := c.readResponse(conn)
if err != nil {
return nil, err
}
res, err := c.decrypt(b, resolverInfo)
if err != nil {
return nil, err
}
return res, nil
}
// writeQuery writes query to the network connection
// depending on the protocol we may write a 2-byte prefix or not
func (c *Client) writeQuery(conn net.Conn, query []byte) error {
var err error
if c.Timeout > 0 {
_ = conn.SetWriteDeadline(time.Now().Add(c.Timeout))
}
// Write to the connection
if _, ok := conn.(*net.TCPConn); ok {
l := make([]byte, 2)
binary.BigEndian.PutUint16(l, uint16(len(query)))
_, err = (&net.Buffers{l, query}).WriteTo(conn)
} else {
_, err = conn.Write(query)
}
return err
}
// readResponse reads response from the network connection
// depending on the protocol, we may read a 2-byte prefix or not
func (c *Client) readResponse(conn net.Conn) ([]byte, error) {
if c.Timeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(c.Timeout))
}
proto := "udp"
if _, ok := conn.(*net.TCPConn); ok {
proto = "tcp"
}
if proto == "udp" {
bufSize := c.UDPSize
if bufSize == 0 {
bufSize = dns.MinMsgSize
}
response := make([]byte, bufSize)
n, err := conn.Read(response)
if err != nil {
return nil, err
}
return response[:n], nil
}
// If we got here, this is a TCP connection
// so we should read a 2-byte prefix first
return readPrefixed(conn)
}
// encrypt encrypts a DNS message using shared key from the resolver info
func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error) {
q := EncryptedQuery{
EsVersion: resolverInfo.ResolverCert.EsVersion,
ClientMagic: resolverInfo.ResolverCert.ClientMagic,
ClientPk: resolverInfo.PublicKey,
}
query, err := m.Pack()
if err != nil {
return nil, err
}
b, err := q.Encrypt(query, resolverInfo.SharedKey)
if len(b) > c.maxQuerySize() {
return nil, ErrQueryTooLarge
}
return b, err
}
// decrypts decrypts a DNS message using a shared key from the resolver info
func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error) {
dr := EncryptedResponse{
EsVersion: resolverInfo.ResolverCert.EsVersion,
}
msg, err := dr.Decrypt(b, resolverInfo.SharedKey)
if err != nil {
return nil, err
}
res := new(dns.Msg)
err = res.Unpack(msg)
if err != nil {
return nil, err
}
return res, nil
}
// fetchCert loads DNSCrypt cert from the specified server
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (cert *Cert, err error) {
providerName := stamp.ProviderName
if !strings.HasSuffix(providerName, ".") {
providerName = providerName + "."
}
query := new(dns.Msg)
query.SetQuestion(providerName, dns.TypeTXT)
// use 1252 as a UDPSize for this client to make sure the buffer is not too small
client := dns.Client{Net: c.Net, UDPSize: uint16(1252), Timeout: c.Timeout}
r, _, err := client.Exchange(query, stamp.ServerAddrStr)
if err != nil {
return nil, err
}
if r.Rcode != dns.RcodeSuccess {
return nil, ErrFailedToFetchCert
}
currentCert := &Cert{}
foundValid := false
for _, rr := range r.Answer {
txt, ok := rr.(*dns.TXT)
if !ok {
continue
}
cert, err = parseCert(stamp, currentCert, providerName, strings.Join(txt.Txt, ""))
if err != nil {
log.Debug("[%s] bad cert: %s", providerName, err)
continue
} else if cert == nil {
// The certificate has been skipped due to Serial or EsVersion.
continue
}
currentCert = cert
foundValid = true
}
if foundValid {
return currentCert, nil
} else if err == nil {
err = fmt.Errorf("no valid txt records for provider %q", providerName)
}
return nil, err
}
// parseCert parses a certificate from its string form and returns it if it has
// priority over currentCert.
func parseCert(
stamp dnsstamps.ServerStamp,
currentCert *Cert,
providerName string,
certStr string,
) (cert *Cert, err error) {
certBytes, err := unpackTxtString(certStr)
if err != nil {
return nil, fmt.Errorf("unpacking txt record: %w", err)
}
cert = &Cert{}
err = cert.Deserialize(certBytes)
if err != nil {
return nil, fmt.Errorf("deserializing cert for: %w", err)
}
log.Debug("[%s] fetched certificate %d", providerName, cert.Serial)
if !cert.VerifyDate() {
return nil, ErrInvalidDate
}
if !cert.VerifySignature(stamp.ServerPk) {
return nil, ErrInvalidCertSignature
}
if cert.Serial < currentCert.Serial {
log.Debug("[%v] cert %d superseded by a previous certificate", providerName, cert.Serial)
return nil, nil
}
if cert.Serial > currentCert.Serial {
return cert, nil
}
if cert.EsVersion <= currentCert.EsVersion {
log.Debug("[%v] keeping the previous, preferred crypto construction", providerName)
return nil, nil
}
log.Debug(
"[%v] upgrading the construction from %v to %v",
providerName,
currentCert.EsVersion,
cert.EsVersion,
)
return cert, nil
}
func (c *Client) maxQuerySize() int {
if c.Net == "tcp" {
return dns.MaxMsgSize
}
if c.UDPSize > 0 {
return c.UDPSize
}
return dns.MinMsgSize
}