1
0
Fork 0
mirror of https://github.com/SamTherapy/dnscrypt.git synced 2024-07-02 21:56:06 +00:00
dnscrypt/server_test.go
2023-03-15 18:57:22 +03:00

319 lines
8.4 KiB
Go

package dnscrypt
import (
"bytes"
"context"
"crypto/ed25519"
"fmt"
"net"
"runtime"
"testing"
"time"
"github.com/ameshkov/dnsstamps"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func TestServer_Shutdown(t *testing.T) {
n := runtime.GOMAXPROCS(1)
t.Cleanup(func() {
runtime.GOMAXPROCS(n)
})
srv := newTestServer(t, &testHandler{})
// Serve* methods are called in different goroutines
// give them at least a moment to actually start the server
time.Sleep(10 * time.Millisecond)
require.NoError(t, srv.Close())
}
func TestServer_UDPServeCert(t *testing.T) {
testServerServeCert(t, "udp")
}
func TestServer_TCPServeCert(t *testing.T) {
testServerServeCert(t, "tcp")
}
func TestServer_UDPRespondMessages(t *testing.T) {
testServerRespondMessages(t, "udp")
}
func TestServer_TCPRespondMessages(t *testing.T) {
testServerRespondMessages(t, "tcp")
}
func TestServer_ReadTimeout(t *testing.T) {
srv := newTestServer(t, &testHandler{})
t.Cleanup(func() {
require.NoError(t, srv.Close())
})
// Sleep for "defaultReadTimeout" before trying to shutdown the server
// The point is to make sure readTimeout is properly handled by
// the "Serve*" goroutines and they don't finish their work unexpectedly
time.Sleep(defaultReadTimeout)
testThisServerRespondMessages(t, "udp", srv)
testThisServerRespondMessages(t, "tcp", srv)
}
func TestServer_UDPTruncateMessage(t *testing.T) {
// Create a test server that returns large response which should be
// truncated if sent over UDP
srv := newTestServer(t, &testLargeMsgHandler{})
t.Cleanup(func() {
require.NoError(t, srv.Close())
})
// Create client and connect
client := &Client{
Timeout: 1 * time.Second,
Net: "udp",
}
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
stamp := dnsstamps.ServerStamp{
ServerAddrStr: serverAddr,
ServerPk: srv.resolverPk,
ProviderName: srv.server.ProviderName,
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
require.NoError(t, err)
require.NotNil(t, ri)
// Send a test message and check that the response was truncated
m := createTestMessage()
res, err := client.Exchange(m, ri)
require.NoError(t, err)
require.NotNil(t, res)
require.Equal(t, dns.RcodeSuccess, res.Rcode)
require.Len(t, res.Answer, 0)
require.True(t, res.Truncated)
}
func TestServer_UDPEDNS0_NoTruncate(t *testing.T) {
// Create a test server that returns large response which should be
// truncated if sent over UDP
// However, when EDNS0 is set with the buffer large enough, there should
// be no truncation
srv := newTestServer(t, &testLargeMsgHandler{})
t.Cleanup(func() {
require.NoError(t, srv.Close())
})
// Create client and connect
client := &Client{
Timeout: 1 * time.Second,
Net: "udp",
UDPSize: 7000, // make sure the client will be able to read the response
}
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
stamp := dnsstamps.ServerStamp{
ServerAddrStr: serverAddr,
ServerPk: srv.resolverPk,
ProviderName: srv.server.ProviderName,
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
require.NoError(t, err)
require.NotNil(t, ri)
// Send a test message with UDP buffer size large enough
// and check that the response was NOT truncated
m := createTestMessage()
m.Extra = append(m.Extra, &dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
Class: 2000, // Set large enough UDPSize here
},
})
res, err := client.Exchange(m, ri)
require.NoError(t, err)
require.NotNil(t, res)
require.Equal(t, dns.RcodeSuccess, res.Rcode)
require.Len(t, res.Answer, 64)
require.False(t, res.Truncated)
}
func testServerServeCert(t *testing.T, network string) {
srv := newTestServer(t, &testHandler{})
t.Cleanup(func() {
require.NoError(t, srv.Close())
})
client := &Client{
Net: network,
Timeout: 1 * time.Second,
}
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
if network == "tcp" {
serverAddr = fmt.Sprintf("127.0.0.1:%d", srv.TCPAddr().Port)
}
stamp := dnsstamps.ServerStamp{
ServerAddrStr: serverAddr,
ServerPk: srv.resolverPk,
ProviderName: srv.server.ProviderName,
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
require.NoError(t, err)
require.NotNil(t, ri)
require.Equal(t, ri.ProviderName, srv.server.ProviderName)
require.True(t, bytes.Equal(srv.server.ResolverCert.ClientMagic[:], ri.ResolverCert.ClientMagic[:]))
require.Equal(t, srv.server.ResolverCert.EsVersion, ri.ResolverCert.EsVersion)
require.Equal(t, srv.server.ResolverCert.Signature, ri.ResolverCert.Signature)
require.Equal(t, srv.server.ResolverCert.NotBefore, ri.ResolverCert.NotBefore)
require.Equal(t, srv.server.ResolverCert.NotAfter, ri.ResolverCert.NotAfter)
require.True(t, bytes.Equal(srv.server.ResolverCert.ResolverPk[:], ri.ResolverCert.ResolverPk[:]))
require.True(t, bytes.Equal(srv.server.ResolverCert.ResolverPk[:], ri.ResolverCert.ResolverPk[:]))
}
func testServerRespondMessages(t *testing.T, network string) {
srv := newTestServer(t, &testHandler{})
t.Cleanup(func() {
require.NoError(t, srv.Close())
})
testThisServerRespondMessages(t, network, srv)
}
func testThisServerRespondMessages(t *testing.T, network string, srv *testServer) {
client := &Client{
Timeout: 1 * time.Second,
Net: network,
}
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
if network == "tcp" {
serverAddr = fmt.Sprintf("127.0.0.1:%d", srv.TCPAddr().Port)
}
stamp := dnsstamps.ServerStamp{
ServerAddrStr: serverAddr,
ServerPk: srv.resolverPk,
ProviderName: srv.server.ProviderName,
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
require.NoError(t, err)
require.NotNil(t, ri)
conn, err := net.Dial(network, stamp.ServerAddrStr)
require.NoError(t, err)
for i := 0; i < 10; i++ {
m := createTestMessage()
res, err := client.ExchangeConn(conn, m, ri)
require.NoError(t, err)
assertTestMessageResponse(t, res)
}
}
type testServer struct {
server *Server
resolverPk ed25519.PublicKey
udpConn *net.UDPConn
tcpListen net.Listener
}
func (s *testServer) TCPAddr() *net.TCPAddr {
return s.tcpListen.Addr().(*net.TCPAddr)
}
func (s *testServer) UDPAddr() *net.UDPAddr {
return s.udpConn.LocalAddr().(*net.UDPAddr)
}
func (s *testServer) Close() error {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
defer cancel()
err := s.server.Shutdown(ctx)
_ = s.udpConn.Close()
_ = s.tcpListen.Close()
return err
}
func newTestServer(t require.TestingT, handler Handler) *testServer {
rc, err := GenerateResolverConfig("example.org", nil)
require.NoError(t, err)
cert, err := rc.CreateCert()
require.NoError(t, err)
s := &Server{
ProviderName: rc.ProviderName,
ResolverCert: cert,
Handler: handler,
}
privateKey, err := HexDecodeKey(rc.PrivateKey)
require.NoError(t, err)
publicKey := ed25519.PrivateKey(privateKey).Public().(ed25519.PublicKey)
srv := &testServer{
server: s,
resolverPk: publicKey,
}
srv.tcpListen, err = net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
srv.udpConn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
go func() {
_ = s.ServeUDP(srv.udpConn)
}()
go func() {
_ = s.ServeTCP(srv.tcpListen)
}()
return srv
}
type testHandler struct{}
// ServeDNS - implements Handler interface
func (h *testHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
res := new(dns.Msg)
res.SetReply(r)
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Ttl: 300,
Class: dns.ClassINET,
}
// First record is from Google DNS
answer.A = net.IPv4(8, 8, 8, 8)
res.Answer = append(res.Answer, answer)
return rw.WriteMsg(res)
}
// testLargeMsgHandler is a handler that returns a huge response
// used for testing messages truncation
type testLargeMsgHandler struct{}
// ServeDNS - implements Handler interface
func (h *testLargeMsgHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
res := new(dns.Msg)
res.SetReply(r)
for i := 0; i < 64; i++ {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Ttl: 300,
Class: dns.ClassINET,
}
answer.A = net.IPv4(127, 0, 0, byte(i))
res.Answer = append(res.Answer, answer)
}
res.Compress = true
return rw.WriteMsg(res)
}