1
0
Fork 0
mirror of https://github.com/SamTherapy/dnscrypt.git synced 2024-12-21 16:50:42 +00:00

all: fix panic in cert parsing (#19)

Fixes ameshkov/dnscrypt#18.
This commit is contained in:
Ainar Garipov 2023-03-15 18:49:10 +03:00 committed by GitHub
parent aeff595567
commit 266e248ed5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 52 deletions

124
client.go
View file

@ -3,10 +3,12 @@ package dnscrypt
import (
"crypto/ed25519"
"encoding/binary"
"fmt"
"net"
"strings"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnsstamps"
"github.com/miekg/dns"
@ -69,6 +71,7 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ResolverInfo, error) {
if err != nil {
return nil, err
}
resolverInfo.ResolverCert = cert
// Compute shared key that we'll use to encrypt/decrypt messages
@ -83,7 +86,7 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ResolverInfo, error) {
// Exchange performs a synchronous DNS query to the specified DNSCrypt server and returns a DNS response.
// This method creates a new network connection for every call so avoid using it for TCP.
// DNSCrypt cert needs to be fetched and validated prior to this call using the c.DialStamp method.
func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (*dns.Msg, error) {
func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (resp *dns.Msg, err error) {
network := "udp"
if c.Net == "tcp" {
network = "tcp"
@ -91,15 +94,16 @@ func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (*dns.Msg, err
conn, err := net.Dial(network, resolverInfo.ServerAddress)
if err != nil {
return nil, err
return nil, fmt.Errorf("dialing: %w", err)
}
defer conn.Close()
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
r, err := c.ExchangeConn(conn, m, resolverInfo)
resp, err = c.ExchangeConn(conn, m, resolverInfo)
if err != nil {
return nil, err
return nil, fmt.Errorf("exchanging: %w", err)
}
return r, nil
return resp, nil
}
// ExchangeConn performs a synchronous DNS query to the specified DNSCrypt server and returns a DNS response.
@ -217,7 +221,7 @@ func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error)
}
// fetchCert loads DNSCrypt cert from the specified server
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (cert *Cert, err error) {
providerName := stamp.ProviderName
if !strings.HasSuffix(providerName, ".") {
providerName = providerName + "."
@ -236,67 +240,87 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
return nil, ErrFailedToFetchCert
}
var certErr error
currentCert := &Cert{}
foundValid := false
for _, rr := range r.Answer {
txt, ok := rr.(*dns.TXT)
if !ok {
continue
}
var b []byte
b, certErr = unpackTxtString(strings.Join(txt.Txt, ""))
if certErr != nil {
log.Debug("[%s] failed to pack TXT record: %v", providerName, certErr)
cert, err = parseCert(stamp, currentCert, providerName, strings.Join(txt.Txt, ""))
if err != nil {
log.Debug("[%s] bad cert: %s", providerName, err)
} else if cert == nil {
continue
}
cert := &Cert{}
certErr = cert.Deserialize(b)
if certErr != nil {
log.Debug("[%s] failed to deserialize cert: %v", providerName, certErr)
continue
}
log.Debug("[%s] fetched certificate %d", providerName, cert.Serial)
if !cert.VerifyDate() {
certErr = ErrInvalidDate
log.Debug("[%s] cert %d date is not valid", providerName, cert.Serial)
continue
}
if !cert.VerifySignature(stamp.ServerPk) {
certErr = ErrInvalidCertSignature
log.Debug("[%s] cert %d signature is not valid", providerName, cert.Serial)
continue
}
if cert.Serial < currentCert.Serial {
log.Debug("[%v] cert %d superseded by a previous certificate", providerName, cert.Serial)
continue
}
if cert.Serial == currentCert.Serial {
if cert.EsVersion > currentCert.EsVersion {
log.Debug("[%v] Upgrading the construction from %v to %v", providerName, currentCert.EsVersion, cert.EsVersion)
} else {
log.Debug("[%v] Keeping the previous, preferred crypto construction", providerName)
continue
}
}
// Setting the cert
currentCert = cert
foundValid = true
}
if foundValid {
return currentCert, nil
} else if err == nil {
err = errors.Error("no valid txt records")
}
return nil, certErr
return nil, err
}
// parseCert parses a certificate from its string form and returns it if it has
// priority over currentCert.
func parseCert(
stamp dnsstamps.ServerStamp,
currentCert *Cert,
providerName string,
certStr string,
) (cert *Cert, err error) {
certBytes, err := unpackTxtString(certStr)
if err != nil {
return nil, fmt.Errorf("unpacking txt record: %w", err)
}
cert = &Cert{}
err = cert.Deserialize(certBytes)
if err != nil {
return nil, fmt.Errorf("deserializing cert for: %w", err)
}
log.Debug("[%s] fetched certificate %d", providerName, cert.Serial)
if !cert.VerifyDate() {
return nil, ErrInvalidDate
}
if !cert.VerifySignature(stamp.ServerPk) {
return nil, ErrInvalidCertSignature
}
if cert.Serial < currentCert.Serial {
log.Debug("[%v] cert %d superseded by a previous certificate", providerName, cert.Serial)
return nil, nil
}
if cert.Serial > currentCert.Serial {
return cert, nil
}
if cert.EsVersion <= currentCert.EsVersion {
log.Debug("[%v] keeping the previous, preferred crypto construction", providerName)
return nil, nil
}
log.Debug(
"[%v] upgrading the construction from %v to %v",
providerName,
currentCert.EsVersion,
cert.EsVersion,
)
return cert, nil
}
func (c *Client) maxQuerySize() int {

View file

@ -74,7 +74,7 @@ func TestTimeoutOnDialExchange(t *testing.T) {
// Check error
require.NotNil(t, err)
require.True(t, os.IsTimeout(err))
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
}
func TestFetchCertPublicResolvers(t *testing.T) {

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/ameshkov/dnscrypt/v2
go 1.18
go 1.19
require (
github.com/AdguardTeam/golibs v0.10.9