diff --git a/bitbake/bin/bitbake-hashclient b/bitbake/bin/bitbake-hashclient index a02a65b937..328c15cdec 100755 --- a/bitbake/bin/bitbake-hashclient +++ b/bitbake/bin/bitbake-hashclient @@ -14,6 +14,7 @@ import sys import threading import time import warnings +import netrc warnings.simplefilter("default") try: @@ -36,10 +37,18 @@ except ImportError: sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib')) import hashserv +import bb.asyncrpc DEFAULT_ADDRESS = 'unix://./hashserve.sock' METHOD = 'stress.test.method' +def print_user(u): + print(f"Username: {u['username']}") + if "permissions" in u: + print("Permissions: " + " ".join(u["permissions"])) + if "token" in u: + print(f"Token: {u['token']}") + def main(): def handle_stats(args, client): @@ -125,9 +134,39 @@ def main(): print("Removed %d rows" % (result["count"])) return 0 + def handle_refresh_token(args, client): + r = client.refresh_token(args.username) + print_user(r) + + def handle_set_user_permissions(args, client): + r = client.set_user_perms(args.username, args.permissions) + print_user(r) + + def handle_get_user(args, client): + r = client.get_user(args.username) + print_user(r) + + def handle_get_all_users(args, client): + users = client.get_all_users() + print("{username:20}| {permissions}".format(username="Username", permissions="Permissions")) + print(("-" * 20) + "+" + ("-" * 20)) + for u in users: + print("{username:20}| {permissions}".format(username=u["username"], permissions=" ".join(u["permissions"]))) + + def handle_new_user(args, client): + r = client.new_user(args.username, args.permissions) + print_user(r) + + def handle_delete_user(args, client): + r = client.delete_user(args.username) + print_user(r) + parser = argparse.ArgumentParser(description='Hash Equivalence Client') parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")') parser.add_argument('--log', default='WARNING', help='Set logging level') + parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME") + parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN") + parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc") subparsers = parser.add_subparsers() @@ -158,6 +197,31 @@ def main(): clean_unused_parser.add_argument("max_age", metavar="SECONDS", type=int, help="Remove unused entries older than SECONDS old") clean_unused_parser.set_defaults(func=handle_clean_unused) + refresh_token_parser = subparsers.add_parser('refresh-token', help="Refresh auth token") + refresh_token_parser.add_argument("--username", "-u", help="Refresh the token for another user (if authorized)") + refresh_token_parser.set_defaults(func=handle_refresh_token) + + set_user_perms_parser = subparsers.add_parser('set-user-perms', help="Set new permissions for user") + set_user_perms_parser.add_argument("--username", "-u", help="Username", required=True) + set_user_perms_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions") + set_user_perms_parser.set_defaults(func=handle_set_user_permissions) + + get_user_parser = subparsers.add_parser('get-user', help="Get user") + get_user_parser.add_argument("--username", "-u", help="Username") + get_user_parser.set_defaults(func=handle_get_user) + + get_all_users_parser = subparsers.add_parser('get-all-users', help="List all users") + get_all_users_parser.set_defaults(func=handle_get_all_users) + + new_user_parser = subparsers.add_parser('new-user', help="Create new user") + new_user_parser.add_argument("--username", "-u", help="Username", required=True) + new_user_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions") + new_user_parser.set_defaults(func=handle_new_user) + + delete_user_parser = subparsers.add_parser('delete-user', help="Delete user") + delete_user_parser.add_argument("--username", "-u", help="Username", required=True) + delete_user_parser.set_defaults(func=handle_delete_user) + args = parser.parse_args() logger = logging.getLogger('hashserv') @@ -171,10 +235,26 @@ def main(): console.setLevel(level) logger.addHandler(console) + login = args.login + password = args.password + + if login is None and args.netrc: + try: + n = netrc.netrc() + auth = n.authenticators(args.address) + if auth is not None: + login, _, password = auth + except FileNotFoundError: + pass + func = getattr(args, 'func', None) if func: - with hashserv.create_client(args.address) as client: - return func(args, client) + try: + with hashserv.create_client(args.address, login, password) as client: + return func(args, client) + except bb.asyncrpc.InvokeError as e: + print(f"ERROR: {e}") + return 1 return 0 diff --git a/bitbake/bin/bitbake-hashserv b/bitbake/bin/bitbake-hashserv index 59b8b07f59..1085d0584e 100755 --- a/bitbake/bin/bitbake-hashserv +++ b/bitbake/bin/bitbake-hashserv @@ -17,6 +17,7 @@ warnings.simplefilter("default") sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib")) import hashserv +from hashserv.server import DEFAULT_ANON_PERMS VERSION = "1.0.0" @@ -36,6 +37,22 @@ The bind address may take one of the following formats: To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or "--bind ws://:8686". To bind to a specific IPv6 address, enclose the address in "[]", e.g. "--bind [::1]:8686" or "--bind ws://[::1]:8686" + +Note that the default Anonymous permissions are designed to not break existing +server instances when upgrading, but are not particularly secure defaults. If +you want to use authentication, it is recommended that you use "--anon-perms +@read" to only give anonymous users read access, or "--anon-perms @none" to +give un-authenticated users no access at all. + +Setting "--anon-perms @all" or "--anon-perms @user-admin" is not allowed, since +this would allow anonymous users to manage all users accounts, which is a bad +idea. + +If you are using user authentication, you should run your server in websockets +mode with an SSL terminating load balancer in front of it (as this server does +not implement SSL). Otherwise all usernames and passwords will be transmitted +in the clear. When configured this way, clients can connect using a secure +websocket, as in "wss://SERVER:PORT" """, ) @@ -79,6 +96,22 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or default=os.environ.get("HASHSERVER_DB_PASSWORD", None), help="Database password ($HASHSERVER_DB_PASSWORD)", ) + parser.add_argument( + "--anon-perms", + metavar="PERM[,PERM[,...]]", + default=os.environ.get("HASHSERVER_ANON_PERMS", ",".join(DEFAULT_ANON_PERMS)), + help='Permissions to give anonymous users (default $HASHSERVER_ANON_PERMS, "%(default)s")', + ) + parser.add_argument( + "--admin-user", + default=os.environ.get("HASHSERVER_ADMIN_USER", None), + help="Create default admin user with name ADMIN_USER ($HASHSERVER_ADMIN_USER)", + ) + parser.add_argument( + "--admin-password", + default=os.environ.get("HASHSERVER_ADMIN_PASSWORD", None), + help="Create default admin user with password ADMIN_PASSWORD ($HASHSERVER_ADMIN_PASSWORD)", + ) args = parser.parse_args() @@ -94,6 +127,7 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or logger.addHandler(console) read_only = (os.environ.get("HASHSERVER_READ_ONLY", "0") == "1") or args.read_only + anon_perms = args.anon_perms.split(",") server = hashserv.create_server( args.bind, @@ -102,6 +136,9 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or read_only=read_only, db_username=args.db_username, db_password=args.db_password, + anon_perms=anon_perms, + admin_username=args.admin_user, + admin_password=args.admin_password, ) server.serve_forever() return 0 diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py index 9a8ee4e88b..552a33278f 100644 --- a/bitbake/lib/hashserv/__init__.py +++ b/bitbake/lib/hashserv/__init__.py @@ -8,6 +8,7 @@ from contextlib import closing import re import itertools import json +from collections import namedtuple from urllib.parse import urlparse UNIX_PREFIX = "unix://" @@ -18,6 +19,8 @@ ADDR_TYPE_UNIX = 0 ADDR_TYPE_TCP = 1 ADDR_TYPE_WS = 2 +User = namedtuple("User", ("username", "permissions")) + def parse_address(addr): if addr.startswith(UNIX_PREFIX): @@ -43,7 +46,10 @@ def create_server( upstream=None, read_only=False, db_username=None, - db_password=None + db_password=None, + anon_perms=None, + admin_username=None, + admin_password=None, ): def sqlite_engine(): from .sqlite import DatabaseEngine @@ -62,7 +68,17 @@ def create_server( else: db_engine = sqlite_engine() - s = server.Server(db_engine, upstream=upstream, read_only=read_only) + if anon_perms is None: + anon_perms = server.DEFAULT_ANON_PERMS + + s = server.Server( + db_engine, + upstream=upstream, + read_only=read_only, + anon_perms=anon_perms, + admin_username=admin_username, + admin_password=admin_password, + ) (typ, a) = parse_address(addr) if typ == ADDR_TYPE_UNIX: @@ -76,33 +92,40 @@ def create_server( return s -def create_client(addr): +def create_client(addr, username=None, password=None): from . import client - c = client.Client() + c = client.Client(username, password) - (typ, a) = parse_address(addr) - if typ == ADDR_TYPE_UNIX: - c.connect_unix(*a) - elif typ == ADDR_TYPE_WS: - c.connect_websocket(*a) - else: - c.connect_tcp(*a) - - return c + try: + (typ, a) = parse_address(addr) + if typ == ADDR_TYPE_UNIX: + c.connect_unix(*a) + elif typ == ADDR_TYPE_WS: + c.connect_websocket(*a) + else: + c.connect_tcp(*a) + return c + except Exception as e: + c.close() + raise e -async def create_async_client(addr): +async def create_async_client(addr, username=None, password=None): from . import client - c = client.AsyncClient() + c = client.AsyncClient(username, password) - (typ, a) = parse_address(addr) - if typ == ADDR_TYPE_UNIX: - await c.connect_unix(*a) - elif typ == ADDR_TYPE_WS: - await c.connect_websocket(*a) - else: - await c.connect_tcp(*a) + try: + (typ, a) = parse_address(addr) + if typ == ADDR_TYPE_UNIX: + await c.connect_unix(*a) + elif typ == ADDR_TYPE_WS: + await c.connect_websocket(*a) + else: + await c.connect_tcp(*a) - return c + return c + except Exception as e: + await c.close() + raise e diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py index 9542d72f6c..82400fe5aa 100644 --- a/bitbake/lib/hashserv/client.py +++ b/bitbake/lib/hashserv/client.py @@ -6,6 +6,7 @@ import logging import socket import bb.asyncrpc +import json from . import create_async_client @@ -16,15 +17,19 @@ class AsyncClient(bb.asyncrpc.AsyncClient): MODE_NORMAL = 0 MODE_GET_STREAM = 1 - def __init__(self): + def __init__(self, username=None, password=None): super().__init__('OEHASHEQUIV', '1.1', logger) self.mode = self.MODE_NORMAL + self.username = username + self.password = password async def setup_connection(self): await super().setup_connection() cur_mode = self.mode self.mode = self.MODE_NORMAL await self._set_mode(cur_mode) + if self.username: + await self.auth(self.username, self.password) async def send_stream(self, msg): async def proc(): @@ -41,6 +46,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient): if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM: r = await self._send_wrapper(stream_to_normal) if r != "ok": + self.check_invoke_error(r) raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r) elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL: r = await self.invoke({"get-stream": None}) @@ -109,9 +115,52 @@ class AsyncClient(bb.asyncrpc.AsyncClient): await self._set_mode(self.MODE_NORMAL) return await self.invoke({"clean-unused": {"max_age_seconds": max_age}}) + async def auth(self, username, token): + await self._set_mode(self.MODE_NORMAL) + result = await self.invoke({"auth": {"username": username, "token": token}}) + self.username = username + self.password = token + return result + + async def refresh_token(self, username=None): + await self._set_mode(self.MODE_NORMAL) + m = {} + if username: + m["username"] = username + result = await self.invoke({"refresh-token": m}) + if self.username and result["username"] == self.username: + self.password = result["token"] + return result + + async def set_user_perms(self, username, permissions): + await self._set_mode(self.MODE_NORMAL) + return await self.invoke({"set-user-perms": {"username": username, "permissions": permissions}}) + + async def get_user(self, username=None): + await self._set_mode(self.MODE_NORMAL) + m = {} + if username: + m["username"] = username + return await self.invoke({"get-user": m}) + + async def get_all_users(self): + await self._set_mode(self.MODE_NORMAL) + return (await self.invoke({"get-all-users": {}}))["users"] + + async def new_user(self, username, permissions): + await self._set_mode(self.MODE_NORMAL) + return await self.invoke({"new-user": {"username": username, "permissions": permissions}}) + + async def delete_user(self, username): + await self._set_mode(self.MODE_NORMAL) + return await self.invoke({"delete-user": {"username": username}}) + class Client(bb.asyncrpc.Client): - def __init__(self): + def __init__(self, username=None, password=None): + self.username = username + self.password = password + super().__init__() self._add_methods( "connect_tcp", @@ -126,7 +175,14 @@ class Client(bb.asyncrpc.Client): "backfill_wait", "remove", "clean_unused", + "auth", + "refresh_token", + "set_user_perms", + "get_user", + "get_all_users", + "new_user", + "delete_user", ) def _get_async_client(self): - return AsyncClient() + return AsyncClient(self.username, self.password) diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index c691df7618..f5baa6be78 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py @@ -8,13 +8,48 @@ import asyncio import logging import math import time +import os +import base64 +import hashlib from . import create_async_client import bb.asyncrpc - logger = logging.getLogger("hashserv.server") +# This permission only exists to match nothing +NONE_PERM = "@none" + +READ_PERM = "@read" +REPORT_PERM = "@report" +DB_ADMIN_PERM = "@db-admin" +USER_ADMIN_PERM = "@user-admin" +ALL_PERM = "@all" + +ALL_PERMISSIONS = { + READ_PERM, + REPORT_PERM, + DB_ADMIN_PERM, + USER_ADMIN_PERM, + ALL_PERM, +} + +DEFAULT_ANON_PERMS = ( + READ_PERM, + REPORT_PERM, + DB_ADMIN_PERM, +) + +TOKEN_ALGORITHM = "sha256" + +# 48 bytes of random data will result in 64 characters when base64 +# encoded. This number also ensures that the base64 encoding won't have any +# trailing '=' characters. +TOKEN_SIZE = 48 + +SALT_SIZE = 8 + + class Measurement(object): def __init__(self, sample): self.sample = sample @@ -108,6 +143,85 @@ class Stats(object): } +token_refresh_semaphore = asyncio.Lock() + + +async def new_token(): + # Prevent malicious users from using this API to deduce the entropy + # pool on the server and thus be able to guess a token. *All* token + # refresh requests lock the same global semaphore and then sleep for a + # short time. The effectively rate limits the total number of requests + # than can be made across all clients to 10/second, which should be enough + # since you have to be an authenticated users to make the request in the + # first place + async with token_refresh_semaphore: + await asyncio.sleep(0.1) + raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK) + + return base64.b64encode(raw, b"._").decode("utf-8") + + +def new_salt(): + return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex() + + +def hash_token(algo, salt, token): + h = hashlib.new(algo) + h.update(salt.encode("utf-8")) + h.update(token.encode("utf-8")) + return ":".join([algo, salt, h.hexdigest()]) + + +def permissions(*permissions, allow_anon=True, allow_self_service=False): + """ + Function decorator that can be used to decorate an RPC function call and + check that the current users permissions match the require permissions. + + If allow_anon is True, the user will also be allowed to make the RPC call + if the anonymous user permissions match the permissions. + + If allow_self_service is True, and the "username" property in the request + is the currently logged in user, or not specified, the user will also be + allowed to make the request. This allows users to access normal privileged + API, as long as they are only modifying their own user properties (e.g. + users can be allowed to reset their own token without @user-admin + permissions, but not the token for any other user. + """ + + def wrapper(func): + async def wrap(self, request): + if allow_self_service and self.user is not None: + username = request.get("username", self.user.username) + if username == self.user.username: + request["username"] = self.user.username + return await func(self, request) + + if not self.user_has_permissions(*permissions, allow_anon=allow_anon): + if not self.user: + username = "Anonymous user" + user_perms = self.anon_perms + else: + username = self.user.username + user_perms = self.user.permissions + + self.logger.info( + "User %s with permissions %r denied from calling %s. Missing permissions(s) %r", + username, + ", ".join(user_perms), + func.__name__, + ", ".join(permissions), + ) + raise bb.asyncrpc.InvokeError( + f"{username} is not allowed to access permissions(s) {', '.join(permissions)}" + ) + + return await func(self, request) + + return wrap + + return wrapper + + class ServerClient(bb.asyncrpc.AsyncServerConnection): def __init__( self, @@ -117,6 +231,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): backfill_queue, upstream, read_only, + anon_perms, ): super().__init__(socket, "OEHASHEQUIV", logger) self.db_engine = db_engine @@ -125,6 +240,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): self.backfill_queue = backfill_queue self.upstream = upstream self.read_only = read_only + self.user = None + self.anon_perms = anon_perms self.handlers.update( { @@ -135,6 +252,9 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): # Not always read-only, but internally checks if the server is # read-only "report": self.handle_report, + "auth": self.handle_auth, + "get-user": self.handle_get_user, + "get-all-users": self.handle_get_all_users, } ) @@ -146,9 +266,36 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): "backfill-wait": self.handle_backfill_wait, "remove": self.handle_remove, "clean-unused": self.handle_clean_unused, + "refresh-token": self.handle_refresh_token, + "set-user-perms": self.handle_set_perms, + "new-user": self.handle_new_user, + "delete-user": self.handle_delete_user, } ) + def raise_no_user_error(self, username): + raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists") + + def user_has_permissions(self, *permissions, allow_anon=True): + permissions = set(permissions) + if allow_anon: + if ALL_PERM in self.anon_perms: + return True + + if not permissions - self.anon_perms: + return True + + if self.user is None: + return False + + if ALL_PERM in self.user.permissions: + return True + + if not permissions - self.user.permissions: + return True + + return False + def validate_proto_version(self): return self.proto_version > (1, 0) and self.proto_version <= (1, 1) @@ -178,6 +325,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg) + @permissions(READ_PERM) async def handle_get(self, request): method = request["method"] taskhash = request["taskhash"] @@ -206,6 +354,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): return d + @permissions(READ_PERM) async def handle_get_outhash(self, request): method = request["method"] outhash = request["outhash"] @@ -236,6 +385,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) await self.db.insert_outhash(data) + @permissions(READ_PERM) async def handle_get_stream(self, request): await self.socket.send_message("ok") @@ -304,8 +454,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): "unihash": unihash, } + # Since this can be called either read only or to report, the check to + # report is made inside the function + @permissions(READ_PERM) async def handle_report(self, data): - if self.read_only: + if self.read_only or not self.user_has_permissions(REPORT_PERM): return await self.report_readonly(data) outhash_data = { @@ -358,6 +511,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): "unihash": unihash, } + @permissions(READ_PERM, REPORT_PERM) async def handle_equivreport(self, data): await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) @@ -375,11 +529,13 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): return {k: row[k] for k in ("taskhash", "method", "unihash")} + @permissions(READ_PERM) async def handle_get_stats(self, request): return { "requests": self.request_stats.todict(), } + @permissions(DB_ADMIN_PERM) async def handle_reset_stats(self, request): d = { "requests": self.request_stats.todict(), @@ -388,6 +544,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): self.request_stats.reset() return d + @permissions(READ_PERM) async def handle_backfill_wait(self, request): d = { "tasks": self.backfill_queue.qsize(), @@ -395,6 +552,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): await self.backfill_queue.join() return d + @permissions(DB_ADMIN_PERM) async def handle_remove(self, request): condition = request["where"] if not isinstance(condition, dict): @@ -402,19 +560,178 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): return {"count": await self.db.remove(condition)} + @permissions(DB_ADMIN_PERM) async def handle_clean_unused(self, request): max_age = request["max_age_seconds"] oldest = datetime.now() - timedelta(seconds=-max_age) return {"count": await self.db.clean_unused(oldest)} + # The authentication API is always allowed + async def handle_auth(self, request): + username = str(request["username"]) + token = str(request["token"]) + + async def fail_auth(): + nonlocal username + # Rate limit bad login attempts + await asyncio.sleep(1) + raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}") + + user, db_token = await self.db.lookup_user_token(username) + + if not user or not db_token: + await fail_auth() + + try: + algo, salt, _ = db_token.split(":") + except ValueError: + await fail_auth() + + if hash_token(algo, salt, token) != db_token: + await fail_auth() + + self.user = user + + self.logger.info("Authenticated as %s", username) + + return { + "result": True, + "username": self.user.username, + "permissions": sorted(list(self.user.permissions)), + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_refresh_token(self, request): + username = str(request["username"]) + + token = await new_token() + + updated = await self.db.set_user_token( + username, + hash_token(TOKEN_ALGORITHM, new_salt(), token), + ) + if not updated: + self.raise_no_user_error(username) + + return {"username": username, "token": token} + + def get_perm_arg(self, arg): + if not isinstance(arg, list): + raise bb.asyncrpc.InvokeError("Unexpected type for permissions") + + arg = set(arg) + try: + arg.remove(NONE_PERM) + except KeyError: + pass + + unknown_perms = arg - ALL_PERMISSIONS + if unknown_perms: + raise bb.asyncrpc.InvokeError( + "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms))) + ) + + return sorted(list(arg)) + + def return_perms(self, permissions): + if ALL_PERM in permissions: + return sorted(list(ALL_PERMISSIONS)) + return sorted(list(permissions)) + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_set_perms(self, request): + username = str(request["username"]) + permissions = self.get_perm_arg(request["permissions"]) + + if not await self.db.set_user_perms(username, permissions): + self.raise_no_user_error(username) + + return { + "username": username, + "permissions": self.return_perms(permissions), + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_get_user(self, request): + username = str(request["username"]) + + user = await self.db.lookup_user(username) + if user is None: + return None + + return { + "username": user.username, + "permissions": self.return_perms(user.permissions), + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_get_all_users(self, request): + users = await self.db.get_all_users() + return { + "users": [ + { + "username": u.username, + "permissions": self.return_perms(u.permissions), + } + for u in users + ] + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_new_user(self, request): + username = str(request["username"]) + permissions = self.get_perm_arg(request["permissions"]) + + token = await new_token() + + inserted = await self.db.new_user( + username, + permissions, + hash_token(TOKEN_ALGORITHM, new_salt(), token), + ) + if not inserted: + raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'") + + return { + "username": username, + "permissions": self.return_perms(permissions), + "token": token, + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_delete_user(self, request): + username = str(request["username"]) + + if not await self.db.delete_user(username): + self.raise_no_user_error(username) + + return {"username": username} + class Server(bb.asyncrpc.AsyncServer): - def __init__(self, db_engine, upstream=None, read_only=False): + def __init__( + self, + db_engine, + upstream=None, + read_only=False, + anon_perms=DEFAULT_ANON_PERMS, + admin_username=None, + admin_password=None, + ): if upstream and read_only: raise bb.asyncrpc.ServerError( "Read-only hashserv cannot pull from an upstream server" ) + disallowed_perms = set(anon_perms) - set( + [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM] + ) + + if disallowed_perms: + raise bb.asyncrpc.ServerError( + f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users" + ) + super().__init__(logger) self.request_stats = Stats() @@ -422,6 +739,13 @@ class Server(bb.asyncrpc.AsyncServer): self.upstream = upstream self.read_only = read_only self.backfill_queue = None + self.anon_perms = set(anon_perms) + self.admin_username = admin_username + self.admin_password = admin_password + + self.logger.info( + "Anonymous user permissions are: %s", ", ".join(self.anon_perms) + ) def accept_client(self, socket): return ServerClient( @@ -431,12 +755,34 @@ class Server(bb.asyncrpc.AsyncServer): self.backfill_queue, self.upstream, self.read_only, + self.anon_perms, ) + async def create_admin_user(self): + admin_permissions = (ALL_PERM,) + async with self.db_engine.connect(self.logger) as db: + added = await db.new_user( + self.admin_username, + admin_permissions, + hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), + ) + if added: + self.logger.info("Created admin user '%s'", self.admin_username) + else: + await db.set_user_perms( + self.admin_username, + admin_permissions, + ) + await db.set_user_token( + self.admin_username, + hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), + ) + self.logger.info("Admin user '%s' updated", self.admin_username) + async def backfill_worker_task(self): async with await create_async_client( self.upstream - ) as client, self.db_engine.connect(logger) as db: + ) as client, self.db_engine.connect(self.logger) as db: while True: item = await self.backfill_queue.get() if item is None: @@ -457,6 +803,9 @@ class Server(bb.asyncrpc.AsyncServer): self.loop.run_until_complete(self.db_engine.create()) + if self.admin_username: + self.loop.run_until_complete(self.create_admin_user()) + return tasks async def stop(self): diff --git a/bitbake/lib/hashserv/sqlalchemy.py b/bitbake/lib/hashserv/sqlalchemy.py index 3216621f9d..bfd8a8446e 100644 --- a/bitbake/lib/hashserv/sqlalchemy.py +++ b/bitbake/lib/hashserv/sqlalchemy.py @@ -7,6 +7,7 @@ import logging from datetime import datetime +from . import User from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import NullPool @@ -25,13 +26,12 @@ from sqlalchemy import ( literal, and_, delete, + update, ) import sqlalchemy.engine from sqlalchemy.orm import declarative_base from sqlalchemy.exc import IntegrityError -logger = logging.getLogger("hashserv.sqlalchemy") - Base = declarative_base() @@ -68,9 +68,19 @@ class OuthashesV2(Base): ) +class Users(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(Text, nullable=False) + token = Column(Text, nullable=False) + permissions = Column(Text) + + __table_args__ = (UniqueConstraint("username"),) + + class DatabaseEngine(object): def __init__(self, url, username=None, password=None): - self.logger = logger + self.logger = logging.getLogger("hashserv.sqlalchemy") self.url = sqlalchemy.engine.make_url(url) if username is not None: @@ -85,7 +95,7 @@ class DatabaseEngine(object): async with self.engine.begin() as conn: # Create tables - logger.info("Creating tables...") + self.logger.info("Creating tables...") await conn.run_sync(Base.metadata.create_all) def connect(self, logger): @@ -98,6 +108,15 @@ def map_row(row): return dict(**row._mapping) +def map_user(row): + if row is None: + return None + return User( + username=row.username, + permissions=set(row.permissions.split()), + ) + + class Database(object): def __init__(self, engine, logger): self.engine = engine @@ -278,7 +297,7 @@ class Database(object): await self.db.execute(statement) return True except IntegrityError: - logger.debug( + self.logger.debug( "%s, %s, %s already in unihash database", method, taskhash, unihash ) return False @@ -298,7 +317,87 @@ class Database(object): await self.db.execute(statement) return True except IntegrityError: - logger.debug( + self.logger.debug( "%s, %s already in outhash database", data["method"], data["outhash"] ) return False + + async def _get_user(self, username): + statement = select( + Users.username, + Users.permissions, + Users.token, + ).where( + Users.username == username, + ) + self.logger.debug("%s", statement) + async with self.db.begin(): + result = await self.db.execute(statement) + return result.first() + + async def lookup_user_token(self, username): + row = await self._get_user(username) + if not row: + return None, None + return map_user(row), row.token + + async def lookup_user(self, username): + return map_user(await self._get_user(username)) + + async def set_user_token(self, username, token): + statement = ( + update(Users) + .where( + Users.username == username, + ) + .values( + token=token, + ) + ) + self.logger.debug("%s", statement) + async with self.db.begin(): + result = await self.db.execute(statement) + return result.rowcount != 0 + + async def set_user_perms(self, username, permissions): + statement = ( + update(Users) + .where(Users.username == username) + .values(permissions=" ".join(permissions)) + ) + self.logger.debug("%s", statement) + async with self.db.begin(): + result = await self.db.execute(statement) + return result.rowcount != 0 + + async def get_all_users(self): + statement = select( + Users.username, + Users.permissions, + ) + self.logger.debug("%s", statement) + async with self.db.begin(): + result = await self.db.execute(statement) + return [map_user(row) for row in result] + + async def new_user(self, username, permissions, token): + statement = insert(Users).values( + username=username, + permissions=" ".join(permissions), + token=token, + ) + self.logger.debug("%s", statement) + try: + async with self.db.begin(): + await self.db.execute(statement) + return True + except IntegrityError as e: + self.logger.debug("Cannot create new user %s: %s", username, e) + return False + + async def delete_user(self, username): + statement = delete(Users).where(Users.username == username) + self.logger.debug("%s", statement) + async with self.db.begin(): + result = await self.db.execute(statement) + return result.rowcount != 0 diff --git a/bitbake/lib/hashserv/sqlite.py b/bitbake/lib/hashserv/sqlite.py index 6809c53706..414ee8ffb8 100644 --- a/bitbake/lib/hashserv/sqlite.py +++ b/bitbake/lib/hashserv/sqlite.py @@ -7,6 +7,7 @@ import sqlite3 import logging from contextlib import closing +from . import User logger = logging.getLogger("hashserv.sqlite") @@ -34,6 +35,14 @@ OUTHASH_TABLE_DEFINITION = ( OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION) +USERS_TABLE_DEFINITION = ( + ("username", "TEXT NOT NULL", "UNIQUE"), + ("token", "TEXT NOT NULL", ""), + ("permissions", "TEXT NOT NULL", ""), +) + +USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION) + def _make_table(cursor, name, definition): cursor.execute( @@ -53,6 +62,15 @@ def _make_table(cursor, name, definition): ) +def map_user(row): + if row is None: + return None + return User( + username=row["username"], + permissions=set(row["permissions"].split()), + ) + + class DatabaseEngine(object): def __init__(self, dbname, sync): self.dbname = dbname @@ -66,6 +84,7 @@ class DatabaseEngine(object): with closing(db.cursor()) as cursor: _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION) _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION) + _make_table(cursor, "users", USERS_TABLE_DEFINITION) cursor.execute("PRAGMA journal_mode = WAL") cursor.execute( @@ -227,6 +246,7 @@ class Database(object): "oldest": oldest, }, ) + self.db.commit() return cursor.rowcount async def insert_unihash(self, method, taskhash, unihash): @@ -257,3 +277,88 @@ class Database(object): cursor.execute(query, data) self.db.commit() return cursor.lastrowid != prevrowid + + def _get_user(self, username): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT username, permissions, token FROM users WHERE username=:username + """, + { + "username": username, + }, + ) + return cursor.fetchone() + + async def lookup_user_token(self, username): + row = self._get_user(username) + if row is None: + return None, None + return map_user(row), row["token"] + + async def lookup_user(self, username): + return map_user(self._get_user(username)) + + async def set_user_token(self, username, token): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + UPDATE users SET token=:token WHERE username=:username + """, + { + "username": username, + "token": token, + }, + ) + self.db.commit() + return cursor.rowcount != 0 + + async def set_user_perms(self, username, permissions): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + UPDATE users SET permissions=:permissions WHERE username=:username + """, + { + "username": username, + "permissions": " ".join(permissions), + }, + ) + self.db.commit() + return cursor.rowcount != 0 + + async def get_all_users(self): + with closing(self.db.cursor()) as cursor: + cursor.execute("SELECT username, permissions FROM users") + return [map_user(r) for r in cursor.fetchall()] + + async def new_user(self, username, permissions, token): + with closing(self.db.cursor()) as cursor: + try: + cursor.execute( + """ + INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions) + """, + { + "username": username, + "token": token, + "permissions": " ".join(permissions), + }, + ) + self.db.commit() + return True + except sqlite3.IntegrityError: + return False + + async def delete_user(self, username): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + DELETE FROM users WHERE username=:username + """, + { + "username": username, + }, + ) + self.db.commit() + return cursor.rowcount != 0 diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py index e9a361dc4b..f92f37c459 100644 --- a/bitbake/lib/hashserv/tests.py +++ b/bitbake/lib/hashserv/tests.py @@ -6,6 +6,8 @@ # from . import create_server, create_client +from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS +from bb.asyncrpc import InvokeError import hashlib import logging import multiprocessing @@ -29,8 +31,9 @@ class HashEquivalenceTestSetup(object): METHOD = 'TestMethod' server_index = 0 + client_index = 0 - def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc): + def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None): self.server_index += 1 if dbpath is None: dbpath = self.make_dbpath() @@ -45,7 +48,10 @@ class HashEquivalenceTestSetup(object): server = create_server(self.get_server_addr(self.server_index), dbpath, upstream=upstream, - read_only=read_only) + read_only=read_only, + anon_perms=anon_perms, + admin_username=admin_username, + admin_password=admin_password) server.dbpath = dbpath server.serve_as_process(prefunc=prefunc, args=(self.server_index,)) @@ -56,18 +62,31 @@ class HashEquivalenceTestSetup(object): def make_dbpath(self): return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) - def start_client(self, server_address): + def start_client(self, server_address, username=None, password=None): def cleanup_client(client): client.close() - client = create_client(server_address) + client = create_client(server_address, username=username, password=password) self.addCleanup(cleanup_client, client) return client def start_test_server(self): - server = self.start_server() - return server.address + self.server = self.start_server() + return self.server.address + + def start_auth_server(self): + self.auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password") + self.admin_client = self.start_client(self.auth_server.address, username="admin", password="password") + return self.admin_client + + def auth_client(self, user): + return self.start_client(self.auth_server.address, user["username"], user["token"]) + + def auth_perms(self, *permissions): + self.client_index += 1 + user = self.admin_client.new_user(f"user-{self.client_index}", permissions) + return self.auth_client(user) def setUp(self): if sys.version_info < (3, 5, 0): @@ -86,18 +105,21 @@ class HashEquivalenceTestSetup(object): class HashEquivalenceCommonTests(object): - def test_create_hash(self): + def create_test_hash(self, client): # Simple test that hashes can be created taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' - self.assertClientGetHash(self.client, taskhash, None) + self.assertClientGetHash(client, taskhash, None) - result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + result = client.report_unihash(taskhash, self.METHOD, outhash, unihash) self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') return taskhash, outhash, unihash + def test_create_hash(self): + return self.create_test_hash(self.client) + def test_create_equivalent(self): # Tests that a second reported task with the same outhash will be # assigned the same unihash @@ -471,6 +493,242 @@ class HashEquivalenceCommonTests(object): # shares a taskhash with Task 2 self.assertClientGetHash(self.client, taskhash2, unihash2) + def test_auth_read_perms(self): + admin_client = self.start_auth_server() + + # Create hashes with non-authenticated server + taskhash, outhash, unihash = self.test_create_hash() + + # Validate hash can be retrieved using authenticated client + with self.auth_perms("@read") as client: + self.assertClientGetHash(client, taskhash, unihash) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + self.assertClientGetHash(client, taskhash, unihash) + + def test_auth_report_perms(self): + admin_client = self.start_auth_server() + + # Without read permission, the user is completely denied + with self.auth_perms() as client, self.assertRaises(InvokeError): + self.create_test_hash(client) + + # Read permission allows the call to succeed, but it doesn't record + # anythin in the database + with self.auth_perms("@read") as client: + taskhash, outhash, unihash = self.create_test_hash(client) + self.assertClientGetHash(client, taskhash, None) + + # Report permission alone is insufficient + with self.auth_perms("@report") as client, self.assertRaises(InvokeError): + self.create_test_hash(client) + + # Read and report permission actually modify the database + with self.auth_perms("@read", "@report") as client: + taskhash, outhash, unihash = self.create_test_hash(client) + self.assertClientGetHash(client, taskhash, unihash) + + def test_auth_no_token_refresh_from_anon_user(self): + self.start_auth_server() + + with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError): + client.refresh_token() + + def assertUserCanAuth(self, user): + with self.start_client(self.auth_server.address) as client: + client.auth(user["username"], user["token"]) + + def assertUserCannotAuth(self, user): + with self.start_client(self.auth_server.address) as client, self.assertRaises(InvokeError): + client.auth(user["username"], user["token"]) + + def test_auth_self_token_refresh(self): + admin_client = self.start_auth_server() + + # Create a new user with no permissions + user = admin_client.new_user("test-user", []) + + with self.auth_client(user) as client: + new_user = client.refresh_token() + + self.assertEqual(user["username"], new_user["username"]) + self.assertNotEqual(user["token"], new_user["token"]) + self.assertUserCanAuth(new_user) + self.assertUserCannotAuth(user) + + # Explicitly specifying with your own username is fine also + with self.auth_client(new_user) as client: + new_user2 = client.refresh_token(user["username"]) + + self.assertEqual(user["username"], new_user2["username"]) + self.assertNotEqual(user["token"], new_user2["token"]) + self.assertUserCanAuth(new_user2) + self.assertUserCannotAuth(new_user) + self.assertUserCannotAuth(user) + + def test_auth_token_refresh(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", []) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.refresh_token(user["username"]) + + with self.auth_perms("@user-admin") as client: + new_user = client.refresh_token(user["username"]) + + self.assertEqual(user["username"], new_user["username"]) + self.assertNotEqual(user["token"], new_user["token"]) + self.assertUserCanAuth(new_user) + self.assertUserCannotAuth(user) + + def test_auth_self_get_user(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", []) + user_info = user.copy() + del user_info["token"] + + with self.auth_client(user) as client: + info = client.get_user() + self.assertEqual(info, user_info) + + # Explicitly asking for your own username is fine also + info = client.get_user(user["username"]) + self.assertEqual(info, user_info) + + def test_auth_get_user(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", []) + user_info = user.copy() + del user_info["token"] + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.get_user(user["username"]) + + with self.auth_perms("@user-admin") as client: + info = client.get_user(user["username"]) + self.assertEqual(info, user_info) + + info = client.get_user("nonexist-user") + self.assertIsNone(info) + + def test_auth_reconnect(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", []) + user_info = user.copy() + del user_info["token"] + + with self.auth_client(user) as client: + info = client.get_user() + self.assertEqual(info, user_info) + + client.disconnect() + + info = client.get_user() + self.assertEqual(info, user_info) + + def test_auth_delete_user(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", []) + + # No self service + with self.auth_client(user) as client, self.assertRaises(InvokeError): + client.delete_user(user["username"]) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.delete_user(user["username"]) + + with self.auth_perms("@user-admin") as client: + client.delete_user(user["username"]) + + # User doesn't exist, so even though the permission is correct, it's an + # error + with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError): + client.delete_user(user["username"]) + + def assertUserPerms(self, user, permissions): + with self.auth_client(user) as client: + info = client.get_user() + self.assertEqual(info, { + "username": user["username"], + "permissions": permissions, + }) + + def test_auth_set_user_perms(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", []) + + self.assertUserPerms(user, []) + + # No self service to change permissions + with self.auth_client(user) as client, self.assertRaises(InvokeError): + client.set_user_perms(user["username"], ["@all"]) + self.assertUserPerms(user, []) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.set_user_perms(user["username"], ["@all"]) + self.assertUserPerms(user, []) + + with self.auth_perms("@user-admin") as client: + client.set_user_perms(user["username"], ["@all"]) + self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS))) + + # Bad permissions + with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError): + client.set_user_perms(user["username"], ["@this-is-not-a-permission"]) + self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS))) + + def test_auth_get_all_users(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", []) + + with self.auth_client(user) as client, self.assertRaises(InvokeError): + client.get_all_users() + + # Give the test user the correct permission + admin_client.set_user_perms(user["username"], ["@user-admin"]) + + with self.auth_client(user) as client: + all_users = client.get_all_users() + + # Convert to a dictionary for easier comparison + all_users = {u["username"]: u for u in all_users} + + self.assertEqual(all_users, + { + "admin": { + "username": "admin", + "permissions": sorted(list(ALL_PERMISSIONS)), + }, + "test-user": { + "username": "test-user", + "permissions": ["@user-admin"], + } + } + ) + + def test_auth_new_user(self): + self.start_auth_server() + + permissions = ["@read", "@report", "@db-admin", "@user-admin"] + permissions.sort() + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.new_user("test-user", permissions) + + with self.auth_perms("@user-admin") as client: + user = client.new_user("test-user", permissions) + self.assertIn("token", user) + self.assertEqual(user["username"], "test-user") + self.assertEqual(user["permissions"], permissions) + + class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): def get_server_addr(self, server_idx): return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)