mirror of
https://github.com/SamTherapy/dnscrypt.git
synced 2024-12-21 16:50:42 +00:00
parent
aeff595567
commit
266e248ed5
3 changed files with 76 additions and 52 deletions
124
client.go
124
client.go
|
@ -3,10 +3,12 @@ package dnscrypt
|
|||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/ameshkov/dnsstamps"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -69,6 +71,7 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ResolverInfo, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resolverInfo.ResolverCert = cert
|
||||
|
||||
// Compute shared key that we'll use to encrypt/decrypt messages
|
||||
|
@ -83,7 +86,7 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ResolverInfo, error) {
|
|||
// Exchange performs a synchronous DNS query to the specified DNSCrypt server and returns a DNS response.
|
||||
// This method creates a new network connection for every call so avoid using it for TCP.
|
||||
// DNSCrypt cert needs to be fetched and validated prior to this call using the c.DialStamp method.
|
||||
func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (*dns.Msg, error) {
|
||||
func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (resp *dns.Msg, err error) {
|
||||
network := "udp"
|
||||
if c.Net == "tcp" {
|
||||
network = "tcp"
|
||||
|
@ -91,15 +94,16 @@ func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (*dns.Msg, err
|
|||
|
||||
conn, err := net.Dial(network, resolverInfo.ServerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("dialing: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
|
||||
|
||||
r, err := c.ExchangeConn(conn, m, resolverInfo)
|
||||
resp, err = c.ExchangeConn(conn, m, resolverInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("exchanging: %w", err)
|
||||
}
|
||||
return r, nil
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExchangeConn performs a synchronous DNS query to the specified DNSCrypt server and returns a DNS response.
|
||||
|
@ -217,7 +221,7 @@ func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error)
|
|||
}
|
||||
|
||||
// fetchCert loads DNSCrypt cert from the specified server
|
||||
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
|
||||
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (cert *Cert, err error) {
|
||||
providerName := stamp.ProviderName
|
||||
if !strings.HasSuffix(providerName, ".") {
|
||||
providerName = providerName + "."
|
||||
|
@ -236,67 +240,87 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
|
|||
return nil, ErrFailedToFetchCert
|
||||
}
|
||||
|
||||
var certErr error
|
||||
currentCert := &Cert{}
|
||||
foundValid := false
|
||||
|
||||
for _, rr := range r.Answer {
|
||||
txt, ok := rr.(*dns.TXT)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var b []byte
|
||||
b, certErr = unpackTxtString(strings.Join(txt.Txt, ""))
|
||||
if certErr != nil {
|
||||
log.Debug("[%s] failed to pack TXT record: %v", providerName, certErr)
|
||||
|
||||
cert, err = parseCert(stamp, currentCert, providerName, strings.Join(txt.Txt, ""))
|
||||
if err != nil {
|
||||
log.Debug("[%s] bad cert: %s", providerName, err)
|
||||
} else if cert == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
cert := &Cert{}
|
||||
certErr = cert.Deserialize(b)
|
||||
if certErr != nil {
|
||||
log.Debug("[%s] failed to deserialize cert: %v", providerName, certErr)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("[%s] fetched certificate %d", providerName, cert.Serial)
|
||||
|
||||
if !cert.VerifyDate() {
|
||||
certErr = ErrInvalidDate
|
||||
log.Debug("[%s] cert %d date is not valid", providerName, cert.Serial)
|
||||
continue
|
||||
}
|
||||
|
||||
if !cert.VerifySignature(stamp.ServerPk) {
|
||||
certErr = ErrInvalidCertSignature
|
||||
log.Debug("[%s] cert %d signature is not valid", providerName, cert.Serial)
|
||||
continue
|
||||
}
|
||||
|
||||
if cert.Serial < currentCert.Serial {
|
||||
log.Debug("[%v] cert %d superseded by a previous certificate", providerName, cert.Serial)
|
||||
continue
|
||||
}
|
||||
|
||||
if cert.Serial == currentCert.Serial {
|
||||
if cert.EsVersion > currentCert.EsVersion {
|
||||
log.Debug("[%v] Upgrading the construction from %v to %v", providerName, currentCert.EsVersion, cert.EsVersion)
|
||||
} else {
|
||||
log.Debug("[%v] Keeping the previous, preferred crypto construction", providerName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Setting the cert
|
||||
currentCert = cert
|
||||
foundValid = true
|
||||
}
|
||||
|
||||
if foundValid {
|
||||
return currentCert, nil
|
||||
} else if err == nil {
|
||||
err = errors.Error("no valid txt records")
|
||||
}
|
||||
|
||||
return nil, certErr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// parseCert parses a certificate from its string form and returns it if it has
|
||||
// priority over currentCert.
|
||||
func parseCert(
|
||||
stamp dnsstamps.ServerStamp,
|
||||
currentCert *Cert,
|
||||
providerName string,
|
||||
certStr string,
|
||||
) (cert *Cert, err error) {
|
||||
certBytes, err := unpackTxtString(certStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unpacking txt record: %w", err)
|
||||
}
|
||||
|
||||
cert = &Cert{}
|
||||
err = cert.Deserialize(certBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deserializing cert for: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("[%s] fetched certificate %d", providerName, cert.Serial)
|
||||
|
||||
if !cert.VerifyDate() {
|
||||
return nil, ErrInvalidDate
|
||||
}
|
||||
|
||||
if !cert.VerifySignature(stamp.ServerPk) {
|
||||
return nil, ErrInvalidCertSignature
|
||||
}
|
||||
|
||||
if cert.Serial < currentCert.Serial {
|
||||
log.Debug("[%v] cert %d superseded by a previous certificate", providerName, cert.Serial)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if cert.Serial > currentCert.Serial {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
if cert.EsVersion <= currentCert.EsVersion {
|
||||
log.Debug("[%v] keeping the previous, preferred crypto construction", providerName)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Debug(
|
||||
"[%v] upgrading the construction from %v to %v",
|
||||
providerName,
|
||||
currentCert.EsVersion,
|
||||
cert.EsVersion,
|
||||
)
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func (c *Client) maxQuerySize() int {
|
||||
|
|
|
@ -74,7 +74,7 @@ func TestTimeoutOnDialExchange(t *testing.T) {
|
|||
|
||||
// Check error
|
||||
require.NotNil(t, err)
|
||||
require.True(t, os.IsTimeout(err))
|
||||
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestFetchCertPublicResolvers(t *testing.T) {
|
||||
|
|
2
go.mod
2
go.mod
|
@ -1,6 +1,6 @@
|
|||
module github.com/ameshkov/dnscrypt/v2
|
||||
|
||||
go 1.18
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/golibs v0.10.9
|
||||
|
|
Loading…
Reference in a new issue