Files
poky/bitbake/bin/bitbake-hashclient
Joshua Watt 6c7c9b1146 bitbake: asyncrpc: Add context manager API
Adds context manager API for the asyncrcp client class which allow
writing code that will automatically close the connection like so:

    with hashserv.create_client(address) as client:
       ...

Rework the bitbake-hashclient tool and PR server to use this new API to
fix warnings about unclosed event loops when exiting

(Bitbake rev: ee090484cc25d760b8c20f18add17b5eff485b40)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
(cherry picked from commit d01d684a0f6398270fe35ed59b7d28f3fd9b7e41)
Signed-off-by: Steve Sakoman <steve@sakoman.com>
2024-01-10 05:14:16 -10:00

5.4 KiB
Executable File

#! /usr/bin/env python3

Copyright (C) 2019 Garmin Ltd.

SPDX-License-Identifier: GPL-2.0-only

import argparse import hashlib import logging import os import pprint import sys import threading import time import warnings warnings.simplefilter("default")

try: import tqdm ProgressBar = tqdm.tqdm except ImportError: class ProgressBar(object): def init(self, *args, **kwargs): pass

    def __enter__(self):
        return self

    def __exit__(self, *args, **kwargs):
        pass

    def update(self):
        pass

sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(file)), 'lib'))

import hashserv

DEFAULT_ADDRESS = 'unix://./hashserve.sock' METHOD = 'stress.test.method'

def main(): def handle_stats(args, client): if args.reset: s = client.reset_stats() else: s = client.get_stats() pprint.pprint(s) return 0

def handle_stress(args, client):
    def thread_main(pbar, lock):
        nonlocal found_hashes
        nonlocal missed_hashes
        nonlocal max_time

        with hashserv.create_client(args.address) as client:
            for i in range(args.requests):
                taskhash = hashlib.sha256()
                taskhash.update(args.taskhash_seed.encode('utf-8'))
                taskhash.update(str(i).encode('utf-8'))

                start_time = time.perf_counter()
                l = client.get_unihash(METHOD, taskhash.hexdigest())
                elapsed = time.perf_counter() - start_time

                with lock:
                    if l:
                        found_hashes += 1
                    else:
                        missed_hashes += 1

                    max_time = max(elapsed, max_time)
                    pbar.update()

    max_time = 0
    found_hashes = 0
    missed_hashes = 0
    lock = threading.Lock()
    total_requests = args.clients * args.requests
    start_time = time.perf_counter()
    with ProgressBar(total=total_requests) as pbar:
        threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)]
        for t in threads:
            t.start()

        for t in threads:
            t.join()

    elapsed = time.perf_counter() - start_time
    with lock:
        print("%d requests in %.1fs. %.1f requests per second" % (total_requests, elapsed, total_requests / elapsed))
        print("Average request time %.8fs" % (elapsed / total_requests))
        print("Max request time was %.8fs" % max_time)
        print("Found %d hashes, missed %d" % (found_hashes, missed_hashes))

    if args.report:
        with ProgressBar(total=args.requests) as pbar:
            for i in range(args.requests):
                taskhash = hashlib.sha256()
                taskhash.update(args.taskhash_seed.encode('utf-8'))
                taskhash.update(str(i).encode('utf-8'))

                outhash = hashlib.sha256()
                outhash.update(args.outhash_seed.encode('utf-8'))
                outhash.update(str(i).encode('utf-8'))

                client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest())

                with lock:
                    pbar.update()

parser = argparse.ArgumentParser(description='Hash Equivalence Client')
parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
parser.add_argument('--log', default='WARNING', help='Set logging level')

subparsers = parser.add_subparsers()

stats_parser = subparsers.add_parser('stats', help='Show server stats')
stats_parser.add_argument('--reset', action='store_true',
                          help='Reset server stats')
stats_parser.set_defaults(func=handle_stats)

stress_parser = subparsers.add_parser('stress', help='Run stress test')
stress_parser.add_argument('--clients', type=int, default=10,
                           help='Number of simultaneous clients')
stress_parser.add_argument('--requests', type=int, default=1000,
                           help='Number of requests each client will perform')
stress_parser.add_argument('--report', action='store_true',
                           help='Report new hashes')
stress_parser.add_argument('--taskhash-seed', default='',
                           help='Include string in taskhash')
stress_parser.add_argument('--outhash-seed', default='',
                           help='Include string in outhash')
stress_parser.set_defaults(func=handle_stress)

args = parser.parse_args()

logger = logging.getLogger('hashserv')

level = getattr(logging, args.log.upper(), None)
if not isinstance(level, int):
    raise ValueError('Invalid log level: %s' % args.log)

logger.setLevel(level)
console = logging.StreamHandler()
console.setLevel(level)
logger.addHandler(console)

func = getattr(args, 'func', None)
if func:
    with hashserv.create_client(args.address) as client:
        return func(args, client)

return 0

if name == 'main': try: ret = main() except Exception: ret = 1 import traceback traceback.print_exc() sys.exit(ret)