1
0
Fork 1
mirror of https://github.com/maunium/stickerpicker synced 2024-10-18 02:14:06 +00:00
stickerpicker/sticker/server/api/fed_connector.py
2020-10-31 21:53:46 +02:00

110 lines
4.1 KiB
Python

from typing import Tuple, Any, NamedTuple, Dict, Optional
from time import time
import ipaddress
import logging
import asyncio
import json
from mautrix.util.logging import TraceLogger
from aiohttp import ClientRequest, TCPConnector, ClientSession, ClientTimeout, ClientError
from aiohttp.client_proto import ResponseHandler
from yarl import URL
import aiodns
log: TraceLogger = logging.getLogger("mau.federation")
class ResolvedServerName(NamedTuple):
host_header: str
host: str
port: int
expire: int
class ServerNameSplit(NamedTuple):
host: str
port: Optional[int]
is_ip: bool
dns_resolver: aiodns.DNSResolver
http: ClientSession
server_name_cache: Dict[str, ResolvedServerName] = {}
class MatrixFederationTCPConnector(TCPConnector):
"""An extension to aiohttp's TCPConnector that correctly sets the TLS SNI for Matrix federation
requests, where the TCP host may not match the SNI/Host header."""
async def _wrap_create_connection(self, *args: Any, server_hostname: str, req: ClientRequest,
**kwargs: Any) -> Tuple[asyncio.Transport, ResponseHandler]:
split = parse_server_name(req.headers["Host"])
return await super()._wrap_create_connection(*args, server_hostname=split.host,
req=req, **kwargs)
def parse_server_name(name: str) -> ServerNameSplit:
port_split = name.rsplit(":", 1)
if len(port_split) == 2 and port_split[1].isdecimal():
name, port = port_split
else:
port = None
try:
ipaddress.ip_address(name)
is_ip = True
except ValueError:
is_ip = False
res = ServerNameSplit(host=name, port=port, is_ip=is_ip)
log.trace(f"Parsed server name {name} into {res}")
return res
async def resolve_server_name(server_name: str) -> ResolvedServerName:
try:
cached = server_name_cache[server_name]
if cached.expire > int(time()):
log.trace(f"Using cached server name resolution for {server_name}: {cached}")
return cached
except KeyError:
log.trace(f"No cached server name resolution for {server_name}")
host_header = server_name
hostname, port, is_ip = parse_server_name(host_header)
ttl = 86400
if port is None and not is_ip:
well_known_url = URL.build(scheme="https", host=host_header, port=443,
path="/.well-known/matrix/server")
try:
log.trace(f"Requesting {well_known_url} to resolve {server_name}'s .well-known")
async with http.get(well_known_url) as resp:
if resp.status == 200:
well_known_data = await resp.json()
host_header = well_known_data["m.server"]
log.debug(f"Got {host_header} from {server_name}'s .well-known")
hostname, port, is_ip = parse_server_name(host_header)
else:
log.trace(f"Got non-200 status {resp.status} from {server_name}'s .well-known")
except (ClientError, json.JSONDecodeError, KeyError, ValueError) as e:
log.debug(f"Failed to fetch .well-known for {server_name}: {e}")
if port is None and not is_ip:
log.trace(f"Querying SRV at _matrix._tcp.{host_header}")
res = await dns_resolver.query(f"_matrix._tcp.{host_header}", "SRV")
if res:
hostname = res[0].host
port = res[0].port
ttl = max(res[0].ttl, 300)
log.debug(f"Got {hostname}:{port} from {host_header}'s Matrix SRV record")
else:
log.trace(f"No SRV records found at _matrix._tcp.{host_header}")
result = ResolvedServerName(host_header=host_header, host=hostname, port=port or 8448,
expire=int(time()) + ttl)
server_name_cache[server_name] = result
log.debug(f"Resolved server name {server_name} -> {result}")
return result
def init():
global http, dns_resolver
dns_resolver = aiodns.DNSResolver(loop=asyncio.get_running_loop())
http = ClientSession(timeout=ClientTimeout(total=10),
connector=MatrixFederationTCPConnector())