1
0
Fork 1
mirror of https://github.com/maunium/stickerpicker synced 2024-10-18 02:14:06 +00:00
stickerpicker/sticker/server/api/auth.py
2020-10-31 21:53:46 +02:00

216 lines
7.5 KiB
Python

# maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
# Copyright (C) 2020 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Callable, Awaitable, Optional, TYPE_CHECKING
import logging
import json
from mautrix.client import Client
from mautrix.types import UserID
from mautrix.util.logging import TraceLogger
from aiohttp import web, hdrs, ClientError, ClientSession
from yarl import URL
from ..database import AccessToken, User
from ..config import Config
from .errors import Error
from . import fed_connector
if TYPE_CHECKING:
from typing import TypedDict
class OpenIDPayload(TypedDict):
access_token: str
token_type: str
matrix_server_name: str
expires_in: int
class OpenIDResponse(TypedDict):
sub: str
Handler = Callable[[web.Request], Awaitable[web.Response]]
log: TraceLogger = logging.getLogger("mau.api.auth")
routes = web.RouteTableDef()
config: Config
def get_ip(request: web.Request) -> str:
if config["server.trust_forward_headers"]:
try:
return request.headers["X-Forwarded-For"]
except KeyError:
pass
return request.remote
def get_auth_header(request: web.Request) -> str:
try:
auth = request.headers["Authorization"]
if not auth.startswith("Bearer "):
raise Error.invalid_auth_header
return auth[len("Bearer "):]
except KeyError:
raise Error.missing_auth_header
async def get_user(request: web.Request) -> Tuple[User, AccessToken]:
auth = get_auth_header(request)
try:
token_id, token_val = auth.split(":")
token_id = int(token_id)
except ValueError:
raise Error.invalid_auth_token
token = await AccessToken.get(token_id)
if not token or not token.check(token_val):
raise Error.invalid_auth_token
elif token.expired:
raise Error.auth_token_expired
await token.update_ip(get_ip(request))
return await User.get(token.user_id), token
@web.middleware
async def token_middleware(request: web.Request, handler: Handler) -> web.Response:
if request.method == hdrs.METH_OPTIONS:
return await handler(request)
user, token = await get_user(request)
request["user"] = user
request["token"] = token
return await handler(request)
async def get_widget_user(request: web.Request) -> User:
try:
user_id = UserID(request.headers["X-Matrix-User-ID"])
except KeyError:
raise Error.missing_user_id_header
user = await User.get(user_id)
if user is None:
raise Error.user_not_found
return user
@web.middleware
async def widget_secret_middleware(request: web.Request, handler: Handler) -> web.Response:
if request.method == hdrs.METH_OPTIONS:
return await handler(request)
user = await get_widget_user(request)
request["user"] = user
return await handler(request)
account_cors_headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS, GET, POST",
"Access-Control-Allow-Headers": "Authorization, Content-Type",
}
@routes.get("/account")
async def get_auth(request: web.Request) -> web.Response:
user, token = await get_user(request)
return web.json_response({"user_id": token.user_id}, headers=account_cors_headers)
async def check_openid_token(homeserver: str, token: str) -> Optional[UserID]:
server_info = await fed_connector.resolve_server_name(homeserver)
headers = {"Host": server_info.host_header}
userinfo_url = URL.build(scheme="https", host=server_info.host, port=server_info.port,
path="/_matrix/federation/v1/openid/userinfo",
query={"access_token": token})
try:
async with fed_connector.http.get(userinfo_url, headers=headers) as resp:
data: 'OpenIDResponse' = await resp.json()
return UserID(data["sub"])
except (ClientError, json.JSONDecodeError, KeyError, ValueError) as e:
log.debug(f"Failed to check OpenID token from {homeserver}", exc_info=True)
return None
@routes.route(hdrs.METH_OPTIONS, "/account/register")
@routes.route(hdrs.METH_OPTIONS, "/account/logout")
@routes.route(hdrs.METH_OPTIONS, "/account")
async def cors_token(_: web.Request) -> web.Response:
return web.Response(status=200, headers=account_cors_headers)
async def resolve_client_well_known(server_name: str) -> str:
url = URL.build(scheme="https", host=server_name, port=443, path="/.well-known/matrix/client")
async with ClientSession() as sess, sess.get(url) as resp:
data = await resp.json()
return data["m.homeserver"]["base_url"]
@routes.post("/account/register")
async def exchange_token(request: web.Request) -> web.Response:
try:
data: 'OpenIDPayload' = await request.json()
except json.JSONDecodeError:
raise Error.request_not_json
try:
matrix_server_name = data["matrix_server_name"]
access_token = data["access_token"]
except KeyError:
raise Error.invalid_openid_payload
log.trace(f"Validating OpenID token from {matrix_server_name}")
user_id = await check_openid_token(matrix_server_name, access_token)
if user_id is None:
raise Error.invalid_openid_token
_, homeserver = Client.parse_user_id(user_id)
if homeserver != data["matrix_server_name"]:
raise Error.homeserver_mismatch
permissions = config.get_permissions(user_id)
if not permissions.access:
raise Error.no_access
try:
log.trace(f"Trying to resolve {matrix_server_name}'s client .well-known")
homeserver_url = await resolve_client_well_known(matrix_server_name)
log.trace(f"Got {homeserver_url} from {matrix_server_name}'s client .well-known")
except (ClientError, json.JSONDecodeError, KeyError, ValueError, TypeError):
log.trace(f"Failed to resolve {matrix_server_name}'s client .well-known", exc_info=True)
raise Error.client_well_known_error
user = await User.get(user_id)
if user is None:
log.debug(f"Creating user {user_id} with homeserver client URL {homeserver_url}")
user = User.new(user_id, homeserver_url=homeserver_url)
await user.insert()
elif user.homeserver_url != homeserver_url:
log.debug(f"Updating {user_id}'s homeserver client URL from {user.homeserver_url} "
f"to {homeserver_url}")
await user.set_homeserver_url(homeserver_url)
token = await user.new_access_token(get_ip(request))
return web.json_response({
"user_id": user_id,
"token": token,
"permissions": permissions._asdict(),
}, headers=account_cors_headers)
@routes.post("/account/logout")
async def logout(request: web.Request) -> web.Response:
user, token = await get_user(request)
await token.delete()
return web.json_response({}, status=204, headers=account_cors_headers)
def init(cfg: Config) -> None:
global config
config = cfg