diff --git a/Makefile b/Makefile index 64feb21..6b6856c 100644 --- a/Makefile +++ b/Makefile @@ -2,3 +2,4 @@ TOPDIR = . include $(TOPDIR)/make/proj.mk include $(JWBDIR)/make/topdir.mk +include $(JWBDIR)/make/py-topdir.mk diff --git a/src/python/jw/util/ArgsContainer.py b/src/python/jw/util/ArgsContainer.py index fdc02af..45363a7 100644 --- a/src/python/jw/util/ArgsContainer.py +++ b/src/python/jw/util/ArgsContainer.py @@ -1,20 +1,17 @@ -# -*- coding: utf-8 -*- - import argparse from collections import OrderedDict -from .log import slog +from .log import get_caller_pos, slog -class ArgsContainer: # export +class ArgsContainer: # export __args: OrderedDict[str, str] = OrderedDict() - __kwargs: OrderedDict[str, str] = OrderedDict() + __kwargs: OrderedDict[str, dict[str, str]] = OrderedDict() __values: dict[str, str] = {} __specified_args: list[str] = list() - def __getattr__(self, name): - values = self.__values + def __getattr__(self, name: str) -> str: if name in self.__values: return self.__values[name] if name in self.__kwargs.keys(): @@ -25,7 +22,7 @@ class ArgsContainer: # export raise Exception(f'No argument "{name}" defined') def __setattr__(self, name, value): - if not name in self.__kwargs.keys(): + if name not in self.__kwargs.keys(): raise Exception(f'No argument "{name}" defined') self.__values[name] = value self.__specified_args.append(name) @@ -41,7 +38,7 @@ class ArgsContainer: # export else: raise Exception('Missing argument name') name = name.replace('-', '_') - self.__args[name] = args + self.__args[name] = arg self.__kwargs[name] = kwargs def keys(self): @@ -50,7 +47,7 @@ class ArgsContainer: # export def args(self, name) -> str: return self.__args[name] - def kwargs(self, name) -> str: + def kwargs(self, name) -> dict[str, str]: return self.__kwargs[name] def dump(self, prio, *args, **kwargs): @@ -59,15 +56,17 @@ class ArgsContainer: # export val = None try: val = self.__getattr__(name) - except: + except Exception: pass - slog(prio, f'{name}: {val}', caller=caller) + slog(prio, f'{name}: {val}', caller = caller) @property def specified_args(self): return self.__specified_args -def add_argument(p: argparse.ArgumentParser|ArgsContainer, name: str, *args, **kwargs): # export +def add_argument( # export + p: argparse.ArgumentParser | ArgsContainer, name: str, *args, **kwargs +): key = name.strip('--').replace('-', '_') if isinstance(p, ArgsContainer): diff --git a/src/python/jw/util/Bunch.py b/src/python/jw/util/Bunch.py index dc49a44..6ccd6bd 100644 --- a/src/python/jw/util/Bunch.py +++ b/src/python/jw/util/Bunch.py @@ -1,8 +1,6 @@ -# -*- coding: utf-8 -*- - from typing import Any -class Bunch: # export +class Bunch: # export def __init__(self, **kwargs): self.__dict__.update(kwargs) diff --git a/src/python/jw/util/Cmd.py b/src/python/jw/util/Cmd.py index 0903a7b..305a454 100644 --- a/src/python/jw/util/Cmd.py +++ b/src/python/jw/util/Cmd.py @@ -1,39 +1,49 @@ -# -*- coding: utf-8 -*- - from __future__ import annotations -import inspect, sys, re, abc, argparse -from argparse import ArgumentParser, _SubParsersAction + +import abc +import argparse +import inspect +import re +import sys + +from argparse import ArgumentParser +from typing import TYPE_CHECKING from . import log +if TYPE_CHECKING: + from .Cmds import Cmds + # full blown example of one level of nested subcommands # git -C project remote -v show -n myremote -class Cmd(abc.ABC): # export +class Cmd(abc.ABC): # export @abc.abstractmethod async def run(self, args): pass def __init__(self, name: str, help: str) -> None: - from . import Cmds self.name = name self.help = help self.parent = None self.children: list[Cmd] = [] self.child_classes: list[type[Cmd]] = [] - self.app: Cmds|None = None + self.app: Cmds | None = None async def _run(self, args): pass def add_parser(self, parsers) -> ArgumentParser: - r = parsers.add_parser(self.name, help=self.help, - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - r.set_defaults(func=self.run) + r = parsers.add_parser( + self.name, + help = self.help, + formatter_class = argparse.ArgumentDefaultsHelpFormatter + ) + r.set_defaults(func = self.run) return r - def add_subcommands(self, cmd: str|type[Cmd]|list[type[Cmd]]) -> None: + def add_subcommands(self, cmd: str | type[Cmd] | list[type[Cmd]]) -> None: if isinstance(cmd, str): sc = [] for name, obj in inspect.getmembers(sys.modules[self.__class__.__module__]): @@ -54,7 +64,7 @@ class Cmd(abc.ABC): # export def add_arguments(self, parser: ArgumentParser) -> None: pass - def conf_value(self, path, default=None): + def conf_value(self, path, default = None): ret = None if self.app is None else self.app.conf_value(path, default) if ret is None and default is not None: return default diff --git a/src/python/jw/util/Cmds.py b/src/python/jw/util/Cmds.py index 13f07fa..64ebe19 100644 --- a/src/python/jw/util/Cmds.py +++ b/src/python/jw/util/Cmds.py @@ -1,13 +1,21 @@ -# -*- coding: utf-8 -*- +import argparse +import asyncio +import cProfile +import importlib +import inspect +import os +import re +import sys -import os, sys, argcomplete, argparse, importlib, inspect, re, pickle, asyncio, cProfile from argparse import ArgumentParser from pathlib import Path, PurePath -from .log import * +import argcomplete + +from .log import DEBUG, ERR, NOTICE, add_log_file, set_flags, set_level, slog from .stree import serdes -class Cmds: # export +class Cmds: # export def __instantiate(self, cls): try: @@ -15,7 +23,7 @@ class Cmds: # export except Exception as e: slog(ERR, f'Failed to instantiate command of type {cls}: {e}') raise - r.cmds = self # TODO: Rename Cmds class to App, "Cmds" isn't very self-explanatory + r.cmds = self # TODO: Rename Cmds class to App, "Cmds" isn't self-explanatory r.app = self return r @@ -26,7 +34,9 @@ class Cmds: # export for c in cmd.child_classes: cmd.children.append(self.__instantiate(c)) if len(cmd.children) > 0: - subparsers = parser.add_subparsers(title='Available subcommands of ' + cmd.name, metavar='') + subparsers = parser.add_subparsers( + title = 'Available subcommands of ' + cmd.name, metavar = '' + ) for sub_cmd in cmd.children: self.__add_cmd_to_parser(sub_cmd, subparsers) @@ -38,7 +48,13 @@ class Cmds: # export slog(DEBUG, 'Reading configuration "{}"'.format(path)) return serdes.read(path, ''), [path] - def __init__(self, description: str = '', filter: str = '^Cmd.*', modules: None=None, eloop: None=None) -> None: + def __init__( + self, + description: str = '', + filter: str = '^Cmd.*', + modules: None = None, + eloop: None = None + ) -> None: self.__description = description self.__filter = filter self.__modules = modules @@ -68,18 +84,34 @@ class Cmds: # export set_flags(log_flags) set_level(log_level) slog(DEBUG, "set log level to {}".format(log_level)) - self.__parser = argparse.ArgumentParser(usage=os.path.basename(sys.argv[0]) + ' [options]', - formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=self.__description) - self.__parser.add_argument('--log-flags', help='Log flags', default=log_flags) - self.__parser.add_argument('--log-level', help='Log level', default=log_level) - self.__parser.add_argument('--backtrace', help='Show exception backtraces', action='store_true', default=False) - self.__parser.add_argument('--write-profile', help='Profile code and store output to file', default=None) - self.__parser.add_argument('--log-file', help='Log file', default=log_file) - if self.__modules == None: - self.__modules = [ '__main__' ] + self.__parser = argparse.ArgumentParser( + usage = os.path.basename(sys.argv[0]) + ' [options]', + formatter_class = argparse.ArgumentDefaultsHelpFormatter, + description = self.__description + ) + self.__parser.add_argument( + '--log-flags', help = 'Log flags', default = log_flags + ) + self.__parser.add_argument( + '--log-level', help = 'Log level', default = log_level + ) + self.__parser.add_argument( + '--backtrace', + help = 'Show exception backtraces', + action = 'store_true', + default = False + ) + self.__parser.add_argument( + '--write-profile', + help = 'Profile code and store output to file', + default = None + ) + self.__parser.add_argument('--log-file', help = 'Log file', default = log_file) + if self.__modules is None: + self.__modules = ['__main__'] subcmds = set() slog(DEBUG, '-- searching for commands') - for m in self.__modules: # type: ignore + for m in self.__modules: # type: ignore if m != '__main__': importlib.import_module(m) for name, c in inspect.getmembers(sys.modules[m], inspect.isclass): @@ -96,24 +128,27 @@ class Cmds: # export subcmds.update(cmd.child_classes) cmds = [cmd for cmd in self.__cmds if type(cmd) not in subcmds] - subparsers = self.__parser.add_subparsers(title='Available commands', metavar='') + subparsers = self.__parser.add_subparsers( + title = 'Available commands', metavar = '' + ) for cmd in cmds: slog(DEBUG, f'Adding top-level command {cmd} to parser') self.__add_cmd_to_parser(cmd, subparsers) # Run all sub-commands. Overwrite if you want to do anything before or after - async def _run(self, argv=None): + async def _run(self, argv = None): return await self.args.func(self.args) - async def __run(self, argv=None): + async def __run(self, argv = None): argcomplete.autocomplete(self.__parser) - self.args = self.__parser.parse_args(args=argv) + self.args = self.__parser.parse_args(args = argv) set_flags(self.args.log_flags) set_level(self.args.log_level) self.__back_trace = self.args.backtrace exit_status = 0 - # This is the toplevel parser, i.e. no func member has been added to the args via + # This is the toplevel parser, i.e. no func member has been added to the args + # yet, via # # Cmds.__init__() # Cmds.__add_cmd_to_parser(cmd, subparsers) @@ -135,17 +170,15 @@ class Cmds: # export add_log_file(self.args.log_file) try: - ret = await self._run(self.args) + await self._run(self.args) except Exception as e: - if hasattr(e, 'message'): - slog(ERR, e.message) - else: - slog(ERR, f'Exception: {type(e)}: {e}') + slog(ERR, f'Exception: {type(e)}: {str(e)}') exit_status = 1 if self.__back_trace: raise finally: if pr is not None: + assert self.args.write_profile is not None, 'args.write_profile' pr.disable() slog(NOTICE, f'Writing profile statistics to {self.args.write_profile}') pr.dump_stats(self.args.write_profile) @@ -160,7 +193,7 @@ class Cmds: # export self.eloop = None self.__own_eloop = False - def conf_value(self, path, default=None): + def conf_value(self, path, default = None): ret = None if self.__conf is None else self.__conf.value(path) if ret is None and default is not None: return default @@ -169,10 +202,12 @@ class Cmds: # export def parser(self) -> ArgumentParser: return self.__parser - def run(self, argv=None) -> None: + def run(self, argv = None) -> None: #return self.__run() - return self.eloop.run_until_complete(self.__run(argv)) # type: ignore + return self.eloop.run_until_complete(self.__run(argv)) # type: ignore -def run_sub_commands(description = '', filter = '^Cmd.*', modules=None, argv=None): # export +def run_sub_commands( # export + description = '', filter = '^Cmd.*', modules = None, argv = None +): cmds = Cmds(description, filter, modules) - return cmds.run(argv=argv) + return cmds.run(argv = argv) diff --git a/src/python/jw/util/Config.py b/src/python/jw/util/Config.py index b4eed8f..fbabfe1 100644 --- a/src/python/jw/util/Config.py +++ b/src/python/jw/util/Config.py @@ -1,14 +1,16 @@ -# -*- coding: utf-8 -*- +import glob +import os +import re +import sys -from typing import Optional, Dict, cast -import os, re, glob, sys -from pathlib import Path, PosixPath +from pathlib import Path +from typing import Dict, Optional, cast -from . import stree +from .stree import serdes +from .log import DEBUG, ERR, slog, get_caller_pos from .stree.StringTree import StringTree -from .log import * -class Config(): # export +class Config(): # export def __load(self, search_dirs, glob_paths, refuse_mode_mask): @@ -33,15 +35,15 @@ class Config(): # export for path in glob_paths: dirs = search_dirs if dirs is None: - dirs = [''] if __is_abs(path) else [ str(Path.home()), str(Path.cwd()) ] + dirs = [''] if __is_abs(path) else [str(Path.home()), str(Path.cwd())] for d in dirs: g = d + '/' + path if len(d) else path slog(DEBUG, 'Looking for config "{}"'.format(g)) for f in glob.glob(g): slog(DEBUG, 'Reading config "{}"'.format(f)) paths_buf = [] - tree = stree.read(f, paths_buf=paths_buf) - assert(len(paths_buf)) + tree = serdes.read(f, paths_buf = paths_buf) + assert (len(paths_buf)) if refuse_mode_mask is not None: for p in paths_buf: st = os.stat(p) @@ -49,37 +51,44 @@ class Config(): # export for item in tree.child_list(): if item.content is None: continue - if not re.search('password|secret', cast(str, item.content), flags=re.IGNORECASE): + if not re.search('password|secret', + cast('str', item.content), + flags = re.IGNORECASE): continue - msg = "Config files define secret, but at least one has file permissions open for world" + msg = ( + 'Config files define secret, but at least one ' + 'has file permissions open for world' + ) slog(ERR, f'{msg}:') for pp in paths_buf: - slog(ERR, f' {((os.stat(pp).st_mode) & 0o7777):o} {pp}') + mode = (os.stat(pp).st_mode) & 0o7777 + slog(ERR, f' {mode:o} {pp}') raise Exception(msg) tree.dump(DEBUG, f) ret.add("", tree) return ret - def __init__(self, - search_dirs: Optional[list[str]]=None, - glob_paths: Optional[list[str]]=None, - glob_paths_env_key: Optional[str]=None, - defaults: Optional[Dict[str, str]]=None, - tree: Optional[StringTree]=None, - parent=None, - root_section=None, - refuse_mode_mask=0o0027 - ) -> None: + def __init__( + self, + search_dirs: Optional[list[str]] = None, + glob_paths: Optional[list[str]] = None, + glob_paths_env_key: Optional[str] = None, + defaults: Optional[Dict[str, str]] = None, + tree: Optional[StringTree] = None, + parent = None, + root_section = None, + refuse_mode_mask = 0o0027 + ) -> None: self.__parent = parent if tree is not None: - assert(search_dirs is None) - assert(glob_paths is None) - assert(glob_paths_env_key is None) + assert (search_dirs is None) + assert (glob_paths is None) + assert (glob_paths_env_key is None) self.__conf = tree else: - assert(tree is None) + assert (tree is None) if glob_paths_env_key is not None: glob_paths_env = os.getenv(glob_paths_env_key) if glob_paths_env is not None: @@ -87,8 +96,11 @@ class Config(): # export glob_paths = [] glob_paths.extend(glob_paths_env.split(':')) - self.__conf = self.__load(search_dirs=search_dirs, glob_paths=glob_paths, - refuse_mode_mask=refuse_mode_mask) + self.__conf = self.__load( + search_dirs = search_dirs, + glob_paths = glob_paths, + refuse_mode_mask = refuse_mode_mask + ) if root_section is not None: tmp = self.__conf.get(root_section) @@ -141,7 +153,11 @@ class Config(): # export def value(self, key: str, default = None) -> Optional[str]: return self.get(key, default) - def branch(self, path: str, throw: bool=True): # type: ignore # Optional[Config]: FIXME: Don't know how to get hold of this type here + def branch( + self, + path: str, + throw: bool = True + ): # type: ignore # Optional[Config]: FIXME: Don't know how to get hold of this type here if self.__conf: tree = self.__conf.get(path) if tree is None: @@ -151,19 +167,24 @@ class Config(): # export return None self.dump(ERR, msg) raise Exception(msg) - return Config(tree=tree, parent=self) # type: ignore + return Config(tree = tree, parent = self) # type: ignore return None def dump(self, prio: int, *args, **kwargs) -> None: caller = get_caller_pos(1, kwargs) - self.__conf.dump(prio, caller=caller, *args, **kwargs) + self.__conf.dump(prio, caller = caller, *args, **kwargs) @property def name(self): return self.__conf.content - def find(self, key: str|None, val: str|None, match:StringTree.Match=StringTree.Match.Equal) -> list[str]: - return self.__conf.find(key, val, match=match) + def find( + self, + key: str | None, + val: str | None, + match: StringTree.Match = StringTree.Match.Equal + ) -> list[str]: + return self.__conf.find(key, val, match = match) #def __getattr__(self, name: str): # return getattr(self.__conf, name) diff --git a/src/python/jw/util/CppState.py b/src/python/jw/util/CppState.py index d6cb5fe..772d192 100644 --- a/src/python/jw/util/CppState.py +++ b/src/python/jw/util/CppState.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -class CppState: # export +class CppState: # export def __init__(self): self.__pair_square = ['[', ']'] @@ -33,37 +31,39 @@ class CppState: # export self.things.append(self.__pair_square) elif tok == ']': self.square -= 1 - assert(self.things.pop() == self.__pair_square) + assert (self.things.pop() == self.__pair_square) elif tok == '{': self.curly += 1 self.things.append(self.__pair_curly) elif tok == '}': self.curly -= 1 - assert(self.things.pop() == self.__pair_curly) + assert (self.things.pop() == self.__pair_curly) elif tok == '(': self.paren += 1 self.things.append(self.__pair_paren) elif tok == ')': self.paren -= 1 - assert(self.things.pop() == self.__pair_paren) + assert (self.things.pop() == self.__pair_paren) elif tok == '<': self.ext += 1 self.things.append(self.__pair_ext) elif tok == '>': self.ext -= 1 - assert(self.things.pop() == self.__pair_ext) + assert (self.things.pop() == self.__pair_ext) elif tok == '?': if not self.in_special: self.in_special = True self.things.append(self.__pair_special) else: self.in_special = False - assert(self.things.pop() == self.__pair_special) + assert (self.things.pop() == self.__pair_special) elif tok == '/*': self.in_c_comment = True self.things.append(self.__pair_c_comment) elif tok == '*/': - raise Exception("Unmatched closing C-style comment mark", tok, "in line", line) + raise Exception( + "Unmatched closing C-style comment mark", tok, "in line", line + ) else: if self.in_cpp_comment: if tok == '\n': @@ -72,7 +72,7 @@ class CppState: # export if tok == '/*': raise Exception("Nested C-style comment", tok, "in line", line) elif tok == '*/': - assert(self.things.pop() == self.__pair_c_comment) + assert (self.things.pop() == self.__pair_c_comment) self.in_c_comment = False if self.curly < 0 or self.square < 0 or self.ext < 0 or self.paren < 0: @@ -101,4 +101,3 @@ class CppState: # export def is_optional(self): return self.in_list() or self.in_option() - diff --git a/src/python/jw/util/Makefile b/src/python/jw/util/Makefile index 8785ac2..59b3ac1 100644 --- a/src/python/jw/util/Makefile +++ b/src/python/jw/util/Makefile @@ -1,6 +1,4 @@ TOPDIR = ../../../.. -PY_UPDATE_INIT_PY ?= false - include $(TOPDIR)/make/proj.mk include $(JWBDIR)/make/py-mod.mk diff --git a/src/python/jw/util/Object.py b/src/python/jw/util/Object.py index 0e7ddc6..96ac95f 100644 --- a/src/python/jw/util/Object.py +++ b/src/python/jw/util/Object.py @@ -1,16 +1,14 @@ -# -*- coding: utf-8 -*- - from __future__ import print_function from . import log -class Object(object): # export +class Object(object): # export def __init__(self): - self.log_level = log.level + self.log_level = log.log_level() def log(self, prio, *args): - if self.log_level == log.level: + if self.log_level == log.log_level(): log.slog(prio, args) return if prio <= self.log_level: diff --git a/src/python/jw/util/Options.py b/src/python/jw/util/Options.py index b4034f5..db098f0 100644 --- a/src/python/jw/util/Options.py +++ b/src/python/jw/util/Options.py @@ -1,11 +1,13 @@ -import re import json -from collections import OrderedDict -from .log import * +import re import shlex import traceback -class Options: # export +from collections import OrderedDict + +from .log import ERR, get_caller_pos, slog, slog_m + +class Options: # export class OrderedData: @@ -30,8 +32,8 @@ class Options: # export if spec[0] != '{': spec = '{' + spec + '}' try: - return json.loads(spec, object_pairs_hook=cls) - except: + return json.loads(spec, object_pairs_hook = cls) + except Exception: pass return None @@ -42,7 +44,7 @@ class Options: # export r = cls() try: opt_strs = shlex.split(opts_str) - except Exception as e: + except Exception: slog_m(ERR, traceback.format_exc()) slog(ERR, 'Failed to split options string >{}<'.format(opts_str)) raise @@ -52,7 +54,7 @@ class Options: # export lhs = sides[0].strip() if not len(lhs): continue - if self.__allowed_keys and not lhs in self.__allowed_keys: + if self.__allowed_keys and lhs not in self.__allowed_keys: raise Exception('Field "{}" not supported'.format(lhs)) rhs = ' '.join(sides[1:]).strip() if len(sides) > 1 else self.__true_val if cls == OrderedDict: @@ -82,7 +84,7 @@ class Options: # export self.__str = self.__str__() def __getitem__(self, key): - if not key in self.__dict.keys(): + if key not in self.__dict.keys(): return None return self.__dict[key] @@ -99,35 +101,38 @@ class Options: # export return len(self.__data.pairs) def __contains__(self, keys): - if not type(keys) in [list, set]: + if type(keys) not in [list, set]: return keys in self.__dict.keys() for key in keys: - if not key in self.__dict.keys(): + if key not in self.__dict.keys(): return False return True def __iter__(self): return iter(self.__list) - def __next__(self): - return next(self.__list) + #def __next__(self): + # return next(self.__list) - def __init__(self, spec=None, delimiter=',', allowed_keys=None, true_val=True): + def __init__( + self, spec = None, delimiter = ',', allowed_keys = None, true_val = True + ): self.__true_val = true_val self.__allowed_keys = None self.__delimiter = delimiter - self.__data = self.OrderedData() if spec is None else self.__parse(spec, self.OrderedData) + self.__data = self.OrderedData( + ) if spec is None else self.__parse(spec, self.OrderedData) self.__dict = {} - #self.__dict = OrderedDict() if spec is None else self.__parse(spec, OrderedDict) + #self.__dict = OrderedDict() if spec is None else self.__parse(spec,OrderedDict) self.__list = [] self.__str = None self.__recache() - def dump(self, prio, caller=None): + def dump(self, prio, caller = None): if caller is None: caller = get_caller_pos() for key, val in self.__data.pairs: - slog(prio, "{}=\"{}\"".format(key, val), caller=caller) + slog(prio, "{}=\"{}\"".format(key, val), caller = caller) def keys(self): return self.__dict.keys() @@ -136,22 +141,28 @@ class Options: # export #return self.__dict.items() return self.__data.pairs - def get(self, key, default=None, by_index=False): + def get(self, key, default = None, by_index = False): if by_index: - if type(key) != int: - raise KeyError('Tried to get value from options string with ' + - 'index {} of type "{}": {}'.format(key, type(key), str(self))) + if isinstance(key, int): + raise KeyError( + 'Tried to get value from options string with ' + + 'index {} of type "{}": {}'.format(key, type(key), str(self)) + ) if key >= len(self.__data.pairs): if default is not None: return default - raise KeyError('Tried to get value from options string with ' + - 'index {} of {}: {}'.format(key, len(self.__data.pairs), str(self))) + raise KeyError( + 'Tried to get value from options string with ' + + 'index {} of {}: {}'.format(key, len(self.__data.pairs), str(self)) + ) return self.__list[key] if key in self.__dict.keys(): return self.__dict[key] if default is not None: return default - raise KeyError('Key "{}" is not present in options string: {}'.format(key, str(self))) + raise KeyError( + 'Key "{}" is not present in options string: {}'.format(key, str(self)) + ) def update(self, rhs): if hasattr(rhs, 'items'): @@ -159,9 +170,13 @@ class Options: # export self.__dict[key] = val return if isinstance(rhs, str): - self.update(self.__parse(rhs)) + self.update(self.__parse(rhs, self.OrderedData)) return - raise Exception('Tried to update options with object of incompatible type {}'.format(type(rhs))) + raise Exception( + 'Tried to update options with object of incompatible type {}'.format( + type(rhs) + ) + ) def append_to(self, obj): for opt in self.__list: diff --git a/src/python/jw/util/Process.py b/src/python/jw/util/Process.py index a20a856..fed6aa9 100644 --- a/src/python/jw/util/Process.py +++ b/src/python/jw/util/Process.py @@ -1,41 +1,46 @@ -# -*- coding: utf-8 -*- - from __future__ import annotations + +import signal + from abc import ABC, abstractmethod from enum import Enum, Flag, auto -from typing import List +from typing import TYPE_CHECKING + +from .log import ERR, slog + +if TYPE_CHECKING: + from .Signals import Signals def _sigchld_handler(signum, process): if not signum == signal.SIGCHLD: return Process.propagate_signal(signum) -class Process(ABC): # export +class Process(ABC): # export - __processes: List[Process] = [] + __processes: set[Process] = set() class State(Enum): - Running = auto() - Shutdown = auto() - Done = auto() + Running = auto() + Shutdown = auto() + Done = auto() class Flags(Flag): FailOnExitWithoutShutdown = auto() def __init__(self): - self.__state = Running - self.__flags = Flags.FailOnExitWithoutShutdown + self.__state = Process.State.Running + self.__flags = self.Flags.FailOnExitWithoutShutdown if len(self.__processes) == 0: - self._signals().add_handler(signals.SIGCHLD, _sigchld_handler) + self.signals().add_handler(signal.SIGCHLD, _sigchld_handler) self.__processes.add(self) @classmethod def propagate_signal(cls, signum): - for p in cls.__processes: - p.__signal(signum) + cls.signals().propagate(signum) def signal(self, signum): - if signum == signals.SIGCHLD: + if signum == signal.SIGCHLD: self.exited() @abstractmethod @@ -44,7 +49,7 @@ class Process(ABC): # export @classmethod @abstractmethod - def signals(cls): + def signals(cls) -> Signals: pass # to be reimplemented @@ -56,17 +61,15 @@ class Process(ABC): # export return str(self._pid()) def request_shutdown(self): - if not self.__state == Shutdown: - self.__state = Shutdown + if not self.__state == Process.State.Shutdown: + self.__state = Process.State.Shutdown self._request_shutdown() def exited(self): if self.__state == Process.State.Running: - slog(ERR, 'process "{}" exited unexpectedly'.format(process.name())) - if __flags & Process.Flags.FailOnExitWithoutShutdown: + slog(ERR, 'process exited unexpectedly') + if self.__flags & Process.Flags.FailOnExitWithoutShutdown: slog(ERR, 'exiting') exit(1) self.__state = Process.State.Done - self.__processes.erase(self) - if len(self.__processes) == 0: - self._signals().remove_handler(signals.SIGCHLD) # FIXME: broken logic + self.__processes.remove(self) diff --git a/src/python/jw/util/RedirectStdIO.py b/src/python/jw/util/RedirectStdIO.py index 0d6bd00..e68d083 100644 --- a/src/python/jw/util/RedirectStdIO.py +++ b/src/python/jw/util/RedirectStdIO.py @@ -1,13 +1,13 @@ -# -*- coding: utf-8 -*- - from __future__ import print_function -import os, io, sys, traceback +import os +import io +import sys from fcntl import fcntl, F_GETFL, F_SETFL -class RedirectStdIO: # export +class RedirectStdIO: # export - def __init__(self, stderr='on', stdout='off'): + def __init__(self, stderr = 'on', stdout = 'off'): self.__stderr = stderr self.__stdout = stdout # TODO: arguments not fully implemented, @@ -30,12 +30,12 @@ class RedirectStdIO: # export sys.stdout.flush() os.dup2(self.real_stdout_fd, 1) if type is not None: - #print("-------- Error while stdio was suppressed --------") - #traceback.print_stack() - #print(traceback) - print("-------- Captured output --------") - print(*self.rfile.readlines()) - self.rfile.close() - #print('type = ' + str(type)) - #print('value = ' + str(value)) - raise type(value) + #print("-------- Error while stdio was suppressed --------") + #traceback.print_stack() + #print(traceback) + print("-------- Captured output --------") + print(*self.rfile.readlines()) + self.rfile.close() + #print('type = ' + str(type)) + #print('value = ' + str(value)) + raise type(value) diff --git a/src/python/jw/util/Signals.py b/src/python/jw/util/Signals.py index 1179753..2753f7f 100644 --- a/src/python/jw/util/Signals.py +++ b/src/python/jw/util/Signals.py @@ -1,19 +1,10 @@ -# -*- coding: utf-8 -*- - -from typing import Dict, Callable -from abc import ABC, abstractmethod - -_handled_signals: Dict[int, Callable] = {} - -def _signal_handler(signal, frame): - if not signal in _handled_signals.keys(): - return - for h in _handled_signals[signal]: - h.func(signal, *h.args) +from abc import abstractmethod +from typing import Dict class Signals: class Handler: + def __init__(self, func, args): self.func = func self.args = args @@ -23,15 +14,27 @@ class Signals: @classmethod @abstractmethod - def _add_handler(self, signal, handler): + def _add_handler(cls, signal, handler): raise Exception("_add_handler() is not reimplemented") @classmethod def add_handler(cls, signals, handler, *args): for signal in signals: h = Signals.Handler(handler, args) - if not signal in _handled_signals.keys(): + if signal not in _handled_signals.keys(): _handled_signals[signal] = [h] - cls._add_signal_handler(signal, _signal_handler) + cls._add_handler(signal, _signal_handler) else: - _handled_signals[signal].add(h) + _handled_signals[signal].append(h) + + @classmethod + def propagate(cls, signal): + _signal_handler(signal, None) + +_handled_signals: Dict[int, list[Signals.Handler]] = {} + +def _signal_handler(signal, frame): + if signal not in _handled_signals.keys(): + return + for h in _handled_signals[signal]: + h.func(signal, *h.args) diff --git a/src/python/jw/util/StopWatch.py b/src/python/jw/util/StopWatch.py index 4e7472b..745c460 100644 --- a/src/python/jw/util/StopWatch.py +++ b/src/python/jw/util/StopWatch.py @@ -1,12 +1,10 @@ -# -*- coding: utf-8 -*- - from datetime import datetime -from .log import * +from .log import get_caller_pos, slog -class StopWatch: # export +class StopWatch: # export - def __init__(self, name=''): + def __init__(self, name = ''): self.__start = datetime.now() self.__last = self.__start self.name = name @@ -21,5 +19,9 @@ class StopWatch: # export else: msg = '------------------ ' caller = kwargs['caller'] if 'caller' in kwargs.keys() else get_caller_pos(1) - slog(prio, '{} {} {}'.format(self.name, str(now - self.__last), msg), caller=caller) + slog( + prio, + '{} {} {}'.format(self.name, str(now - self.__last), msg), + caller = caller + ) self.__last = now diff --git a/src/python/jw/util/__init__.py b/src/python/jw/util/__init__.py deleted file mode 100644 index b36383a..0000000 --- a/src/python/jw/util/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from pkgutil import extend_path - -__path__ = extend_path(__path__, __name__) diff --git a/src/python/jw/util/algo/ShuntingYard.py b/src/python/jw/util/algo/ShuntingYard.py index b197eee..483bd30 100644 --- a/src/python/jw/util/algo/ShuntingYard.py +++ b/src/python/jw/util/algo/ShuntingYard.py @@ -1,16 +1,15 @@ -# -*- coding: utf-8 -*- +import re -import re, shlex from collections import namedtuple -from ..log import * +from ..log import DEBUG, get_caller_pos, prio_gets_logged, slog L, R = 'Left Right'.split() ARG, KEYW, QUOTED, LPAREN, RPAREN = 'arg kw quoted ( )'.split() -class Operator: # export +class Operator: # export - def __init__(self, func=None, nargs=2, precedence=3, assoc=L): + def __init__(self, func = None, nargs = 2, precedence = 3, assoc = L): self.func = func self.nargs = nargs self.prec = precedence @@ -18,7 +17,7 @@ class Operator: # export class Stack: - def __init__(self, itemlist=[]): + def __init__(self, itemlist = []): self.items = itemlist def __repr__(self): @@ -39,7 +38,7 @@ class Stack: self.items.append(item) return 0 -class ShuntingYard(object): # export +class ShuntingYard(object): # export def __init__(self, operators = None): self.do_debug = prio_gets_logged(DEBUG) @@ -54,7 +53,7 @@ class ShuntingYard(object): # export for count, thing in enumerate(args): msg += ' ' + str(thing) if len(msg): - slog(DEBUG, msg[1:], caller=get_caller_pos()) + slog(DEBUG, msg[1:], caller = get_caller_pos()) def operator(self, key: str) -> Operator: return self.__ops[key] @@ -65,7 +64,7 @@ class ShuntingYard(object): # export v = self.__ops[k] buf = ", \"" + k if v.nargs == 1: - if k[len(k)-1].isalnum(): + if k[len(k) - 1].isalnum(): buf = buf + ' ' buf = buf + "xxx" buf = buf + "\"" @@ -83,16 +82,20 @@ class ShuntingYard(object): # export regex = regex[1:] - scanner = re.Scanner([ - (regex, lambda scanner,token:(KEYW, token)), - (r"\"[^\"]*\"|'[^']*'", lambda scanner,token:(QUOTED, token[1:-1])), - (r"[^\s()]+", lambda scanner,token:(ARG, token)), - (r"\s+", None), # None == skip token. - ]) + scanner = re.Scanner( # pyright: ignore[reportAttributeAccessIssue] + [ + (regex, lambda scanner, token: (KEYW, token)), + (r"\"[^\"]*\"|'[^']*'", lambda scanner, token: (QUOTED, token[1:-1])), + (r"[^\s()]+", lambda scanner, token: (ARG, token)), + (r"\s+", None), # None == skip token. + ] + ) tokens, remainder = scanner.scan(spec) - if len(remainder)>0: - raise Exception("Failed to tokenize " + spec + ", remaining bit is ", remainder) + if len(remainder) > 0: + raise Exception( + "Failed to tokenize " + spec + ", remaining bit is ", remainder + ) #self.debug(tokens) return tokens @@ -112,14 +115,14 @@ class ShuntingYard(object): # export tokenized = self.tokenize(infix) self.debug("tokenized = ", tokenized) outq, stack = [], [] - table = ['TOKEN,ACTION,RPN OUTPUT,OP STACK,NOTES'.split(',')] + table = ['TOKEN', 'ACTION', 'RPN OUTPUT', ('OP STACK', ), 'NOTES'] for toktype, token in tokenized: self.debug("Checking token", token) note = action = '' - if toktype in [ ARG, QUOTED ]: + if toktype in [ARG, QUOTED]: action = 'Add arg to output' outq.append(token) - table.append( (token, action, outq, (s[0] for s in stack), note) ) + table.append((token, action, outq, (s[0] for s in stack), note)) elif toktype == KEYW: val = self.__ops[token] t1, op1 = token, val @@ -127,7 +130,9 @@ class ShuntingYard(object): # export note = 'Pop ops from stack to output' while stack: t2, op2 = stack[-1] - if (op1.assoc == L and op1.prec <= op2.prec) or (op1.assoc == R and op1.prec < op2.prec): + if (op1.assoc == L + and op1.prec <= op2.prec) or (op1.assoc == R + and op1.prec < op2.prec): if t1 != RPAREN: if t2 != LPAREN: stack.pop() @@ -143,9 +148,11 @@ class ShuntingYard(object): # export else: stack.pop() action = '(Pop & discard "(")' - table.append( (v, action, outq, (s[0] for s in stack), note) ) + table.append( + (v, action, outq, (s[0] for s in stack), note) + ) break - table.append( (v, action, (outq), (s[0] for s in stack), note) ) + table.append((v, action, (outq), (s[0] for s in stack), note)) v = note = '' else: note = '' @@ -157,7 +164,7 @@ class ShuntingYard(object): # export action = 'Push op token to stack' else: action = 'Discard ")"' - table.append( (v, action, (outq), (s[0] for s in stack), note) ) + table.append((v, action, (outq), (s[0] for s in stack), note)) note = 'Drain stack to output' while stack: v = '' @@ -165,15 +172,27 @@ class ShuntingYard(object): # export action = '(Pop op)' stack.pop() outq.append(t2) - table.append( (v, action, outq, (s[0] for s in stack), note) ) + table.append((v, action, outq, (s[0] for s in stack), note)) v = note = '' if self.do_debug: - maxcolwidths = [len(max(x, key=len)) for x in [zip(*table)]] - caller = get_caller_pos() + maxcolwidths = [len(max(x, key = len)) for x in [zip(*table)]] + get_caller_pos() row = table[0] - slog(DEBUG, ' '.join('{cell:^{width}}'.format(width=width, cell=cell) for (width, cell) in zip(maxcolwidths, row))) + slog( + DEBUG, + ' '.join( + '{cell:^{width}}'.format(width = width, cell = cell) + for (width, cell) in zip(maxcolwidths, row) + ) + ) for row in table[1:]: - slog(DEBUG, ' '.join('{cell:<{width}}'.format(width=width, cell=cell) for (width, cell) in zip(maxcolwidths, row))) + slog( + DEBUG, + ' '.join( + '{cell:<{width}}'.format(width = width, cell = cell) + for (width, cell) in zip(maxcolwidths, row) + ) + ) return table[-1][2] def infix_to_postfix_orig(self, infix): @@ -185,7 +204,7 @@ class ShuntingYard(object): # export for tokinfo in tokens: self.debug(tokinfo) - toktype, token = tokinfo[0], tokinfo[1] + _toktype, token = tokinfo[0], tokinfo[1] self.debug("Checking token ", token) @@ -204,7 +223,8 @@ class ShuntingYard(object): # export topToken = s.pop() continue - while (not s.isEmpty()) and (self.__ops[s.peek()].prec >= self.__ops[token].prec): + while (not s.isEmpty()) and (self.__ops[s.peek()].prec + >= self.__ops[token].prec): #self.debug(token) r.append(s.pop()) #self.debug(r) @@ -240,7 +260,9 @@ class ShuntingYard(object): # export args.append(vals.pop()) #self.debug("running %s(%s)" % (token, ', '.join(reversed(args)))) val = op.func(*reversed(args)) - self.debug("%s(%s) = %s" % (token, ', '.join(map(str, reversed(args))), val)) + self.debug( + "%s(%s) = %s" % (token, ', '.join(map(str, reversed(args))), val) + ) vals.push(val) return vals.pop() @@ -266,27 +288,27 @@ if __name__ == '__main__': # return string.split() def f_mult(self, a, b): - return str(atof(a) * atof(b)); + return str(atof(a) * atof(b)) def f_div(self, a, b): - return str(atof(a) / atof(b)); + return str(atof(a) / atof(b)) def f_add(self, a, b): - return str(atof(a) + atof(b)); + return str(atof(a) + atof(b)) def f_sub(self, a, b): - return str(atof(a) - atof(b)); + return str(atof(a) - atof(b)) def __init__(self): Op = Operator operators = { - '^': Op(None, 2, 4, R), + '^': Op(None, 2, 4, R), '*': Op(self.f_mult, 2, 3, L), - '/': Op(self.f_div, 2, 3, L), - '+': Op(self.f_add, 2, 2, L), - '-': Op(self.f_sub, 2, 2, L), - '(': Op(None, 0, 9, L), - ')': Op(None, 0, 0, L), + '/': Op(self.f_div, 2, 3, L), + '+': Op(self.f_add, 2, 2, L), + '-': Op(self.f_sub, 2, 2, L), + '(': Op(None, 0, 9, L), + ')': Op(None, 0, 0, L), } super(Calculator, self).__init__(operators) @@ -295,7 +317,7 @@ if __name__ == '__main__': # ------------- testbed match object - Object = namedtuple("Object", [ "Name", "Label" ]) + Object = namedtuple("Object", ["Name", "Label"]) class Matcher(ShuntingYard): @@ -324,14 +346,14 @@ if __name__ == '__main__': def __init__(self, obj): Op = Operator operators = { - '(': Op(None, 2, 9, L), - ')': Op(None, 2, 0, L), - 'name=': Op(self.f_is_name, 1, 3, R), - 'and': Op(self.f_and, 2, 3, L), - 'label~=': Op(self.f_matches_label, 1, 3, R), - 'False': Op(None, 0, 3, L), - 'True': Op(None, 0, 3, L), - 'not': Op(self.f_is_not, 1, 3, R), + '(': Op(None, 2, 9, L), + ')': Op(None, 2, 0, L), + 'name=': Op(self.f_is_name, 1, 3, R), + 'and': Op(self.f_and, 2, 3, L), + 'label~=': Op(self.f_matches_label, 1, 3, R), + 'False': Op(None, 0, 3, L), + 'True': Op(None, 0, 3, L), + 'not': Op(self.f_is_not, 1, 3, R), } super(Matcher, self).__init__(operators) diff --git a/src/python/jw/util/asyncio/Process.py b/src/python/jw/util/asyncio/Process.py index 3a63713..8708241 100644 --- a/src/python/jw/util/asyncio/Process.py +++ b/src/python/jw/util/asyncio/Process.py @@ -1,16 +1,14 @@ -# -*- coding: utf-8 -*- - from abc import abstractmethod from ..Process import Process as ProcessBase from .Signals import Signals -class Process(ProcessBase): # export +class Process(ProcessBase): # export __signals = Signals() def __init__(self, aio): - super().__init() + super().__init__() self.aio = aio @classmethod diff --git a/src/python/jw/util/asyncio/ShellCmd.py b/src/python/jw/util/asyncio/ShellCmd.py index ab62610..1b96a24 100644 --- a/src/python/jw/util/asyncio/ShellCmd.py +++ b/src/python/jw/util/asyncio/ShellCmd.py @@ -1,9 +1,11 @@ import asyncio -from ..log import * +import re + +from ..log import DEBUG, ERR, INFO, WARNING, slog # FIXME: Derive this from Process, or merge the classes entirely -class ShellCmd: # export +class ShellCmd: # export class SubprocessProtocol(asyncio.SubprocessProtocol): @@ -26,12 +28,12 @@ class ShellCmd: # export self.process.exited() class ShutdownState: - Running = 1 - Triggered = 2 - Completed = 3 + Running = 1 + Triggered = 2 + Completed = 3 Unnecessary = 4 - def __init__(self, cmdline, eloop=None, name=None): + def __init__(self, cmdline, eloop = None, name = None): if eloop is None: eloop = asyncio.get_running_loop() self.__eloop = eloop @@ -56,12 +58,19 @@ class ShellCmd: # export return r[1:] try: - slog(INFO, "Running shell command [{}]: {}".format(self.__name, format_cmdline(self.__cmdline))) + slog( + INFO, + "Running shell command [{}]: {}".format( + self.__name, format_cmdline(self.__cmdline) + ) + ) self.__transport, self.__protocol = await self.__eloop.subprocess_exec( lambda: self.SubprocessProtocol(self, self.__name), *self.__cmdline, ) - self.__proc = self.__transport.get_extra_info('subprocess') # Popen instance + self.__proc = self.__transport.get_extra_info( + 'subprocess' + ) # Popen instance except: slog(ERR, "Failed to run process [{}]".format(self.__name)) raise @@ -69,7 +78,8 @@ class ShellCmd: # export def __reap(self): if self.__rc is None and self.__transport: self.__transport = None - self.__rc = self.__proc.wait() + if self.__proc is not None: + self.__rc = self.__proc.wait() # to be called from SubprocessProtocol / SIGCHLD handler def exited(self): @@ -78,13 +88,24 @@ class ShellCmd: # export async def __cleanup(self): pid = self.__reap() - sd_fine = self.__shutdown in [ self.ShutdownState.Unnecessary, self.ShutdownState.Completed ] + sd_fine = self.__shutdown in [ + self.ShutdownState.Unnecessary, self.ShutdownState.Completed + ] if self.__rc == 0 and sd_fine: - slog(INFO, "The shell command [{}], pid {}, has exited cleanly".format(self.__name, self.__proc.pid)) + assert self.__proc is not None + slog( + INFO, + "The shell command [{}], pid {}, has exited cleanly".format( + self.__name, self.__proc.pid + ) + ) self.monitor = self.console = self.__protocol = self.__task = None return 0 - slog(ERR, "The process ([{}], pid {}) has exited {}with status code {}, aborting".format( - self.__name, pid, "" if sd_fine else "prematurely ", self.__rc)) + slog( + ERR, + "The process ([{}], pid {}) has exited {}with status code {}, aborting". + format(self.__name, pid, "" if sd_fine else "prematurely ", self.__rc) + ) exit(1) async def init(self): @@ -100,9 +121,9 @@ class ShellCmd: # export if __name__ == '__main__': from .. import log log.set_level('info') + async def run(): - sp = ShellCmd([ 'echo', 'hello world!' ]) + sp = ShellCmd(['echo', 'hello world!']) await sp.run() asyncio.run(run()) - diff --git a/src/python/jw/util/asyncio/Signals.py b/src/python/jw/util/asyncio/Signals.py index 7096d0a..470c188 100644 --- a/src/python/jw/util/asyncio/Signals.py +++ b/src/python/jw/util/asyncio/Signals.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import asyncio from ..Signals import Signals as SignalsBase @@ -10,4 +8,4 @@ class Signals(SignalsBase): @classmethod def _add_handler(cls, signal, handler): loop = asyncio.get_running_loop() - loop.add_signal_handler(signal, handler, None) # None = *args + loop.add_signal_handler(signal, handler, None) # None = *args diff --git a/src/python/jw/util/auth/Auth.py b/src/python/jw/util/auth/Auth.py index 0f56fcc..155d976 100644 --- a/src/python/jw/util/auth/Auth.py +++ b/src/python/jw/util/auth/Auth.py @@ -1,27 +1,28 @@ -# -*- coding: utf-8 -*- - -from typing import Optional, Union, Self +from __future__ import annotations import abc -from enum import Flag, Enum, auto +from enum import Enum, Flag, auto +from typing import TYPE_CHECKING, Optional, Self, Union -from ..log import * -from ..Config import Config +from ..log import ERR from ..misc import load_object -class Access(Enum): # export - Read = auto() +if TYPE_CHECKING: + from ..Config import Config + +class Access(Enum): # export + Read = auto() Modify = auto() Create = auto() Delete = auto() -class ProjectFlags(Flag): # export - NoFlags = auto() +class ProjectFlags(Flag): # export + NoFlags = auto() Contributing = auto() - Active = auto() + Active = auto() -class Group: # export +class Group: # export def __repr__(self): return f'Group({self.name})' @@ -34,7 +35,7 @@ class Group: # export def name(self) -> str: return self._name() -class User: # export +class User: # export def __repr__(self): return f'User({self.name})' @@ -70,14 +71,14 @@ class User: # export def email(self) -> str: return self._email() -class Auth(abc.ABC): # export +class Auth(abc.ABC): # export @classmethod - def load(cls, conf: Config, tp: str='') -> Self: + def load(cls, conf: Config, tp: str = '') -> Self: if tp == '': val = conf.get('type') if val is None: - msg = f'No type specified in auth configuration' + msg = 'No type specified in auth configuration' conf.dump(ERR, msg) raise Exception(msg) tp = val @@ -92,10 +93,17 @@ class Auth(abc.ABC): # export return self.__conf @abc.abstractmethod - def _access(self, what: str, access_type: Optional[Access], who: User|Group|None) -> bool: + def _access( + self, what: str, access_type: Optional[Access], who: User | Group | None + ) -> bool: raise NotImplementedError - def access(self, what: str, access_type: Optional[Access]=None, who: Optional[Union[User|Group]]=None) -> bool: + def access( + self, + what: str, + access_type: Optional[Access] = None, + who: Optional[Union[User, Group]] = None + ) -> bool: return self._access(what, access_type, who) @abc.abstractmethod diff --git a/src/python/jw/util/auth/dummy/Auth.py b/src/python/jw/util/auth/dummy/Auth.py index 4b6b9c0..b6f51b3 100644 --- a/src/python/jw/util/auth/dummy/Auth.py +++ b/src/python/jw/util/auth/dummy/Auth.py @@ -1,16 +1,18 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional -from ...log import * -from ... import Config -from .. import Access -from .. import Auth as AuthBase -from .. import Group as GroupBase -from .. import User as UserBase -from .. import ProjectFlags +from ...log import WARNING, slog +from ..Auth import Access +from ..Auth import Auth as AuthBase +from ..Auth import Group as GroupBase +from ..Auth import ProjectFlags +from ..Auth import User as UserBase -class Group(GroupBase): # export +if TYPE_CHECKING: + from ...Config import Config + +class Group(GroupBase): # export def __init__(self, auth: AuthBase, name: str): self.__name = name @@ -19,7 +21,7 @@ class Group(GroupBase): # export def _name(self) -> str: return self.__name -class User(UserBase): # export +class User(UserBase): # export def __init__(self, auth: AuthBase, name: str, conf: Config): self.__name = name @@ -47,13 +49,13 @@ class User(UserBase): # export def _email(self) -> str: return self.__email -class Auth(AuthBase): # export +class Auth(AuthBase): # export def __init__(self, conf: Config): super().__init__(conf) self.___users: Optional[dict[str, UserBase]] = None self.__groups = None - self.__current_user: UserBase|None = None + self.__current_user: UserBase | None = None self.__user_by_email: Optional[dict[str, UserBase]] = None @property @@ -62,12 +64,18 @@ class Auth(AuthBase): # export ret: dict[str, UserBase] = {} for name in self.conf.entries('user'): conf = self.conf.branch('user.' + name) + assert conf is not None, 'Config is None' ret[name] = User(self, name, conf) self.___users = ret return self.___users - def _access(self, what: str, access_type: Optional[Access], who: User|GroupBase|None) -> bool: # type: ignore - slog(WARNING, f'Returning False for {access_type} access to resource {what} by {who}') + def _access( + self, what: str, access_type: Access | None, who: UserBase | GroupBase | None + ) -> bool: # type: ignore + slog( + WARNING, + f'Returning False for {access_type} access to resource {what} by {who}' + ) return False def _user(self, name) -> UserBase: diff --git a/src/python/jw/util/auth/ldap/Auth.py b/src/python/jw/util/auth/ldap/Auth.py index 29ebf83..1427eed 100644 --- a/src/python/jw/util/auth/ldap/Auth.py +++ b/src/python/jw/util/auth/ldap/Auth.py @@ -1,19 +1,21 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional -import ldap +import ldap # type: ignore[import-untyped] -from ...log import * from ...ldap import bind -from ...Config import Config -from .. import Access -from .. import Auth as AuthBase -from .. import Group as GroupBase -from .. import User as UserBase -from .. import ProjectFlags +from ...log import DEBUG, ERR, WARNING, slog +from ..Auth import Access +from ..Auth import Auth as AuthBase +from ..Auth import Group as GroupBase +from ..Auth import ProjectFlags +from ..Auth import User as UserBase -class Group(GroupBase): # export +if TYPE_CHECKING: + from ...Config import Config + +class Group(GroupBase): # export def __init__(self, auth: AuthBase, name: str): self.__name = name @@ -24,13 +26,7 @@ class Group(GroupBase): # export class User(UserBase): - def __init__( - self, - auth: AuthBase, - name: str, - cn: str, - email: str - ): + def __init__(self, auth: AuthBase, name: str, cn: str, email: str): self.__auth = auth self.__name = name @@ -50,14 +46,14 @@ class User(UserBase): def _display_name(self) -> str: return self.__cn -class Auth(AuthBase): # export +class Auth(AuthBase): # export def __init__(self, conf: Config): super().__init__(conf) self.___users: Optional[dict[str, UserBase]] = None self.___user_by_email: Optional[dict[str, User]] = None self.__groups = None - self.__current_user: User|None = None + self.__current_user: User | None = None self.__user_base_dn = conf['user_base_dn'] self.__conn = self.__bind() self.__dummy = self.load(conf, 'dummy') @@ -72,18 +68,16 @@ class Auth(AuthBase): # export ret_by_email: dict[str, User] = {} for res in self.__conn.find( self.__user_base_dn, - ldap.SCOPE_SUBTREE, + ldap.SCOPE_SUBTREE, # pyright: ignore[reportAttributeAccessIssue] "objectClass=inetOrgPerson", - ('uid', 'cn', 'uidNumber', 'mail', 'maildrop') - ): + ('uid', 'cn', 'uidNumber', 'mail', 'maildrop')): try: - display_name = None if 'displayName' in res[1]: cn = res[1]['displayName'][0].decode('utf-8') else: cn = res[1]['cn'][0].decode('utf-8') uid = res[1]['uid'][0].decode('utf-8') - uidNumber = res[1]['uidNumber'][0].decode('utf-8') + res[1]['uidNumber'][0].decode('utf-8') emails = [] #for attr in ['mail', 'maildrop']: for attr in ['mail']: @@ -93,7 +87,7 @@ class Auth(AuthBase): # export if not emails: slog(DEBUG, f'No email for user "{uid}", skipping') continue - user = User(self, name=uid, cn=cn, email=emails[0]) + user = User(self, name = uid, cn = cn, email = emails[0]) ret[uid] = user for email in emails: ret_by_email[email] = user @@ -111,10 +105,18 @@ class Auth(AuthBase): # export def __user_by_email(self) -> dict[str, UserBase]: if self.___user_by_email is None: self.__users - return self.___user_by_email # type: ignore # We are sure that ___user_by_email is not None at this point + return self.___user_by_email # type: ignore # We are sure that ___user_by_email is not None at this point - def _access(self, what: str, access_type: Optional[Access], who: User|GroupBase|None) -> bool: # type: ignore - slog(WARNING, f'Returning False for {access_type} access to resource {what} by {who}') + def _access( + self, + what: str, + access_type: Optional[Access], + who: UserBase | GroupBase | None + ) -> bool: # type: ignore + slog( + WARNING, + f'Returning False for {access_type} access to resource {what} by {who}' + ) return False def _user(self, name) -> UserBase: @@ -136,5 +138,9 @@ class Auth(AuthBase): # export def _projects(self, name, flags: ProjectFlags) -> list[str]: if flags & ProjectFlags.Contributing: # TODO: Ask LDAP - slog(WARNING, f'Querying LDAP for projects a user contributes to is not implemented, ignoring') + slog( + WARNING, + 'Querying LDAP for projects a user contributes to is not ' + 'implemented, ignoring' + ) return [] diff --git a/src/python/jw/util/cast.py b/src/python/jw/util/cast.py index 1579a35..65408d8 100644 --- a/src/python/jw/util/cast.py +++ b/src/python/jw/util/cast.py @@ -1,110 +1,130 @@ -# -*- coding: utf-8 -*- +import os -import pytimeparse, os -from datetime import datetime, timedelta from collections import OrderedDict +from datetime import timedelta -from .log import * +import pytimeparse # type: ignore[import-untyped] + +from .log import DEBUG, WARNING, slog _int_chars = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] -def _strip(s_, throw=True, log_level=ERR): +def _strip(s_) -> str: s = s_.strip() if len(s) != 0: return s - msg = f'Tried to strip empty string "{s_}" to int' - if throw: - raise Exception(msg) - slog(log_level, msg) - return None + raise Exception(f'Tried to strip empty string "{s_}"') -def cast_str_to_timedelta(s_: str, throw=True, log_level=DEBUG): # export - s = _strip(s_, throw=throw, log_level=log_level) - try: - return (True, timedelta(seconds=pytimeparse.parse(s_))) - except Exception as e: - msg = f'Could not convert string "{s_}" to time ({e})' - if throw: - raise Exception(msg) - slog(log_level, msg) - return (False, None) +def cast_str_to_timedelta(s_: str): # export + s = _strip(s_) + seconds = pytimeparse.parse(s) + if seconds is None: + raise Exception(f'Failed to convert {s} to timedelta') + return timedelta(seconds = seconds) -def cast_str_to_int(s_: str, throw=True, log_level=DEBUG): # export - s = _strip(s_, throw=throw, log_level=log_level) +def cast_str_to_int(s_: str): # export + s = _strip(s_) if s[0] == '-': s = s[1:] for c in s: - if not c in _int_chars: - break - else: - return (True, int(s_)) - msg = f'Could not convert string "{s_}" to int' - if throw: - raise Exception(msg) - slog(log_level, msg) - return (False, None) + if c not in _int_chars: + raise Exception(f'Could not convert string "{s}" to int') + return int(s) -def cast_str_to_bool(s_: str, throw=True, log_level=DEBUG): # export - s = _strip(s_, throw=throw, log_level=log_level).lower() +def cast_str_to_bool(s_: str): # export + s = _strip(s_).lower() if s in ['true', 'yes', '1']: - return (True, True) + return True if s in ['false', 'no', '0']: - return (True, False) - msg = f'Could not convert string "{s_}" to bool' - if throw: - raise Exception(msg) - slog(log_level, msg) - return (False, None) + return False + raise Exception(f'Could not convert string "{s_}" to bool') -_str_cast_functions = OrderedDict({ - bool: cast_str_to_bool, - int: cast_str_to_int, - timedelta: cast_str_to_timedelta +_str_cast_functions = OrderedDict( + { + bool: cast_str_to_bool, int: cast_str_to_int, timedelta: cast_str_to_timedelta + } +) -}) - -def guess_type(s: str, default=None, log_level=DEBUG, throw=False): # export +def guess_type(s: str, default = None, log_level = DEBUG, throw = False): # export if s is None: raise Exception('None string passed to guess_type()') for tp, func in _str_cast_functions.items(): try: - success, value = func(s, log_level=OFF, throw=False) - if success: - return tp - except: + func(s) + except Exception: continue + return tp msg = f'Failed to guess type of string "{s}"' if throw: raise Exception(msg) slog(log_level, msg) return default -def from_str(s: str, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None): # export +def from_str( # export + s: str, + target_type = None, + default_type = None, + throw = True, + log_level = WARNING, + caller = None +): if target_type is None: - target_type = guess_type(s, default_type) - if target_type is None: - msg = f'Could not deduce type to cast to from string "{s}"' - if throw: - raise Exception(msg) - slog(log_level, msg) - return None - result = _str_cast_functions[target_type](s, throw=throw, log_level=log_level) - if result[0]: - return result[1] - msg = f'Failed to cast string "{s}" to type {target_type}' - if throw: - raise Exception(msg) - slog(log_level, msg) + for tp, func in _str_cast_functions.items(): + try: + return func(s) + except Exception: + continue + msg = f'Could not deduce type to cast to from string "{s}"' + if throw: + raise Exception(msg) + slog(log_level, msg) + return None + try: + return _str_cast_functions[target_type](s) + except Exception as e: + msg = f'Failed to cast string "{s}" to type {target_type} ({str(e)})' + if throw: + raise Exception(msg) + slog(log_level, msg) return None -def from_env(key: str, default=None, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None): # export +def from_env( # export + key: str, + default = None, + target_type = None, + default_type = None, + throw = True, + log_level = WARNING, + caller = None +): val = os.getenv(key) if val is None: return default if target_type is None and default is not None: target_type = type(default) - return from_str(val, target_type=target_type, default_type=default_type, throw=throw, log_level=log_level, caller=caller) + return from_str( + val, + target_type = target_type, + default_type = default_type, + throw = throw, + log_level = log_level, + caller = caller + ) # deprecated name -def cast_str(s: str, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None): - return from_str(s, target_type=target_type, default_type=None, throw=True, log_level=WARNING, caller=None) +def cast_str( + s: str, + target_type = None, + default_type = None, + throw = True, + log_level = WARNING, + caller = None +): + return from_str( + s, + target_type = target_type, + default_type = None, + throw = True, + log_level = WARNING, + caller = None + ) diff --git a/src/python/jw/util/db/DataBase.py b/src/python/jw/util/db/DataBase.py index 25812dd..ef174cb 100644 --- a/src/python/jw/util/db/DataBase.py +++ b/src/python/jw/util/db/DataBase.py @@ -1,15 +1,16 @@ -# -*- coding: utf-8 -*- - -from typing import Any +from __future__ import annotations import abc -from contextlib import contextmanager -from ..Config import Config -from .schema.Schema import Schema -from ..Cmds import Cmds -from .Session import Session -from ..log import * +from contextlib import contextmanager +from typing import TYPE_CHECKING + +from ..log import NOTICE + +if TYPE_CHECKING: + from ..Config import Config + from .schema.Schema import Schema + from .Session import Session class DataBase(abc.ABC): @@ -37,6 +38,7 @@ class DataBase(abc.ABC): def session(self): ret = self._create_session() try: - yield ret + yield ret finally: - self._delete_session(ret) + if ret is not None: + self._delete_session(ret) diff --git a/src/python/jw/util/db/Session.py b/src/python/jw/util/db/Session.py index 4e653ad..d4c3f00 100644 --- a/src/python/jw/util/db/Session.py +++ b/src/python/jw/util/db/Session.py @@ -1,8 +1,6 @@ -# -*- coding: utf-8 -*- - import abc -class Session(abc.ABC): # export +class Session(abc.ABC): # export def __init__(self, db): self.__db = db diff --git a/src/python/jw/util/db/TableIoHandler.py b/src/python/jw/util/db/TableIoHandler.py index fb36beb..1bb0882 100644 --- a/src/python/jw/util/db/TableIoHandler.py +++ b/src/python/jw/util/db/TableIoHandler.py @@ -1,18 +1,17 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -from typing import Any, List, Union, Optional, Dict from abc import ABC, abstractmethod -import re, csv, json +from typing import TYPE_CHECKING, Any, Dict, Union -from ..log import * -from ..cast import cast_str -from .schema.Schema import Schema +from ..log import ERR, INFO, OFF, slog, slog_m +from .rows import rows_check_not_null, rows_dump, rows_duplicates -from .rows import * +if TYPE_CHECKING: + from .schema.Schema import Schema TType = Union[Any, Dict[str, Any]] -class TableIoHandler(ABC): # export +class TableIoHandler(ABC): # export def __init__(self, schema: Schema): self.__table_meta = None @@ -22,7 +21,8 @@ class TableIoHandler(ABC): # export def _table_meta(self): if self.__table_meta is None: self.__table_meta = self.__schema.table_by_model_name( - self.__class__.__name__, throw=True) + self.__class__.__name__, throw = True + ) return self.__table_meta @property @@ -35,24 +35,29 @@ class TableIoHandler(ABC): # export def _check_non_nullable(self, rows): buf = [] - non_nullable = self.__table_meta.not_null_insertible_columns + non_nullable = self._table_meta.not_null_insertible_columns try: - rows_check_not_null(rows, non_nullable, buf=buf) + rows_check_not_null(rows, non_nullable, buf = buf) except: cn = self.__class__.__name__ tn = self._table_name d = '=========================================================' slog_m(ERR, f'{d} Null values in {cn}\n') for key in non_nullable: - buf = rows_check_not_null(rows, key, log_prio=OFF, throw=False) + buf = rows_check_not_null(rows, key, log_prio = OFF, throw = False) if not buf: continue slog_m(ERR, f'\n{d} Null values in {cn} / {tn}: "{key}"\n') - use_cols=self.log_columns + use_cols = list(self.log_columns) if key not in use_cols: use_cols.append(key) - rows_dump(buf, use_cols=use_cols, log_prio=ERR) - rows_dump(buf, use_cols=use_cols, out_path=f'/tmp/missing_{key}_in_{tn}.html', heading=f'Missing "{key}" in table {tn}') + rows_dump(buf, use_cols = use_cols, log_prio = ERR) + rows_dump( + buf, + use_cols = use_cols, + out_path = f'/tmp/missing_{key}_in_{tn}.html', + heading = f'Missing "{key}" in table {tn}' + ) raise @property @@ -67,7 +72,9 @@ class TableIoHandler(ABC): # export def _store(self, uri: str, data: TType): pass - def load(self, uri: str, reference, check_duplicates=False, write_csv=None) -> TType: + def load( + self, uri: str, reference, check_duplicates = False, write_csv = None + ) -> TType: slog(INFO, f'Reading table "{self._table_name}" from "{uri}"') ret = self._load(uri, reference) if check_duplicates: diff --git a/src/python/jw/util/db/query/Queries.py b/src/python/jw/util/db/query/Queries.py index ecff98d..b11d016 100644 --- a/src/python/jw/util/db/query/Queries.py +++ b/src/python/jw/util/db/query/Queries.py @@ -1,18 +1,18 @@ -# -*- coding: utf-8 -*- - -from typing import Any +from __future__ import annotations +from typing import Any, TYPE_CHECKING import abc -from ...log import * +from ...log import slog, slog_m, ERR, INFO from ...misc import load_classes -from ...Cmds import Cmds from ..DataBase import DataBase -from ..schema.Schema import Schema from .Query import Query as QueryBase -from .QueryResult import QueryResult -class Queries(abc.ABC): # export +if TYPE_CHECKING: + from ..schema.Schema import Schema + from .QueryResult import QueryResult + +class Queries(abc.ABC): # export class Query(QueryBase): @@ -38,7 +38,7 @@ class Queries(abc.ABC): # export return self.__name def __init__(self, db: DataBase) -> None: - assert(isinstance(db, DataBase)) + assert (isinstance(db, DataBase)) self.__db = db self.__queries: dict[str, Any] = dict() @@ -57,7 +57,7 @@ class Queries(abc.ABC): # export def db(self) -> DataBase: return self.__db - def load(self, modules: list[str], cls=QueryBase): + def load(self, modules: list[str], cls = QueryBase): for path in modules: slog(INFO, f'Loading modules from {path}') for c in load_classes(path, cls): @@ -69,8 +69,8 @@ class Queries(abc.ABC): # export def add(self, query: QueryBase, query_name: str, location: str, func: Any): slog(INFO, f'Adding query "{query_name}" on location "{location}"') - assert(isinstance(query_name, str)) - assert(isinstance(location, str)) + assert (isinstance(query_name, str)) + assert (isinstance(location, str)) #ret = self.Query(query, func) ret = self.Query(query, query_name, location, func) #setattr(ret, 'name', name) diff --git a/src/python/jw/util/db/query/Query.py b/src/python/jw/util/db/query/Query.py index 5773df7..41f296f 100644 --- a/src/python/jw/util/db/query/Query.py +++ b/src/python/jw/util/db/query/Query.py @@ -1,18 +1,15 @@ -# -*- coding: utf-8 -*- - -from typing import Any +from __future__ import annotations import abc -from ...log import * -from ...misc import load_classes -from ...Cmds import Cmds -from ..DataBase import DataBase -from ..Session import Session -from .QueryResult import QueryResult -#from .Queries import Queries +from typing import TYPE_CHECKING, Any -class Query(abc.ABC): # export +if TYPE_CHECKING: + from ..DataBase import DataBase + from ..Session import Session + from .QueryResult import QueryResult + +class Query(abc.ABC): # export def __init__(self, parent: Any) -> None: self.__parent = parent diff --git a/src/python/jw/util/db/query/QueryResult.py b/src/python/jw/util/db/query/QueryResult.py index eb3663e..556ba25 100644 --- a/src/python/jw/util/db/query/QueryResult.py +++ b/src/python/jw/util/db/query/QueryResult.py @@ -1,23 +1,22 @@ -# -*- coding: utf-8 -*- - -from typing import Any, Union +from __future__ import annotations import abc + from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Union -from ...log import * -from ...Cmds import Cmds -from ..DataBase import DataBase -from ..Session import Session +if TYPE_CHECKING: + from ..DataBase import DataBase + from ..Session import Session -class ResType(Enum): # export - Statement = auto() - Scalars = auto() - One = auto() - First = auto() - Pages = auto() +class ResType(Enum): # export + Statement = auto() + Scalars = auto() + One = auto() + First = auto() + Pages = auto() -class QueryResult(abc.ABC): # export +class QueryResult(abc.ABC): # export def __init__(self, session: Session, query: Any) -> None: self.__query = query @@ -42,8 +41,8 @@ class QueryResult(abc.ABC): # export def rows(self) -> list[Any]: return self._cast(ResType.Scalars) - def pages(self, per_page=20, page=1) -> Any: - return self._cast(ResType.Pages, per_page=per_page, page=page) + def pages(self, per_page = 20, page = 1) -> Any: + return self._cast(ResType.Pages, per_page = per_page, page = page) def one(self) -> Any: return self._cast(ResType.One) @@ -58,5 +57,5 @@ class QueryResult(abc.ABC): # export # -- pure virtuals @abc.abstractmethod - def _cast(self, res_type: ResType, **kwargs) -> Union[Any|list[Any]]: + def _cast(self, res_type: ResType, **kwargs) -> Union[Any, list[Any]]: pass diff --git a/src/python/jw/util/db/rows.py b/src/python/jw/util/db/rows.py index 0eea9b6..370c793 100644 --- a/src/python/jw/util/db/rows.py +++ b/src/python/jw/util/db/rows.py @@ -1,19 +1,24 @@ -# -*- coding: utf-8 -*- +import csv +import io +import json +import os +import re +import textwrap -import io, os, re, textwrap, json, csv -from tabulate import tabulate # type: ignore +from tabulate import TableFormat, tabulate # type: ignore -from ..log import * +from ..log import (ERR, INFO, WARNING, get_caller_pos, prio_gets_logged, slog, slog_m) -def rows_pretty(rows): # export - if type(rows) == dict: +def rows_pretty(rows): # export + if isinstance(rows, dict): rows = [rows] out = [] for row in rows: - out.append(json.dumps(row, sort_keys=True, indent=4, default=str)) + out.append(json.dumps(row, sort_keys = True, indent = 4, default = str)) return '\n'.join(out) -def rows_duplicates(rows, log_prio=INFO, caller=None): # export +def rows_duplicates(rows, log_prio = INFO, caller = None): # export + def __equal(r1, r2): for col in set(r1.keys()) | set(r2.keys()): if col in r1: @@ -25,11 +30,12 @@ def rows_duplicates(rows, log_prio=INFO, caller=None): # export if r1[col] != r2[col]: return False return True + ret = [] last = len(rows) - 1 i = last while last > 0: - for i in reversed(range(0, last-1)): + for i in reversed(range(0, last - 1)): if __equal(rows[last], rows[i]): ret.append(last) last -= 1 @@ -37,12 +43,15 @@ def rows_duplicates(rows, log_prio=INFO, caller=None): # export last -= 1 return ret -def rows_remove(rows, callback=None, candidates=None, log_prio=INFO, caller=None): # export +def rows_remove( # export + rows, callback = None, candidates = None, log_prio = INFO, caller = None +): def __is_remove_candidate(row): + assert candidates is not None, 'Candidates is None' for remove_row in candidates: for col, val in row.items(): - if not col in remove_row.keys(): + if col not in remove_row.keys(): break if val != remove_row[col]: break @@ -65,14 +74,14 @@ def rows_remove(rows, callback=None, candidates=None, log_prio=INFO, caller=None remove.append(index) continue for index in reversed(remove): - slog(log_prio, f'Removing row {rows[index]}', caller=caller) + slog(log_prio, f'Removing row {rows[index]}', caller = caller) del rows[index] -def rows_select(rows, rules): # export +def rows_select(rows, rules): # export ret = [] for row in rows: for rule in rules: - if type(rule) == tuple(): + if isinstance(rule, tuple): search_rule = rule[0] else: search_rule = rule @@ -84,7 +93,7 @@ def rows_select(rows, rules): # export break return ret -def rows_rewrite_regex(rows, rules): # export +def rows_rewrite_regex(rows, rules): # export for row in rows: for rule in rules: try: @@ -93,14 +102,25 @@ def rows_rewrite_regex(rows, rules): # export break else: for exec_col_name, exec_val in rule[1].items(): - slog(INFO, f'Rewriting {row} {row.get(exec_col_name)} -> {exec_val}') + slog( + INFO, + f'Rewriting {row} {row.get(exec_col_name)} -> {exec_val}' + ) row[exec_col_name] = exec_val except Exception as e: slog(ERR, f'Failed to run rule {rule} against {row} ({e})') raise -def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, throw=True, caller=None): # export - if type(keys) == str: +def rows_check_not_null( # export + rows, + keys, + log_prio = WARNING, + buf = None, + stat_key = None, + throw = True, + caller = None +): + if isinstance(keys, str): keys = [keys] if caller is None: caller = get_caller_pos() @@ -113,11 +133,11 @@ def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, t for row in rows: for key in keys: if row.get(key) is None: - slog(log_prio, f'{key} is missing in row {row}', caller=caller) + slog(log_prio, f'{key} is missing in row {row}', caller = caller) buf.append(row) if stat_key is not None: stat_val = row[stat_key] - if not stat_val in stats.keys(): + if stat_val not in stats.keys(): stats[stat_val] = 0 stats[stat_val] += 1 count += 1 @@ -125,14 +145,27 @@ def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, t if count > 0: if stat_key is not None: i = 0 - for k, v in reversed(sorted(stats.items(), key=lambda item: item[1])): + for k, v in reversed(sorted(stats.items(), key = lambda item: item[1])): i += 1 - slog(ERR, f'{i:>3}. {k:<23}: {v}', caller=caller) + slog(ERR, f'{i:>3}. {k:<23}: {v}', caller = caller) if throw: - raise Exception(f'Found {count} rows violating null-constraint for keys {keys}') + raise Exception( + f'Found {count} rows violating null-constraint for keys {keys}' + ) return buf -def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, table_name=None, out_path='log', heading=None, lead=None, tablefmt=None): # export +def rows_dumps( # export + rows, + log_prio = INFO, + caller = None, + use_cols = None, + skip_cols = None, + table_name = None, + out_path = 'log', + heading = None, + lead = None, + tablefmt = None +): headers = 'keys' dump_rows = rows @@ -152,20 +185,21 @@ def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, new_row[col] = val new_dump_rows.append(new_row) dump_rows = new_dump_rows - out = header = footer = "" + header = footer = "" match tablefmt: case 'html': if heading is not None: heading = f'

{heading}

\n' - if type(lead) == str: + if isinstance(lead, str): lead = f'
\n {lead}\n
\n' - elif type(lead) == list: - l = '\n' + lead = lst + header = textwrap.dedent( + '''\ @@ -185,30 +219,47 @@ def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, - ''') - footer = textwrap.dedent(''' + ''' + ) + footer = textwrap.dedent( + ''' - ''') + ''' + ) case _: - if type(heading) == str: + if isinstance(heading, str): heading = '\n' + heading - if type(lead) == str: + if isinstance(lead, str): pass - elif type(lead) == list: - l ='' + elif isinstance(lead, list): + lst = '' for li in lead: - l += f' - {li}\n' - lead = '\n\n' + l + '\n' + lst += f' - {li}\n' + lead = '\n\n' + lst + '\n' if heading is None: heading = '' if lead is None: lead = '' - return header + heading + lead + tabulate(dump_rows, headers=headers, tablefmt=tablefmt) + footer + assert isinstance(tablefmt, str) or isinstance(tablefmt, TableFormat), 'tablefmt' + return header + heading + lead + tabulate( + dump_rows, headers = headers, tablefmt = tablefmt + ) + footer -def rows_dump(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, table_name=None, out_path='log', heading=None, lead=None, tablefmt=None): # export +def rows_dump( # export + rows, + log_prio = INFO, + caller = None, + use_cols = None, + skip_cols = None, + table_name = None, + out_path = 'log', + heading = None, + lead = None, + tablefmt = None +): if not prio_gets_logged(log_prio): return @@ -218,18 +269,35 @@ def rows_dump(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, t if tablefmt is None and out_path: tablefmt = os.path.splitext(out_path)[1][1:] - out = rows_dumps(rows, log_prio=log_prio, caller=caller, use_cols=use_cols, skip_cols=skip_cols, table_name=table_name, heading=heading, lead=lead, tablefmt=tablefmt) + out = rows_dumps( + rows, + log_prio = log_prio, + caller = caller, + use_cols = use_cols, + skip_cols = skip_cols, + table_name = table_name, + heading = heading, + lead = lead, + tablefmt = tablefmt + ) match out_path: case 'log': - slog_m(log_prio, out, caller=caller) + slog_m(log_prio, out, caller = caller) case _: with open(out_path, 'w') as fp: fp.write(out) -def rows_to_csv(rows, use_tmpfile=False): # export +def rows_to_csv(rows, use_tmpfile = False): # export + def __write(rows, out): - writer = csv.DictWriter(out, fieldnames=field_names, delimiter=';', quotechar='"', quoting=csv.QUOTE_NONNUMERIC) + writer = csv.DictWriter( + out, + fieldnames = field_names, + delimiter = ';', + quotechar = '"', + quoting = csv.QUOTE_NONNUMERIC + ) writer.writeheader() for row in rows: writer.writerow(row) @@ -244,7 +312,7 @@ def rows_to_csv(rows, use_tmpfile=False): # export __write(rows, out) return out.getvalue() import tempfile - with tempfile.TemporaryFile(mode='w', newline='') as out: + with tempfile.TemporaryFile(mode = 'w', newline = '') as out: __write(rows, out) out.seek(0) return out.read() diff --git a/src/python/jw/util/db/schema/Column.py b/src/python/jw/util/db/schema/Column.py index a8fce9b..84cbd8c 100644 --- a/src/python/jw/util/db/schema/Column.py +++ b/src/python/jw/util/db/schema/Column.py @@ -1,17 +1,20 @@ -# -*- coding: utf-8 -*- - -from typing import Optional, Any +from __future__ import annotations import abc -from .DataType import DataType -from ...log import * +from typing import TYPE_CHECKING, Any, Optional -class Column(abc.ABC): # export +from ...log import ERR, throw - def __init__(self, table, name, data_type: DataType): +if TYPE_CHECKING: + from .DataType import DataType + from .Table import Table + +class Column(abc.ABC): # export + + def __init__(self, table: Table, name: str, data_type: DataType) -> None: self.__name: str = name - self.__table: Any = table + self.__table: Table = table self.__is_nullable: Optional[bool] = None self.__is_null_insertible: Optional[bool] = None self.__is_primary_key: Optional[bool] = None @@ -39,18 +42,18 @@ class Column(abc.ABC): # export return False return True throw(ERR, f'Tried to compare column {self} to type {type(rhs)}: {rhs}') - return False # Unreachable but requested by mypy + return False # Unreachable but requested by mypy @property def name(self) -> str: return self.__name @property - def data_type(self): + def data_type(self) -> DataType: return self.__data_type @property - def table(self) -> str: + def table(self) -> Table: return self.__table @property @@ -60,7 +63,7 @@ class Column(abc.ABC): # export return self.__is_nullable @property - def is_null_insertible(self): + def is_null_insertible(self) -> bool: if self.__is_null_insertible is None: ret = False if self.is_nullable: @@ -81,7 +84,9 @@ class Column(abc.ABC): # export @property def is_auto_increment(self) -> bool: if self.__is_auto_increment is None: - self.__is_auto_increment = self.__name in self.__table.auto_increment_columns + self.__is_auto_increment = ( + self.__name in self.__table.auto_increment_columns + ) return self.__is_auto_increment @property @@ -113,8 +118,8 @@ class Column(abc.ABC): # export def foreign_key(self, table) -> Optional[Any]: if self.__foreign_keys_by_table is None: self.__foreign_keys_by_table = dict() - for col in self.foreign_keys: # type: ignore # Any not iterable - assert(col.table.name not in self.__foreign_keys_by_table) + for col in self.foreign_keys: # type: ignore # Any not iterable + assert (col.table.name not in self.__foreign_keys_by_table) self.__foreign_keys_by_table[col.table.name] = col table_name = table if isinstance(table, str) else table.name return self.__foreign_keys_by_table.get(table_name) diff --git a/src/python/jw/util/db/schema/ColumnSet.py b/src/python/jw/util/db/schema/ColumnSet.py index 8a98cc8..47f1eb6 100644 --- a/src/python/jw/util/db/schema/ColumnSet.py +++ b/src/python/jw/util/db/schema/ColumnSet.py @@ -1,20 +1,24 @@ -# -*- coding: utf-8 -*- +from typing import Optional, Any -from typing import Optional, Iterable, Any +class ColumnSet: # export -class ColumnSet: # export - - def __init__(self, *args: list[Any], columns: list[Any]=[], table: Optional[Any]=None, names: Optional[list[str]]=None): + def __init__( + self, + *args: list[Any], + columns: list[Any] = [], + table: Optional[Any] = None, + names: Optional[list[str]] = None + ): self.__columns: list[Any] = [*args] self.__columns.extend(columns) self.__table = table if names is not None: - assert(table is not None) + assert (table is not None) for name in names: self.__columns.append(table.column(name)) if self.__table is not None: for col in columns: - assert(col.table == self.__table) + assert (col.table == self.__table) def __len__(self): return len(self.__columns) diff --git a/src/python/jw/util/db/schema/CompositeForeignKey.py b/src/python/jw/util/db/schema/CompositeForeignKey.py index 5a4062b..5794326 100644 --- a/src/python/jw/util/db/schema/CompositeForeignKey.py +++ b/src/python/jw/util/db/schema/CompositeForeignKey.py @@ -1,15 +1,18 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -from typing import Optional, Any +from typing import TYPE_CHECKING, Any, Optional -from ...log import * - -from .ColumnSet import ColumnSet +from ...log import WARNING, slog from .SingleForeignKey import SingleForeignKey -class CompositeForeignKey: # export +if TYPE_CHECKING: + from .ColumnSet import ColumnSet - def __init__(self, child_col_set: ColumnSet, parent_col_set: ColumnSet): # TODO: Implement alternative ways to construct +class CompositeForeignKey: # export + + def __init__( + self, child_col_set: ColumnSet, parent_col_set: ColumnSet + ): # TODO: Implement alternative ways to construct def __table(s): ret = None @@ -17,8 +20,8 @@ class CompositeForeignKey: # export if ret is None: ret = c.table else: - assert(ret == c.table) - assert(ret is not None) + assert (ret == c.table) + assert (ret is not None) return ret self.__child_col_set = child_col_set @@ -26,7 +29,7 @@ class CompositeForeignKey: # export self.__child_table = __table(child_col_set) self.__parent_table = __table(parent_col_set) - assert(len(self.__child_col_set) == len(self.__parent_col_set)) + assert (len(self.__child_col_set) == len(self.__parent_col_set)) self.__len = len(self.__child_col_set) self.__column_relations: Optional[list[SingleForeignKey]] = None self.__parent_columns_by_child_column: Optional[dict[str, Any]] = None @@ -46,7 +49,12 @@ class CompositeForeignKey: # export def __repr__(self): ret = self.__table_rel_str() - ret += ': ' + ', '.join([self.__cols_rel_str(rel.child_column, rel.parent_column) for rel in self.column_relations]) + ret += ': ' + ', '.join( + [ + self.__cols_rel_str(rel.child_column, rel.parent_column) + for rel in self.column_relations + ] + ) return ret def __eq__(self, rhs): @@ -73,21 +81,25 @@ class CompositeForeignKey: # export return self.__parent_col_set def parent_column(self, child_column) -> Any: - child_column_name = child_column if isinstance(child_column, str) else child_column.name + child_column if isinstance(child_column, str) else child_column.name if self.__parent_columns_by_child_column is None: d: dict[str, Any] = {} - assert(len(self.__child_col_set) == len(self.__parent_col_set)) + assert (len(self.__child_col_set) == len(self.__parent_col_set)) for i in range(0, len(self.__child_col_set)): d[self.__child_col_set[i].name] = self.__parent_col_set[i] self.__parent_columns_by_child_column = d return self.__parent_columns_by_child_column[child_column] def child_column(self, parent_column) -> Any: - slog(WARNING, f'{self}: Looking for child column belonging to parent column "{parent_column}"') - parent_column_name = parent_column if isinstance(parent_column, str) else parent_column.name + slog( + WARNING, + f'{self}: Looking for child column belonging to parent column ' + f'"{parent_column}"' + ) + parent_column if isinstance(parent_column, str) else parent_column.name if self.__child_columns_by_parent_column is None: d: dict[str, Any] = {} - assert(len(self.__parent_col_set) == len(self.__child_col_set)) + assert (len(self.__parent_col_set) == len(self.__child_col_set)) for i in range(0, len(self.__parent_col_set)): d[self.__parent_col_set[i].name] = self.__child_col_set[i] self.__child_columns_by_parent_column = d @@ -98,6 +110,8 @@ class CompositeForeignKey: # export ret = [] if self.__column_relations is None: for i in range(0, self.__len): - ret.append(SingleForeignKey(self.__child_col_set[i], self.__parent_col_set[i])) + ret.append( + SingleForeignKey(self.__child_col_set[i], self.__parent_col_set[i]) + ) self.__column_relations = ret return self.__column_relations diff --git a/src/python/jw/util/db/schema/DataType.py b/src/python/jw/util/db/schema/DataType.py index ecda420..d982101 100644 --- a/src/python/jw/util/db/schema/DataType.py +++ b/src/python/jw/util/db/schema/DataType.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- - -from typing import Optional -from enum import Enum, auto from datetime import datetime +from enum import Enum, auto +from typing import Optional -from ...log import * +from ...log import ERR, throw class Id(Enum): Integer = auto() @@ -16,7 +14,7 @@ class Id(Enum): Text = auto() Invalid = auto() -def py_type(type_id: Id) -> type: # export +def py_type(type_id: Id) -> type: # export match type_id: case Id.Integer: @@ -38,14 +36,17 @@ def py_type(type_id: Id) -> type: # export raise Exception(f'Unknown column type-id "{type_id}"') -class DataType: # export +class DataType: # export - def __init__(self, type_id: Id, size: Optional[int]=None): + def __init__(self, type_id: Id, size: Optional[int] = None): if not isinstance(type_id, Id): - throw(ERR, f'Passed type id "{type_id}" with unsupported data type {type(type_id)}') + throw( + ERR, + f'Passed type id "{type_id}" with unsupported data type {type(type_id)}' + ) if size is not None: - assert(isinstance(size, int)) - assert(size > 0) + assert (isinstance(size, int)) + assert (size > 0) self.__id = type_id self.__size = size @@ -80,4 +81,4 @@ class DataType: # export @property def py_type_annotation(self) -> str: - return self.py_type_str # FIXME: This is not always correct + return self.py_type_str # FIXME: This is not always correct diff --git a/src/python/jw/util/db/schema/Schema.py b/src/python/jw/util/db/schema/Schema.py index 67c0644..7fdf438 100644 --- a/src/python/jw/util/db/schema/Schema.py +++ b/src/python/jw/util/db/schema/Schema.py @@ -1,20 +1,20 @@ -# -*- coding: utf-8 -*- - -from typing import Optional, Iterable +from __future__ import annotations import abc -from ...log import * +from typing import TYPE_CHECKING, Iterable, Optional -from .Table import Table -from .Column import Column -from .DataType import DataType -from .CompositeForeignKey import CompositeForeignKey +from ...log import DEBUG, ERR, slog, throw -class Schema(abc.ABC): # export +if TYPE_CHECKING: + from .Column import Column + from .CompositeForeignKey import CompositeForeignKey + from .Table import Table + +class Schema(abc.ABC): # export def __init__(self) -> None: - self.___tables: Optional[list[Table]] = None + self.___tables: Optional[dict[str, Table]] = None self.__foreign_keys: Optional[list[CompositeForeignKey]] = None self.__access_defining_columns: Optional[list[str]] = None @@ -24,7 +24,7 @@ class Schema(abc.ABC): # export ret = dict() for name in self._table_names(): slog(DEBUG, f'Caching metadata for table "{name}"') - assert(isinstance(name, str)) + assert (isinstance(name, str)) ret[name] = self._table(name) self.___tables = ret return self.___tables @@ -39,7 +39,7 @@ class Schema(abc.ABC): # export @abc.abstractmethod def _table(self, name: str) -> Table: throw(ERR, "Called pure virtual base class method") - return None # type: ignore + return None # type: ignore @abc.abstractmethod def _foreign_keys(self) -> list[CompositeForeignKey]: @@ -50,7 +50,7 @@ class Schema(abc.ABC): # export pass @abc.abstractmethod - def _model_module_search_paths(self) -> list[tuple[str, type]]: + def _model_module_search_paths(self) -> list[tuple[str, type]]: pass # ------ API to be called @@ -62,7 +62,7 @@ class Schema(abc.ABC): # export yield from self.__tables.values() def __repr__(self): - return '|'.join([table.name for table in self.__tables]) + return '|'.join([table.name for table in self.__tables.values()]) def __getitem__(self, index): return self.__tables[index] @@ -90,13 +90,13 @@ class Schema(abc.ABC): # export def table(self, name: str) -> Table: return self.__tables[name] - def table_by_model_name(self, name: str, throw=False) -> Table: + def table_by_model_name(self, name: str, throw = False) -> Table: for table in self.__tables.values(): if table.model_name == name: return table if throw: raise Exception(f'Table "{name}" not found in database metadata') - return None # type: ignore + return None # type: ignore def primary_keys(self, table_name: str) -> Iterable[str]: return self.__tables[table_name].primary_keys @@ -105,5 +105,5 @@ class Schema(abc.ABC): # export return self.__tables[table_name].columns @property - def model_module_search_paths(self) -> list[tuple[str, type]]: + def model_module_search_paths(self) -> list[tuple[str, type]]: return self._model_module_search_paths() diff --git a/src/python/jw/util/db/schema/SingleForeignKey.py b/src/python/jw/util/db/schema/SingleForeignKey.py index 0045fdf..faa6dd2 100644 --- a/src/python/jw/util/db/schema/SingleForeignKey.py +++ b/src/python/jw/util/db/schema/SingleForeignKey.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -from typing import Optional, Any +from typing import TYPE_CHECKING -from .Column import Column -from .ColumnSet import ColumnSet +if TYPE_CHECKING: + from .Column import Column class SingleForeignKey: @@ -19,7 +19,10 @@ class SingleForeignKey: yield from self.__iterable def __repr__(self): - return f'{self.__child_col.table.name}.{self.__child_col.name} -> {self.__parent_col.table.name}.{self.__parent_col.name}' + return ( + f'{self.__child_col.table.name}.{self.__child_col.name} -> ' + '{self.__parent_col.table.name}.{self.__parent_col.name}' + ) def __getitem__(self, index): return self.__iterable[index] diff --git a/src/python/jw/util/db/schema/Table.py b/src/python/jw/util/db/schema/Table.py index 0f2fb3d..10fb076 100644 --- a/src/python/jw/util/db/schema/Table.py +++ b/src/python/jw/util/db/schema/Table.py @@ -1,47 +1,51 @@ -# -*- coding: utf-8 -*- - -from typing import Optional, Union, Iterable, Self, Any # TODO: Need any for many things, as I can't figure out how to avoid circular imports from here +from __future__ import annotations import abc +import re + from collections import OrderedDict +from typing import TYPE_CHECKING from urllib.parse import quote_plus -from ...log import * +from ...log import ERR, WARNING, slog, throw from ...misc import load_class - -from .ColumnSet import ColumnSet -from .DataType import DataType -from .CompositeForeignKey import CompositeForeignKey from .Column import Column +from .ColumnSet import ColumnSet -class Table(abc.ABC): # export +if TYPE_CHECKING: + from typing import Any, Iterable, Optional, Self, Union + + from .CompositeForeignKey import CompositeForeignKey + from .DataType import DataType + +class Table(abc.ABC): # export def __init__(self, schema, name: str): - assert(isinstance(name, str)) + assert (isinstance(name, str)) self.__schema = schema self.__name = name self.___columns: Optional[OrderedDict[str, Any]] = None - self.___foreign_key_parent_tables: Optional[OrderedDict[str, Any]] = None + self.___foreign_key_parent_tables: Optional[OrderedDict[str, Any]] = None self.__primary_keys: Optional[Iterable[str]] = None self.__unique_constraints: Optional[list[ColumnSet]] = None self.__foreign_key_constraints: Optional[list[CompositeForeignKey]] = None - self.__nullable_columns: Optional[Iterable[str]] = None - self.__non_nullable_columns: Optional[Iterable[str]] = None - self.__null_insertible_columns: Optional[Iterable[str]] = None - self.__not_null_insertible_columns: Optional[Iterable[str]] = None - self.__log_columns: Optional[Iterable[str]] = None - self.__edit_columns: Optional[Iterable[str]] = None - self.__translate_columns: Optional[Iterable[str]] = None - self.__display_columns: Optional[Iterable[str]] = None - self.__default_sort_columns: Optional[Iterable[str]] = None - self.__column_default: Optional[dict[str, Any]] = None - self.__base_location_rule: Optional[Iterable[str]] = None - self.__location_rule: Optional[Iterable[str]] = None - self.__row_location_rule: Optional[Iterable[str]] = None - self.__add_row_location_rule: Optional[Iterable[str]] = None - self.___add_child_row_location_rules: Optional[dict[str, str]] = None + self.__nullable_columns: Optional[Iterable[str]] = None + self.__non_nullable_columns: Optional[Iterable[str]] = None + self.__null_insertible_columns: Optional[Iterable[str]] = None + self.__not_null_insertible_columns: Optional[Iterable[str]] = None + self.__log_columns: Optional[Iterable[str]] = None + self.__edit_columns: Optional[Iterable[str]] = None + self.__translate_columns: Optional[Iterable[str]] = None + self.__display_columns: Optional[Iterable[str]] = None + self.__default_sort_columns: Optional[Iterable[str]] = None + self.__column_default: Optional[dict[str, Any]] = None + self.__base_location_rule: Optional[Iterable[str]] = None + self.__location_rule: Optional[Iterable[str]] = None + self.__row_location_rule: Optional[Iterable[str]] = None + self.__add_row_location_rule: Optional[Iterable[str]] = None + self.___add_child_row_location_rules: Optional[dict[str, str]] = None self.__foreign_keys_to_parent_table: Optional[OrderedDict[str, Any]] = None self.__relationships: Optional[list[tuple[str, Self]]] = None self.__model_class: Optional[Any] = None @@ -61,7 +65,8 @@ class Table(abc.ABC): # export if self.___foreign_key_parent_tables is None: self.___foreign_key_parent_tables = OrderedDict() for cfk in self.foreign_key_constraints: - self.___foreign_key_parent_tables[cfk.parent_table.name] = cfk.parent_table + self.___foreign_key_parent_tables[cfk.parent_table.name + ] = cfk.parent_table return self.___foreign_key_parent_tables @property @@ -77,12 +82,12 @@ class Table(abc.ABC): # export def __add_child_row_location_rules(self) -> dict[str, str]: if self.___add_child_row_location_rules is None: ret: dict[str, str] = {} - for foreign_table_name, foreign_table in self.__relationship_by_foreign_table.items(): - if len([self.foreign_keys_to_parent_table(foreign_table)]): - rule = self._add_child_row_location_rule(foreign_table_name) + for table_name, table in self.__relationship_by_foreign_table.items(): + if len([self.foreign_keys_to_parent_table(table)]): + rule = self._add_child_row_location_rule(table_name) if rule is None: continue - ret[foreign_table_name] = rule + ret[table_name] = rule self.___add_child_row_location_rules = ret return self.___add_child_row_location_rules @@ -108,7 +113,7 @@ class Table(abc.ABC): # export return False return True throw(ERR, f'Tried to compare table {self} to type {type(rhs)}: {rhs}') - return False # Unreachable but requested by mypy + return False # Unreachable but requested by mypy def __hash__(self) -> int: return hash(self.name) @@ -167,8 +172,9 @@ class Table(abc.ABC): # export slog(WARNING, f'Returning None model name for table {self.name}') return None - def _model_module_search_paths(self) -> list[tuple[str, type]]: - return self.schema.model_module_search_paths # Fall back to Schema-global default + def _model_module_search_paths(self) -> list[tuple[str, type]]: + # Fall back to Schema-global default + return self.schema.model_module_search_paths @abc.abstractmethod def _query_name(self) -> str: @@ -190,7 +196,9 @@ class Table(abc.ABC): # export for col in self.__schema.access_defining_columns: if col in self.primary_keys: ret += f'/<{col}>' - ret += self.base_location_rule + base = self.base_location_rule + if base is not None: + ret += base if isinstance(base, str) else 'what-goes-here?'.join(base) return ret def _row_location_rule(self) -> Optional[str]: @@ -261,7 +269,7 @@ class Table(abc.ABC): # export return None pattern = r'^' + model_name + '$' for module_path, base_class in self._model_module_search_paths(): - ret = load_class(module_path, base_class, class_name_filter=pattern) + ret = load_class(module_path, base_class, class_name_filter = pattern) if ret is not None: self.__model_class = ret break @@ -288,8 +296,8 @@ class Table(abc.ABC): # export return self.__location_rule def location(self, *args, **kwargs): - ret = self.location_rule - for token, val in kwargs.items(): # FIXME: Poor man's row location assembly + ret = str(self.location_rule) + for token, val in kwargs.items(): # FIXME: Poor man's row location assembly ret = re.sub(f'<{token}>', quote_plus(quote_plus(str(val))), ret) return ret @@ -300,9 +308,9 @@ class Table(abc.ABC): # export return self.__row_location_rule def row_location(self, *args, **kwargs): - ret = self.row_location_rule + ret = str(self.row_location_rule) for col in self.primary_keys: - if col in kwargs: # FIXME: Poor man's row location assembly + if col in kwargs: # FIXME: Poor man's row location assembly ret = re.sub(f'<{col}>', quote_plus(quote_plus(str(kwargs[col]))), ret) return ret @@ -313,9 +321,9 @@ class Table(abc.ABC): # export return self.__add_row_location_rule def add_row_location(self, *args, **kwargs) -> Optional[str]: - ret = self.add_row_location_rule + ret = str(self.add_row_location_rule) for col in self.primary_keys: - if col in kwargs: # FIXME: Poor man's row location assembly + if col in kwargs: # FIXME: Poor man's row location assembly ret = re.sub(f'<{col}>', quote_plus(quote_plus(str(kwargs[col]))), ret) return ret @@ -323,12 +331,14 @@ class Table(abc.ABC): # export def add_child_row_location_rules(self) -> Iterable[str]: return self.__add_child_row_location_rules.values() - def add_child_row_location_rule(self, child_table: Union[Self, str]) -> Optional[str]: + def add_child_row_location_rule(self, child_table: Union[Self, + str]) -> Optional[str]: if isinstance(child_table, Table): child_table = child_table.name return self.__add_child_row_location_rules.get(child_table) - def add_child_row_location(self, parent_table: Union[Self, str], **kwargs) -> Optional[str]: + def add_child_row_location(self, parent_table: Union[Self, str], + **kwargs) -> Optional[str]: ret = self.add_child_row_location_rule(parent_table) if isinstance(parent_table, str): parent_table = self.schema[parent_table] @@ -337,7 +347,11 @@ class Table(abc.ABC): # export for cfk in self.foreign_keys_to_parent_table(parent_table): for fk in cfk: if fk.parent_column.name in kwargs: - ret = re.sub(f'<{fk.child_column.name}>', quote_plus(quote_plus(str(kwargs[fk.parent_column.name]))), ret) + ret = re.sub( + f'<{fk.child_column.name}>', + quote_plus(quote_plus(str(kwargs[fk.parent_column.name]))), + ret + ) return ret @property @@ -425,7 +439,7 @@ class Table(abc.ABC): # export impl = self._unique_constraints() if impl is not None: for columns in impl: - ret.append(ColumnSet(columns=columns)) + ret.append(ColumnSet(columns = columns)) self.__unique_constraints = ret return self.__unique_constraints @@ -443,7 +457,8 @@ class Table(abc.ABC): # export def foreign_key_parent_tables(self): return self.__foreign_key_parent_tables.values() - def foreign_keys_to_parent_table(self, parent_table) -> Iterable[CompositeForeignKey]: + def foreign_keys_to_parent_table(self, + parent_table) -> Iterable[CompositeForeignKey]: if self.__foreign_keys_to_parent_table is None: self.__foreign_keys_to_parent_table = OrderedDict() for cfk in self.foreign_key_constraints: @@ -451,8 +466,12 @@ class Table(abc.ABC): # export if pt not in self.__foreign_keys_to_parent_table: self.__foreign_keys_to_parent_table[pt] = [] self.__foreign_keys_to_parent_table[pt].append(cfk) - parent_table_name = parent_table if isinstance(parent_table, str) else parent_table.name - return self.__foreign_keys_to_parent_table[parent_table_name] if parent_table_name in self.__foreign_keys_to_parent_table else [] + parent_table_name = parent_table if isinstance( + parent_table, str + ) else parent_table.name + return self.__foreign_keys_to_parent_table[ + parent_table_name + ] if parent_table_name in self.__foreign_keys_to_parent_table else [] @property def relationships(self) -> list[tuple[str, Self]]: diff --git a/src/python/jw/util/db/schema/utils.py b/src/python/jw/util/db/schema/utils.py index f2226c6..43f137b 100644 --- a/src/python/jw/util/db/schema/utils.py +++ b/src/python/jw/util/db/schema/utils.py @@ -1,12 +1,18 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -from .Schema import Schema +from typing import TYPE_CHECKING -from ...log import * +from ...log import NOTICE, slog -def check_schema(schema: Schema): # export +if TYPE_CHECKING: + from .Schema import Schema + +def check_schema(schema: Schema): # export slog(NOTICE, f'There are {len(schema)} tables in the database') for cfk in schema.foreign_key_constraints: for fk in cfk: if fk.child_column.data_type != fk.parent_column.data_type: - raise Exception(f'Type mismatch in foreign key {fk}: {fk.child_column.data_type} != {fk.parent_column.data_type}') + raise Exception( + f'Type mismatch in foreign key {fk}: {fk.child_column.data_type} ' + f'!= {fk.parent_column.data_type}' + ) diff --git a/src/python/jw/util/graph/yed/MapAttr2Shape.py b/src/python/jw/util/graph/yed/MapAttr2Shape.py index eaf2b13..2d8b9be 100644 --- a/src/python/jw/util/graph/yed/MapAttr2Shape.py +++ b/src/python/jw/util/graph/yed/MapAttr2Shape.py @@ -1,29 +1,35 @@ -# -*- coding: utf-8 -*- - -from collections.abc import Callable +from __future__ import annotations import xml.etree.ElementTree as ET -from ...log import * +from typing import TYPE_CHECKING -class MapAttr2Shape: # export +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any - def __init__(self, mappings: dict[str, str|Callable[[dict[str, str]], str]]|None=None) -> None: +class MapAttr2Shape: # export + + def __init__( + self, + mappings: dict[str, str | Callable[[dict[str, str]], str]] | None = None + ) -> None: self.__mappings = mappings if mappings is not None else {} self.__shape_node_key = 'd25' self.__ns_gml = "http://graphml.graphdrawing.org/xmlns" self.__ns = { # -- Standard GraphML - "": self.__ns_gml, - "xsi": "http://www.w3.org/2001/XMLSchema-instance", - "xsi:schemaLocation": "http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd", + "": self.__ns_gml, + "xsi": "http://www.w3.org/2001/XMLSchema-instance", + "xsi:schemaLocation": + "http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd", # -- YWorks GraphML - "java": "http://www.yworks.com/xml/yfiles-common/1.0/java", - "sys": "http://www.yworks.com/xml/yfiles-common/markup/primitives/2.0", - "x": "http://www.yworks.com/xml/yfiles-common/markup/2.0", - "y": "http://www.yworks.com/xml/graphml", - "yed": "http://www.yworks.com/xml/yed/3", + "java": "http://www.yworks.com/xml/yfiles-common/1.0/java", + "sys": "http://www.yworks.com/xml/yfiles-common/markup/primitives/2.0", + "x": "http://www.yworks.com/xml/yfiles-common/markup/2.0", + "y": "http://www.yworks.com/xml/graphml", + "yed": "http://www.yworks.com/xml/yed/3", } # https://stackoverflow.com/questions/4997848/ for name, url in self.__ns.items(): @@ -72,7 +78,7 @@ class MapAttr2Shape: # export ns, tag = tag.split(':') tag = '{' + self.__ns[ns] + '}' + tag attrib = content.get('a') or {} - el = ET.Element(tag, attrib=attrib) + el = ET.Element(tag, attrib = attrib) text = content.get('t') if text is not None: el.text = text @@ -81,10 +87,7 @@ class MapAttr2Shape: # export if children is not None: __add(el, children) - default_values = { - 'color': '#FFCC00', - 'text': '' - } + default_values = {'color': '#FFCC00', 'text': ''} values = {} for key, default in default_values.items(): @@ -98,11 +101,11 @@ class MapAttr2Shape: # export continue mapped = mapping(self.__attribs(node, keys)) values[key] = mapped or default - except: + except Exception: pass color = values['color'] - text = values['text'] + text = values['text'] has_text = 'true' if text else 'false' width_text = round(len(text) * 5.75, 5) if text else 0 @@ -110,61 +113,89 @@ class MapAttr2Shape: # export shape = { 'data': { - 'a': {'key': self.__shape_node_key}, + 'a': { + 'key': self.__shape_node_key + }, 'c': { 'y:ShapeNode': { 'a': {}, 'c': { - 'y:Geometry': {'a': {'height': '30.0', 'width': str(width_box), 'x': str(-(width_box / 2)), 'y':' -15.0'}}, - 'y:Fill': {'a': {'color': color, 'transparent': 'false'}}, - 'y:BorderStyle': {'a': {'color': '#000000', 'raised': 'false', 'type': 'line', 'width': '1.0'}}, + 'y:Geometry': { + 'a': { + 'height': '30.0', + 'width': str(width_box), + 'x': str(-(width_box / 2)), + 'y': ' -15.0' + } + }, + 'y:Fill': { + 'a': { + 'color': color, 'transparent': 'false' + } + }, + 'y:BorderStyle': { + 'a': { + 'color': '#000000', + 'raised': 'false', + 'type': 'line', + 'width': '1.0' + } + }, 'y:NodeLabel': { 'a': { - 'alignment': 'center', - 'autoSizePolicy': 'content', - 'fontFamily': 'Dialog', - 'fontSize': '12', - 'fontStyle': 'plain', - 'hasBackgroundColor': 'false', - 'hasLineColor': 'false', - 'hasText': has_text, - 'height': '18', - 'horizontalTextPosition': 'center', - 'iconTextGap': '4', - 'modelName': 'custom', - 'textColor': '#000000', - 'verticalTextPosition': 'bottom', - 'visible': 'true', - 'width': str(width_text), - 'x': '13.0', - 'y': '13.0', + 'alignment': 'center', + 'autoSizePolicy': 'content', + 'fontFamily': 'Dialog', + 'fontSize': '12', + 'fontStyle': 'plain', + 'hasBackgroundColor': 'false', + 'hasLineColor': 'false', + 'hasText': has_text, + 'height': '18', + 'horizontalTextPosition': 'center', + 'iconTextGap': '4', + 'modelName': 'custom', + 'textColor': '#000000', + 'verticalTextPosition': 'bottom', + 'visible': 'true', + 'width': str(width_text), + 'x': '13.0', + 'y': '13.0', }, 'c': { - 'y:LabelModel': { - 'c': { - 'y:SmartNodeLabelModel': {'a': {'distance': '4.0'}} - }, - }, - 'y:ModelParameter': { - 'c': { - 'y:SmartNodeLabelModelParameter': { - 'a': { - 'labelRatioX':'0.0', - 'labelRatioY': '0.0', - 'nodeRatioX': '0.0', - 'nodeRatioY': '0.0', - 'offsetX': '0.0', - 'offsetY': '0.0', - 'upX': '0.0', - 'upY': '-1.0', - } - } - } - } + 'y:LabelModel': { + 'c': { + 'y:SmartNodeLabelModel': { + 'a': { + 'distance': '4.0' + } + } + }, + }, + 'y:ModelParameter': { + 'c': { + 'y:SmartNodeLabelModelParameter': { + 'a': { + 'labelRatioX': '0.0', + 'labelRatioY': '0.0', + 'nodeRatioX': '0.0', + 'nodeRatioY': '0.0', + 'offsetX': '0.0', + 'offsetY': '0.0', + 'upX': '0.0', + 'upY': '-1.0', + } + } + } + } }, 't': text - }, - 'y:Shape': {'a': {'type': 'rectangle'}} + }, + 'y:Shape': { + 'a': { + 'type': 'rectangle' + } + } } } } @@ -175,17 +206,17 @@ class MapAttr2Shape: # export def __massage_nodes(self, root) -> None: keys = self.__keys(root) - graph = root.find(f'graph', self.__ns) + graph = root.find('graph', self.__ns) for node in graph: self.__massage_node(node, keys) def run(self, path_in: str, path_out: str) -> None: - parser = ET.XMLParser(encoding="utf-8") - tree = ET.parse(path_in, parser=parser) + parser = ET.XMLParser(encoding = "utf-8") + tree = ET.parse(path_in, parser = parser) root = tree.getroot() self.__add_key_nodegraphics(root) self.__massage_nodes(root) - ET.indent(tree, space=' ', level=0) - tree.write(path_out, xml_declaration=True, encoding='utf-8') + ET.indent(tree, space = ' ', level = 0) + tree.write(path_out, xml_declaration = True, encoding = 'utf-8') diff --git a/src/python/jw/util/ldap.py b/src/python/jw/util/ldap.py index bbd94ef..2a49fcd 100644 --- a/src/python/jw/util/ldap.py +++ b/src/python/jw/util/ldap.py @@ -1,36 +1,47 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations + +import copy +import getpass +import pathlib -import ldap, getpass, pathlib, copy -from ldap.schema.models import ObjectClass from enum import Flag, auto -import networkx as nx -from typing import Any, Self -from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Self -from .Config import Config as BaseConfig -from .log import * +import ldap # type: ignore[import-untyped] +import networkx as nx # type: ignore[import-untyped] + +from ldap.schema.models import ObjectClass # type: ignore[import-untyped] + +from .log import ERR, INFO, WARNING, slog + +if TYPE_CHECKING: + from collections.abc import Callable + + from .Config import Config as BaseConfig class Config: - def __init__(self, external: BaseConfig|None=None): + + def __init__(self, external: BaseConfig | None = None): self.__external = external for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']: setattr(self, '_Config__' + attr, None) - def __get(self, key: str, default: str): + def __get(self, key: str, default: str | None): if not self.__external: return default - return self.__external.value(key, default=default) + return self.__external.value(key, default = default) @property def ldap_uri(self): if self.__ldap_uri is None: for key in ['ldap_uri', 'uri']: - self.__ldap_uri = self.__get(key, default=None) + self.__ldap_uri = self.__get(key, default = None) if self.__ldap_uri is not None: break else: self.__ldap_uri = 'ldap://ldap.janware.com' return self.__ldap_uri + @ldap_uri.setter def ldap_uri(self, rhs): self.__ldap_uri = rhs @@ -38,8 +49,12 @@ class Config: @property def bind_dn(self): if self.__bind_dn is None: - self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de') + self.__bind_dn = self.__get( + 'bind_dn', + default = f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de' + ) return self.__bind_dn + @bind_dn.setter def bind_dn(self, rhs): self.__bind_dn = rhs @@ -48,17 +63,21 @@ class Config: def bind_pw(self): if self.__bind_pw is None: for key in ['bind_pw', 'password']: - ret = self.__get(key, default=None) + ret = self.__get(key, default = None) if ret is not None: break if ret is None: - ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret') + ldap_secret_file = self.__get( + 'secret_file', f'{pathlib.Path.home()}/.ldap.secret' + ) + assert ldap_secret_file is not None, 'ldap_secret_file' with open(ldap_secret_file, 'r') as file: ret = file.read() file.closed ret = ret.strip() self.__bind_pw = ret return self.__bind_pw + @bind_pw.setter def bind_pw(self, rhs): self.__bind_pw = rhs @@ -66,25 +85,28 @@ class Config: @property def base_dn(self): if self.__base_dn is None: - self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de') + self.__base_dn = self.__get('base_dn', default = 'dc=jannet,dc=de') return self.__base_dn + @base_dn.setter def base_dn(self, rhs): self.__base_dn = rhs -class Connection: # export +class Connection: # export class AttrType(Flag): Must = auto() - May = auto() + May = auto() - def __init__(self, conf: Config|BaseConfig|None=None, backtrace=False): - uri: str|None = None + def __init__(self, conf: Config | BaseConfig | None = None, backtrace = False): + uri: str | None = None c = conf if isinstance(conf, Config) else Config(conf) try: uri = c.ldap_uri - except: - uri = c.uri + except Exception: + # mypy says: E: "Config" has no attribute "uri" [attr-defined] + # FIXME: Who adds .uri? + uri = c.uri # type: ignore try: ret = ldap.initialize(uri) ret.start_tls_s() @@ -92,46 +114,60 @@ class Connection: # export slog(ERR, f'Failed to initialize LDAP connection to "{uri}" ({str(e)})') raise try: - rr = ret.bind_s(c.bind_dn, c.bind_pw) # method) + ret.bind_s(c.bind_dn, c.bind_pw) # method) except Exception as e: slog(ERR, f'Failed to bind to "{uri}" with dn "{c.bind_dn}" ({str(e)})') raise self.__ldap = ret self.__backtrace = backtrace - self.__object_classes_by_oid: dict[str, ObjectClass]|None = None - self.__object_class_tree: nx.Graph|None = None - self.__object_classes_by_name: dict[str, ObjectClass]|None = None + self.__object_classes_by_oid: dict[str, ObjectClass] | None = None + self.__object_class_tree: nx.Graph | None = None + self.__object_classes_by_name: dict[str, ObjectClass] | None = None @property def ldap(self): return self.__ldap - def add(self, attrs: dict[str, bytes], dn: str|None=None): + def add(self, attrs: dict[str, bytes], dn: str | None = None): if dn is None: - if not 'dn' in attrs: + if 'dn' not in attrs: raise Exception('No DN to add an LDAP entry to') attrs = copy.deepcopy(attrs) del attrs['dn'] try: slog(INFO, f'LDAP: Add [{dn}] -> {attrs}') - self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs)) + ml =ldap.modlist.addModlist( # pyright: ignore[reportAttributeAccessIssue] + attrs + ) + self.__ldap.add_s(dn, ml) except Exception as e: slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})') raise - def delete(self, dn: str, recursive=False, force_existence: bool=False): + def delete(self, dn: str, recursive = False, force_existence: bool = False): def __walk_cb_delete(conn: Connection, entry, context): - self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context) + self.walk( + __walk_cb_delete, + base = entry[0], + scope = ldap. + SCOPE_ONELEVEL, # pyright: ignore[reportAttributeAccessIssue] + context = context + ) self.__ldap.delete_s(entry[0]) try: if recursive: - self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL) + self.walk( + __walk_cb_delete, + dn, + scope = ldap. + SCOPE_ONELEVEL # pyright: ignore[reportAttributeAccessIssue] + ) self.__ldap.delete_s(dn) else: self.__ldap.delete_s(dn) - except ldap.NO_SUCH_OBJECT as e: + except ldap.NO_SUCH_OBJECT: # pyright: ignore[reportAttributeAccessIssue] if force_existence: raise except Exception as e: @@ -139,37 +175,42 @@ class Connection: # export raise def walk( - self, - callback: Callable[[Self, Any, Any], None], - base: str, - scope, - context=None, - filterstr=None, - attrlist=None, - attrsonly=0, - serverctrls=None, - clientctrls=None, - timeout=-1, - sizelimit=0, - decode: bool=False, - unroll: bool=False - ): + self, + callback: Callable[[Self, Any, Any], None], + base: str, + scope, + context = None, + filterstr = None, + attrlist = None, + attrsonly = 0, + serverctrls = None, + clientctrls = None, + timeout = -1, + sizelimit = 0, + decode: bool = False, + unroll: bool = False + ): # TODO: Support ignored arguments - search_return = self.__ldap.search(base=base, - scope=scope, - filterstr=filterstr, - attrlist=attrlist, - attrsonly=attrsonly) + search_return = self.__ldap.search( + base = base, + scope = scope, + filterstr = filterstr, + attrlist = attrlist, + attrsonly = attrsonly + ) while True: result_type, result_data = self.__ldap.result(search_return, 0) - if (result_data == []): + if (not result_data): break - if result_type != ldap.RES_SEARCH_ENTRY: + if result_type != ldap.RES_SEARCH_ENTRY: # pyright: ignore[reportAttributeAccessIssue] continue for entry in result_data: if decode: - entry = entry[0], {key: [val.decode() for val in vals] for key, vals in entry[1].items()} + entry = entry[0], { + key: [val.decode() for val in vals] + for key, vals in entry[1].items() + } if unroll and False: entry = entry[0], {key: val[0] for key, val in entry[1].items()} try: @@ -182,19 +223,20 @@ class Connection: # export slog(WARNING, msg) continue - def find(self, - base: str, - scope, - filterstr=None, - attrlist=None, - attrsonly=0, - serverctrls=None, - clientctrls=None, - timeout=-1, - sizelimit=0, - assert_unique=False, - assert_not_empty=False, - ): + def find( + self, + base: str, + scope, + filterstr = None, + attrlist = None, + attrsonly = 0, + serverctrls = None, + clientctrls = None, + timeout = -1, + sizelimit = 0, + assert_unique = False, + assert_not_empty = False, + ): def __walk_cb_find(conn: Connection, entry: Any, context: Any): result.append(entry) @@ -204,7 +246,13 @@ class Connection: # export try: result: list[Any] = [] - self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist) + self.walk( + __walk_cb_find, + base, + scope = scope, + filterstr = filterstr, + attrlist = attrlist + ) except Exception as e: slog(ERR, f'Failed search {__search()} ({e})') raise @@ -216,17 +264,34 @@ class Connection: # export @property def object_classes(self) -> dict[str, ObjectClass]: - #def object_classes(self): + #def object_classes(self): if self.__object_classes_by_oid is None: - res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry']) - dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema - res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+']) + res = self.find( + base = '', + scope = ldap.SCOPE_BASE, # pyright: ignore[reportAttributeAccessIssue] + filterstr = '(objectClass=*)', + attrlist = ['subschemaSubentry'] + ) + dn = res[0][1]['subschemaSubentry'][0].decode( + 'utf-8' + ) # Usually yields cn=Subschema + res = self.find( + base = dn, + scope = ldap.SCOPE_BASE, # pyright: ignore[reportAttributeAccessIssue] + filterstr = '(objectClass=*)', + attrlist = ['*', '+'] + ) subschema_entry = res[0] - subschema_subentry = ldap.cidict.cidict(subschema_entry[1]) - subschema = ldap.schema.SubSchema(subschema_subentry) + subschema_subentry = ldap.cidict.cidict( # pyright: ignore[reportAttributeAccessIssue] + subschema_entry[1] + ) + subschema = ldap.schema.SubSchema( # pyright: ignore[reportAttributeAccessIssue] + subschema_subentry + ) object_class_oids = subschema.listall(ObjectClass) self.__object_classes_by_oid = { - oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids + oid: subschema.get_obj(ObjectClass, oid) + for oid in object_class_oids } return self.__object_classes_by_oid @@ -242,15 +307,18 @@ class Connection: # export ret[name.lower()] = oc return self.__object_classes_by_name - def __oc_recurse_to_top(self, cur: str|ObjectClass, cb, context): - cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()] + def __oc_recurse_to_top(self, cur: str | ObjectClass, cb, context): + cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[ + cur.lower()] for s in cur_oc.sup: self.__oc_recurse_to_top(s, cb, context) cb(cur_oc, context) - def object_class_path(self, leaf: str|ObjectClass): + def object_class_path(self, leaf: str | ObjectClass): + def cb(oc, context): ret.append(oc) + ret: list[str] = [] self.__oc_recurse_to_top(leaf, cb, None) return reversed(ret) @@ -262,47 +330,55 @@ class Connection: # export def collect(root, attr): ret = set() + def cb(oc, attr): vals = getattr(oc, attr) if vals is None: return for val in vals: ret.add(val) + self.__oc_recurse_to_top(root, cb, attr) return ret - kind = { - 0: 'STRUCTURAL', - 1: 'ABSTRACT', - 2: 'AUXILIARY' - } + kind = {0: 'STRUCTURAL', 1: 'ABSTRACT', 2: 'AUXILIARY'} ret = nx.DiGraph() for oid, oc in self.object_classes.items(): ret.add_node( - oid, - oid=oid, - name=oc.names[0], - kind=kind[oc.kind], - must=', '.join(collect(oc, 'must')), - may=', '.join(collect(oc, 'may')) - ) + oid, + oid = oid, + name = oc.names[0], + kind = kind[oc.kind], + must = ', '.join(collect(oc, 'must')), + may = ', '.join(collect(oc, 'may')) + ) for base_class in oc.sup: try: - ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid) + ret.add_edge( + oid, self.object_class_by_name[base_class.lower()].oid + ) except Exception as e: - slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})') + slog( + WARNING, + f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})' + ) self.__object_class_tree = ret return self.__object_class_tree - def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]: + def object_class_attrs( + self, + oc: str | ObjectClass, + required: AttrType = AttrType.Must, + origins: bool = False + ) -> dict[str, set[str]] | set[str]: all_attrs: set[str] = set() attrs_by_origin: dict[str, set[str]] = {} for oc in self.object_class_path(oc): cur = set() if required & self.AttrType.Must: - cur |= set(oc.must) + cur |= set(oc.must) # pyright: ignore[reportAttributeAccessIssue] if required & self.AttrType.May: - cur |= set(oc.may) + cur |= set(oc.may) # pyright: ignore[reportAttributeAccessIssue] if cur: all_attrs |= cur attrs_by_origin[oc] = cur @@ -313,10 +389,14 @@ class Connection: # export #base_oid = self.object_class_by_name[base_candidate].oid #if base_oid in [oc.oid for oc in self.object_class_path(name)]: # return True - return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid) + return nx.has_path( + self.object_class_tree, + self.object_class_by_name[name.lower()].oid, + self.object_class_by_name[base_candidate.lower()].oid + ) -def default_config() -> Config: # export +def default_config() -> Config: # export return Config() -def bind(conf: Config|BaseConfig|None=None) -> Connection: +def bind(conf: Config | BaseConfig | None = None) -> Connection: return Connection(conf) diff --git a/src/python/jw/util/log.py b/src/python/jw/util/log.py index b324a5b..f7f8f3a 100644 --- a/src/python/jw/util/log.py +++ b/src/python/jw/util/log.py @@ -1,22 +1,28 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations, print_function -from __future__ import print_function +import inspect +import re +import sys +import syslog +import unicodedata -from typing import List, Tuple, Optional, Any - -import sys, re, io, syslog, inspect, unicodedata - -from os.path import basename from datetime import datetime +from os.path import basename +from typing import TYPE_CHECKING, Any, List, Optional, Tuple from . import misc +if TYPE_CHECKING: + import io + # --- python 2 / 3 compatibility stuff try: - basestring # type: ignore + basestring # type: ignore except NameError: basestring = str +# fmt: disable # don't conflate +# yapf: disable # don't conflate _special_chars = { '\a' : '\\a', '\b' : '\\b', @@ -26,12 +32,20 @@ _special_chars = { '\f' : '\\f', '\r' : '\\r', } +# yapf: enable +# fmt: enable -_special_char_regex = re.compile("(%s)" % "|".join(map(re.escape, _special_chars.keys()))) +_special_char_regex = re.compile( + "(%s)" % "|".join(map(re.escape, _special_chars.keys())) +) -_all_control_chars = ''.join(chr(c) for c in range(sys.maxunicode) if unicodedata.category(chr(c)) in {'Cc'}) -_clean_str_regex = re.compile(r'(\033\[[0-9]*m|[%s])' % re.escape(_all_control_chars)) +_all_control_chars = ''.join( + chr(c) for c in range(sys.maxunicode) if unicodedata.category(chr(c)) in {'Cc'} +) +_clean_str_regex = re.compile(r'(\033\[[0-9]*m|[%s])' % re.escape(_all_control_chars)) +# fmt: disable # don't conflate +# yapf: disable # don't conflate EMERG = int(syslog.LOG_EMERG) ALERT = int(syslog.LOG_ALERT) CRIT = int(syslog.LOG_CRIT) @@ -98,44 +112,49 @@ _prio_colors = { EMERG : [ CONSOLE_FONT_BOLD + CONSOLE_FONT_MAGENTA, CONSOLE_FONT_OFF ], } +# yapf: enable +# fmt: enable + class Stream: + def __init__(self, stream, flags): self.stream = stream self.flags = flags _streams: dict[int, Stream] = dict() -_stream_descriptors = [reversed(range(1, 16))] +_stream_descriptors = list(reversed(range(1, 16))) -def add_capture_stream(stream, flags=0x0): +def add_capture_stream(stream, flags = 0x0): ret = _stream_descriptors.pop() - _streams[ret] = Stream(stream=stream, flags=flags) + _streams[ret] = Stream(stream = stream, flags = flags) return ret def rm_capture_stream(sd): del _streams[sd] _stream_descriptors.append(sd) -def prio_gets_logged(prio: int) -> bool: # export +def prio_gets_logged(prio: int) -> bool: # export if prio > _level: return False return True -def log_level(s: Optional[str]=None) -> int: # export +def log_level(s: Optional[str] = None) -> int: # export if s is None: return _level return parse_log_prio_str(s) -def get_caller_pos(up: int = 1, kwargs: Optional[dict[str, Any]] = None) -> Tuple[str, str, int]: +def get_caller_pos(up: int = 1, + kwargs: Optional[dict[str, Any]] = None) -> Tuple[str, str, int]: if kwargs and 'caller' in kwargs: r = kwargs['caller'] del kwargs['caller'] return r - caller = inspect.stack()[up+1] + caller = inspect.stack()[up + 1] mod = inspect.getmodule(caller[0]) mod_name = '' if mod is None else mod.__name__ return (mod_name, basename(caller.filename), caller.lineno) -def slog_m(prio: int, *args, **kwargs) -> None: # export +def slog_m(prio: int, *args, **kwargs) -> None: # export if prio > _level: return if len(args): @@ -151,9 +170,9 @@ def slog_m(prio: int, *args, **kwargs) -> None: # export caller = kwargs['caller'] del kwargs['caller'] for line in margs[1:].split('\n'): - slog(prio, line, **kwargs, caller=caller) + slog(prio, line, **kwargs, caller = caller) -def slog(prio: int, *args, only_printable: bool=False, **kwargs) -> None: # export +def slog(prio: int, *args, only_printable: bool = False, **kwargs) -> None: # export if prio > _level: return @@ -188,11 +207,13 @@ def slog(prio: int, *args, only_printable: bool=False, **kwargs) -> None: # expo for a in args: margs += ' ' + str(a) if only_printable: - margs = _special_char_regex.sub(lambda mo: _special_chars[mo.string[mo.start():mo.end()]], margs) + margs = _special_char_regex.sub( + lambda mo: _special_chars[mo.string[mo.start():mo.end()]], margs + ) margs = re.sub('[\x01-\x1f]', '.', margs) for file in _log_file_streams: - print(msg + _clean_log_prefix + margs, file=file) + print(msg + _clean_log_prefix + margs, file = file) msg += _log_prefix @@ -215,24 +236,26 @@ def slog(prio: int, *args, only_printable: bool=False, **kwargs) -> None: # expo files.append(sys.stderr) if not len(files): - files = [ sys.stdout ] + files = [sys.stdout] for file in files: - print(msg, file=file) + print(msg, file = file) -def throw(*args, prio=ERR, caller=None, **kwargs) -> None: +def throw(*args, prio = ERR, caller = None, **kwargs) -> None: if caller is None: caller = get_caller_pos(1) msg = ' '.join([str(arg) for arg in args]) - slog(prio, msg, caller=caller) + slog(prio, msg, caller = caller) raise Exception(msg) -def parse_log_prio_str(prio: str) -> int: # export +def parse_log_prio_str(prio: str) -> int: # export try: r = int(prio) if r < 0 or r > DEVEL: raise Exception("Invalid log priority ", prio) except ValueError: + # fmt: disable # don't conflate + # yapf: disable # don't conflate map_prio_str_to_val = { "EMERG" : EMERG, "emerg" : EMERG, @@ -255,23 +278,25 @@ def parse_log_prio_str(prio: str) -> int: # export "OFF" : OFF, "off" : OFF, } + # yapf: enable + # fmt: enable if prio in map_prio_str_to_val: return map_prio_str_to_val[prio] raise Exception("Unknown priority string \"", prio, "\"") -def console_color_chars(prio: int) -> List[str]: # export +def console_color_chars(prio: int) -> List[str]: # export if not sys.stdout.isatty(): - return [ '', '' ] + return ['', ''] return _prio_colors[prio] -def set_level(level_: str) -> None: # export +def set_level(level_: str) -> None: # export global _level if isinstance(level_, basestring): _level = parse_log_prio_str(level_) return _level = level_ -def set_flags(flags: str|None) -> str: # export +def set_flags(flags: str | None) -> str: # export global _flags ret = ','.join(_flags) if flags is not None: @@ -293,7 +318,7 @@ def set_flags(flags: str|None) -> str: # export #pid #highlight_first_error -def append_to_prefix(prefix: str) -> str: # export +def append_to_prefix(prefix: str) -> str: # export global _log_prefix global _clean_log_prefix r = _log_prefix @@ -302,7 +327,7 @@ def append_to_prefix(prefix: str) -> str: # export _clean_log_prefix = _clean_str_regex.sub('', _log_prefix) return r -def remove_from_prefix(count) -> str: # export +def remove_from_prefix(count) -> str: # export if isinstance(count, str): count = len(count) global _log_prefix @@ -312,21 +337,21 @@ def remove_from_prefix(count) -> str: # export _clean_log_prefix = _clean_str_regex.sub('', _log_prefix) return r -def set_filename_length(l: int) -> int: # export +def set_filename_length(length: int) -> int: # export global _file_name_len r = _file_name_len - if l: - _file_name_len = l + if length: + _file_name_len = length return r -def set_module_name_length(l: int) -> int: # export +def set_module_name_length(length: int) -> int: # export global _module_name_len r = _module_name_len - if l: - _module_name_len = l + if length: + _module_name_len = length return r -def add_log_file(path: str) -> None: # export +def add_log_file(path: str) -> None: # export global _log_file_streams - fd = open(path, 'w', buffering=1) + fd = open(path, 'w', buffering = 1) _log_file_streams.append(fd) diff --git a/src/python/jw/util/misc.py b/src/python/jw/util/misc.py index e1f9781..4b53e13 100644 --- a/src/python/jw/util/misc.py +++ b/src/python/jw/util/misc.py @@ -1,6 +1,11 @@ -# -*- coding: utf-8 -*- - -import os, errno, atexit, tempfile, filecmp, inspect, importlib, re +import atexit +import errno +import filecmp +import importlib +import inspect +import os +import re +import tempfile from typing import Iterable @@ -12,12 +17,12 @@ def _cleanup(): for f in _tmpfiles: silentremove(f) -def silentremove(filename): #export +def silentremove(filename): #export try: os.remove(filename) except OSError as e: if e.errno != errno.ENOENT: - raise # re-raise exception if a different error occurred + raise # re-raise exception if a different error occurred def update_symlink(target, link_name): try: @@ -38,12 +43,14 @@ def pad(token: str, total_size: int, right_align: bool = False) -> str: return space + token return token + space -def atomic_store(contents, path): # export +def atomic_store(contents, path): # export if path[0:3] == '/dev': with open(path, 'w') as outfile: outfile.write(contents) return - outfile = tempfile.NamedTemporaryFile(prefix=os.path.basename(path), delete=False, dir=os.path.dirname(path)) + outfile = tempfile.NamedTemporaryFile( + prefix = os.path.basename(path), delete = False, dir = os.path.dirname(path) + ) name = outfile.name _tmpfiles.add(name) outfile.write(contents) @@ -52,7 +59,7 @@ def atomic_store(contents, path): # export _tmpfiles.remove(name) # see https://stackoverflow.com/questions/2020014 -def object_builtin_name(o, full=True): # export +def object_builtin_name(o, full = True): # export #if not full: # return o.__class__.__name__ module = o.__class__.__module__ @@ -60,7 +67,7 @@ def object_builtin_name(o, full=True): # export return o.__class__.__name__ # Avoid reporting __builtin__ return module + '.' + o.__class__.__name__ -def get_derived_classes(mod, base, flt=None): # export +def get_derived_classes(mod, base, flt = None): # export members = inspect.getmembers(mod, inspect.isclass) r = [] for name, c in members: @@ -68,8 +75,10 @@ def get_derived_classes(mod, base, flt=None): # export if inspect.isabstract(c): log.slog(log.DEBUG, " is abstract") continue - if not base in inspect.getmro(c): - log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c)) + if base not in inspect.getmro(c): + log.slog( + log.DEBUG, " is not derived from", base, "only", inspect.getmro(c) + ) continue if flt and not re.match(flt, name): log.slog(log.DEBUG, ' "{}.{}" has wrong name'.format(mod, name)) @@ -77,7 +86,7 @@ def get_derived_classes(mod, base, flt=None): # export r.append(c) return r -def load_classes(path, baseclass, flt=None): # export +def load_classes(path, baseclass, flt = None): # export r = [] for p in path.split(':'): mod = importlib.import_module(path) @@ -85,22 +94,28 @@ def load_classes(path, baseclass, flt=None): # export r.extend(get_derived_classes(mod, baseclass, flt)) return r -def load_class(module_path, baseclass, class_name_filter=None): # export +def load_class(module_path, baseclass, class_name_filter = None): # export mod = importlib.import_module(module_path) - classes = get_derived_classes(mod, baseclass, flt=class_name_filter) + classes = get_derived_classes(mod, baseclass, flt = class_name_filter) if len(classes) == 0: - raise Exception(f'no class matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"') + raise Exception( + f'no class matching "{class_name_filter}" of type "{baseclass}" ' + f'found in module "{module_path}"' + ) if len(classes) > 1: - raise Exception(f'{len(classes)} classes matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"') + raise Exception( + f'{len(classes)} classes matching "{class_name_filter}" of type ' + f'"{baseclass}" found in module "{module_path}"' + ) return classes[0] -def load_class_names(path, baseclass, flt=None, remove_flt=False): # export +def load_class_names(path, baseclass, flt = None, remove_flt = False): # export classes = load_classes(path, baseclass, flt) r = [] for c in classes: name = c.__name__ if flt and remove_flt: - name = re.subst(flt, "", name) + name = re.sub(flt, '', name) if name not in r: r.append(name) else: @@ -108,64 +123,72 @@ def load_class_names(path, baseclass, flt=None, remove_flt=False): # export #log.slog(log.WARNING, "{} is already in in {}".format(name, r)) return r -def load_object(module_path, baseclass, class_name_filter=None, *args, **kwargs): # export - return load_class(module_path, baseclass, class_name_filter=class_name_filter)(*args, **kwargs) +def load_object( # export + module_path, baseclass, class_name_filter = None, *args, **kwargs +): + return load_class( + module_path, baseclass, class_name_filter = class_name_filter + )(*args, **kwargs) -def load_function(module_path, name): # export +def load_function(module_path, name): # export mod = importlib.import_module(module_path) return getattr(mod, name) -def commit_tmpfile(tmp: str, path: str) -> None: # export +def commit_tmpfile(tmp: str, path: str) -> None: # export caller = log.get_caller_pos() if os.path.isfile(path) and filecmp.cmp(tmp, path): - log.slog(log.INFO, "{} is up to date".format(path), caller=caller) + log.slog(log.INFO, "{} is up to date".format(path), caller = caller) os.unlink(tmp) else: - log.slog(log.NOTICE, "saving {}".format(path), caller=caller) + log.slog(log.NOTICE, "saving {}".format(path), caller = caller) os.rename(path + '.tmp', path) -def multi_regex_edit(spec, strings): # export +def multi_regex_edit(spec, strings): # export for cmd in spec: if len(cmd) < 2: - raise Exception('Invalid command in multi_regex_edit(): {}'.format(str(cmd))) + raise Exception( + 'Invalid command in multi_regex_edit(): {}'.format(str(cmd)) + ) if cmd[0] == 'sub': rx = re.compile(cmd[1]) replacement = cmd[2] r = [] - for l in strings: - r.append(re.sub(rx, replacement, l)) + for string in strings: + r.append(re.sub(rx, replacement, string)) strings = r continue if cmd[0] == 'del': rx = re.compile(cmd[1]) r = [] - for l in strings: - if rx.search(l) is not None: + for string in strings: + if rx.search(string) is not None: continue - r.append(l) + r.append(string) strings = r continue if cmd[0] == 'match': rx = re.compile(cmd[1]) r = [] - for l in strings: - if rx.search(l) is not None: - r.append(l) + for string in strings: + if rx.search(string) is not None: + r.append(string) strings = r continue raise Exception('Invalid command in multi_regex_edit(): {}'.format(str(cmd))) return strings -def dump(prio: int, objects: Iterable, *args, **kwargs) -> None: # export - caller = log.get_caller_pos(kwargs=kwargs) - log.slog(prio, ",---------- {}".format(' '.join(args)), caller=caller) +def dump(prio: int, objects: Iterable, *args, **kwargs) -> None: # export + caller = log.get_caller_pos(kwargs = kwargs) + log.slog(prio, ",---------- {}".format(' '.join(args)), caller = caller) prefix = " | " log.append_to_prefix(prefix) i = 1 for o in objects: - o.dump(prio, "{} ({})".format(i, o.__class__.__name__), caller=caller, **kwargs) + o.dump( + prio, "{} ({})".format(i, o.__class__.__name__), caller = caller, **kwargs + ) i += 1 log.remove_from_prefix(prefix) - log.slog(prio, "`---------- {}".format(' '.join(args)), caller=caller) + log.slog(prio, "`---------- {}".format(' '.join(args)), caller = caller) atexit.register(_cleanup) diff --git a/src/python/jw/util/multi_key_dict.py b/src/python/jw/util/multi_key_dict.py index 333e72e..85a37a1 100644 --- a/src/python/jw/util/multi_key_dict.py +++ b/src/python/jw/util/multi_key_dict.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - ''' Created on 26 May 2013 @@ -14,65 +13,74 @@ ___________________________________ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software -without restriction, including without limitation the rights to use, copy, modify, merge, -publish, distribute, sub-license, and/or sell copies of the Software, and to permit persons -to whom the Software is furnished to do so, subject to the following conditions: +without restriction, including without limitation the rights to use, copy, modify, +merge, publish, distribute, sub-license, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to the following +conditions: - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR -PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE -FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR -OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. +PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT +OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. ''' import platform + _python3 = int(platform.python_version_tuple()[0]) >= 3 class multi_key_dict(object): - """ The purpose of this type is to provide a multi-key dictionary. - This kind of dictionary has a similar interface to the standard dictionary, and indeed if used + """ + The purpose of this type is to provide a multi-key dictionary. This kind of + dictionary has a similar interface to the standard dictionary, and indeed if used with single key key elements - it's behaviour is the same as for a standard dict(). - However it also allows for creation of elements using multiple keys (using tuples/lists). - Such elements can be accessed using either of those keys (e.g read/updated/deleted). - Dictionary provides also an extended interface for iterating over items and keys by the key type. - This can be useful e.g.: when creating dictionaries with (index,name) allowing one to iterate over - items using either: names or indexes. It can be useful for many many other similar use-cases, - and there is no limit to the number of keys used to map to the value. + However it also allows for creation of elements using multiple keys (using + tuples/lists). Such elements can be accessed using either of those keys (e.g + read/updated/deleted). Dictionary provides also an extended interface for iterating + over items and keys by the key type. This can be useful e.g.: when creating + dictionaries with (index,name) allowing one to iterate over items using either: + names or indexes. It can be useful for many many other similar use-cases, and there + is no limit to the number of keys used to map to the value. - There are also methods to find other keys mapping to the same value as the specified keys etc. - Refer to examples and test code to see it in action. + There are also methods to find other keys mapping to the same value as the specified + keys etc. Refer to examples and test code to see it in action. simple example: k = multi_key_dict() k[100] = 'hundred' # add item to the dictionary (as for normal dictionary) # but also: - # below creates entry with two possible key types: int and str, + # below creates entry with two possible key types: int and str, # mapping all keys to the assigned value k[1000, 'kilo', 'k'] = 'kilo (x1000)' print k[1000] # will print 'kilo (x1000)' print k['k'] # will also print 'kilo (x1000)' - # the same way objects can be updated, and if an object is updated using one key, the new value will - # be accessible using any other key, e.g. for example above: + # the same way objects can be updated, and if an object is updated using one + # key, the new value will be accessible using any other key, e.g. for example + # above: k['kilo'] = 'kilo' print k[1000] # will print 'kilo' as value was updated """ - def __init__(self, mapping_or_iterable=None, **kwargs): - """ Initializes dictionary from an optional positional argument and a possibly empty set of keyword arguments.""" + def __init__(self, mapping_or_iterable = None, **kwargs): + """ Initializes dictionary from an optional positional argument and a possibly + empty set of keyword arguments.""" self.items_dict = {} if mapping_or_iterable is not None: if type(mapping_or_iterable) is dict: mapping_or_iterable = mapping_or_iterable.items() for kv in mapping_or_iterable: if len(kv) != 2: - raise Exception('Iterable should contain tuples with exactly two values but specified: {0}.'.format(kv)) + raise Exception( + 'Iterable should contain tuples with exactly two values ' + 'but specified: {0}.'.format(kv) + ) self[kv[0]] = kv[1] for keys, value in kwargs.items(): self[keys] = value @@ -83,20 +91,19 @@ class multi_key_dict(object): def __setitem__(self, keys, value): """ Set the value at index (or list of indexes) specified as keys. - Note, that if multiple key list is specified, either: - - none of keys should map to an existing item already (item creation), or + Note, that if multiple key list is specified, either: + - none of keys should map to an existing item already (item creation), or - all of keys should map to exactly the same item (as previously created) (item update) If this is not the case - KeyError is raised. """ - if(type(keys) in [tuple, list]): - at_least_one_key_exists = False + if (type(keys) in [tuple, list]): num_of_keys_we_have = 0 for x in keys: try: self.__getitem__(x) num_of_keys_we_have += 1 - except Exception as err: + except Exception: continue if num_of_keys_we_have: @@ -112,36 +119,37 @@ class multi_key_dict(object): if new != direct_key: all_select_same_item = False break - except Exception as err: + except Exception: all_select_same_item = False - break; + break if not all_select_same_item: raise KeyError(', '.join(str(key) for key in keys)) - first_key = keys[0] # combination if keys is allowed, simply use the first one + first_key = keys[ + 0] # combination if keys is allowed, simply use the first one else: first_key = keys - key_type = str(type(first_key)) # find the intermediate dictionary.. + key_type = str(type(first_key)) # find the intermediate dictionary.. if first_key in self: - self.items_dict[self.__dict__[key_type][first_key]] = value # .. and update the object if it exists.. + self.items_dict[self.__dict__[key_type][first_key] + ] = value # .. and update the object if it exists.. else: - if(type(keys) not in [tuple, list]): + if (type(keys) not in [tuple, list]): key = keys keys = [keys] - self.__add_item(value, keys) # .. or create it - if it doesn't + self.__add_item(value, keys) # .. or create it - if it doesn't def __delitem__(self, key): """ Called to implement deletion of self[key].""" key_type = str(type(key)) - if (key in self and - self.items_dict and - (self.__dict__[key_type][key] in self.items_dict) ): + if (key in self and self.items_dict + and (self.__dict__[key_type][key] in self.items_dict)): intermediate_key = self.__dict__[key_type][key] - # remove the item in main dictionary + # remove the item in main dictionary del self.items_dict[intermediate_key] # and remove all references (if there were other keys) @@ -166,10 +174,10 @@ class multi_key_dict(object): """ Returns True if this object contains an item referenced by the key.""" return key in self - def get_other_keys(self, key, including_current=False): - """ Returns list of other keys that are mapped to the same value as specified key. - @param key - key for which other keys should be returned. - @param including_current if set to True - key will also appear on this list.""" + def get_other_keys(self, key, including_current = False): + """ Returns list of other keys that are mapped to the same value as specified + key. @param key - key for which other keys should be returned. @param + including_current if set to True - key will also appear on this list.""" other_keys = [] if key in self: other_keys.extend(self.__dict__[str(type(key))][key]) @@ -177,12 +185,17 @@ class multi_key_dict(object): other_keys.remove(key) return other_keys - def iteritems(self, key_type=None, return_all_keys=False): + def iteritems(self, key_type = None, return_all_keys = False): """ Returns an iterator over the dictionary's (key, value) pairs. - @param key_type if specified, iterator will be returning only (key,value) pairs for this type of key. - Otherwise (if not specified) ((keys,...), value) - i.e. (tuple of keys, values) pairs for all items in this dictionary will be generated. - @param return_all_keys if set to True - tuple of keys is retuned instead of a key of this type.""" + + @param key_type if specified, iterator will be returning only (key,value) + pairs for this type of key. + + Otherwise (if not specified) ((keys,...), value) i.e. (tuple of keys, + values) pairs for all items in this dictionary will be generated. + + @param return_all_keys if set to True - tuple of keys is retuned instead of + a key of this type.""" if key_type is None: for item in self.items_dict.items(): @@ -200,29 +213,34 @@ class multi_key_dict(object): keys = tuple(k for k in keys if isinstance(k, key_type)) yield keys, value - def iterkeys(self, key_type=None, return_all_keys=False): + def iterkeys(self, key_type = None, return_all_keys = False): """ Returns an iterator over the dictionary's keys. - @param key_type if specified, iterator for a dictionary of this type will be used. + @param key_type if specified, iterator for a dictionary of this type will + be used. Otherwise (if not specified) tuples containing all (multiple) keys for this dictionary will be generated. - @param return_all_keys if set to True - tuple of keys is retuned instead of a key of this type.""" - if(key_type is not None): + @param return_all_keys if set to True - tuple of keys is retuned instead of + a key of this type.""" + if (key_type is not None): the_key = str(key_type) if the_key in self.__dict__: for key in self.__dict__[the_key].keys(): if return_all_keys: yield self.__dict__[the_key][key] else: - yield key + yield key else: for keys in self.items_dict.keys(): yield keys - def itervalues(self, key_type=None): + def itervalues(self, key_type = None): """ Returns an iterator over the dictionary's values. - @param key_type if specified, iterator will be returning only values pointed by keys of this type. - Otherwise (if not specified) all values in this dictinary will be generated.""" - if(key_type is not None): + @param key_type if specified, iterator will be returning only values pointed + by keys of this type. + Otherwise (if not specified) all values in this dictinary will be + generated.""" + + if (key_type is not None): intermediate_key = str(key_type) if intermediate_key in self.__dict__: for direct_key in self.__dict__[intermediate_key].values(): @@ -232,37 +250,42 @@ class multi_key_dict(object): yield value if _python3: - items = iteritems + items = iteritems # type: ignore else: - def items(self, key_type=None, return_all_keys=False): + + def items(self, key_type = None, return_all_keys = False): return list(self.iteritems(key_type, return_all_keys)) + items.__doc__ = iteritems.__doc__ - def keys(self, key_type=None): + def keys(self, key_type = None): """ Returns a copy of the dictionary's keys. @param key_type if specified, only keys for this type will be returned. - Otherwise list of tuples containing all (multiple) keys will be returned.""" + Otherwise list of tuples containing all (multiple) keys will be + returned.""" if key_type is not None: intermediate_key = str(key_type) if intermediate_key in self.__dict__: return self.__dict__[intermediate_key].keys() else: - all_keys = {} # in order to preserve keys() type (dict_keys for python3) + all_keys = {} # in order to preserve keys() type (dict_keys for python3) for keys in self.items_dict.keys(): all_keys[keys] = None return all_keys.keys() - def values(self, key_type=None): + def values(self, key_type = None): """ Returns a copy of the dictionary's values. - @param key_type if specified, only values pointed by keys of this type will be returned. - Otherwise list of all values contained in this dictionary will be returned.""" - if(key_type is not None): - all_items = {} # in order to preserve keys() type (dict_values for python3) + @param key_type if specified, only values pointed by keys of this type + will be returned + Otherwise list of all values contained in this dictionary will be + returned.""" + if (key_type is not None): + all_items = {} # in order to preserve keys() type (dict_values for python3) keys_used = set() direct_key = str(key_type) if direct_key in self.__dict__: for intermediate_key in self.__dict__[direct_key].values(): - if not intermediate_key in keys_used: + if intermediate_key not in keys_used: all_items[intermediate_key] = self.items_dict[intermediate_key] keys_used.add(intermediate_key) return all_items.values() @@ -276,26 +299,29 @@ class multi_key_dict(object): length = len(self.items_dict) return length - def __add_item(self, item, keys=None): + def __add_item(self, item, keys = None): """ Internal method to add an item to the multi-key dictionary""" - if(not keys or not len(keys)): - raise Exception('Error in %s.__add_item(%s, keys=tuple/list of items): need to specify a tuple/list containing at least one key!' - % (self.__class__.__name__, str(item))) - direct_key = tuple(keys) # put all keys in a tuple, and use it as a key + if (not keys or not len(keys)): + raise Exception( + 'Error in %s.__add_item(%s, keys=tuple/list of items): need to specify' + 'a tuple/list containing at least one key!' % + (self.__class__.__name__, str(item)) + ) + direct_key = tuple(keys) # put all keys in a tuple, and use it as a key for key in keys: key_type = str(type(key)) # store direct key as a value in an intermediate dictionary - if(not key_type in self.__dict__): + if (key_type not in self.__dict__): self.__setattr__(key_type, dict()) self.__dict__[key_type][key] = direct_key # store the value in the actual dictionary - if(not 'items_dict' in self.__dict__): + if ('items_dict' not in self.__dict__): self.items_dict = dict() self.items_dict[direct_key] = item - def get(self, key, default=None): + def get(self, key, default = None): """ Return the value at index specified as key.""" if key in self: return self.items_dict[self.__dict__[str(type(key))][key]] @@ -304,74 +330,91 @@ class multi_key_dict(object): def __str__(self): items = [] - str_repr = lambda x: '\'%s\'' % x if type(x) == str else str(x) + + def str_repr(x): + return '\'%s\'' % x if isinstance(x, str) else str(x) + if hasattr(self, 'items_dict'): for (keys, value) in self.items(): keys_str = [str_repr(k) for k in keys] - items.append('(%s): %s' % (', '.join(keys_str), - str_repr(value))) - dict_str = '{%s}' % ( ', '.join(items)) + items.append('(%s): %s' % (', '.join(keys_str), str_repr(value))) + dict_str = '{%s}' % (', '.join(items)) return dict_str def test_multi_key_dict(): - contains_all = lambda cont, in_items: not (False in [c in cont for c in in_items]) + + def contains_all(cont, in_items): + return False not in [c in cont for c in in_items] m = multi_key_dict() - assert( len(m) == 0 ), 'expected len(m) == 0' + assert (len(m) == 0), 'expected len(m) == 0' all_keys = list() m['aa', 12, 32, 'mmm'] = 123 # create a value with multiple keys.. - assert( len(m) == 1 ), 'expected len(m) == 1' - all_keys.append(('aa', 'mmm', 32, 12)) # store it for later + assert (len(m) == 1), 'expected len(m) == 1' + all_keys.append(('aa', 'mmm', 32, 12)) # store it for later # try retrieving other keys mapped to the same value using one of them res = m.get_other_keys('aa') expected = ['mmm', 32, 12] - assert(set(res) == set(expected)), 'get_other_keys(\'aa\'): {0} other than expected: {1} '.format(res, expected) - # try retrieving other keys mapped to the same value using one of them: also include this key + assert (set(res) == set(expected)), ( + 'get_other_keys(\'aa\'): {0} other ' + 'than expected: {1} '.format(res, expected) + ) + + # try retrieving other keys mapped to the same value using one of them: also include + # this key res = m.get_other_keys(32, True) expected = ['aa', 'mmm', 32, 12] - assert(set(res) == set(expected)), 'get_other_keys(32): {0} other than expected: {1} '.format(res, expected) + assert (set(res) == set(expected)), ( + 'get_other_keys(32): {0} other than expected: ' + '{1} '.format(res, expected) + ) - assert( m.has_key('aa') == True ), 'expected m.has_key(\'aa\') == True' - assert( m.has_key('aab') == False ), 'expected m.has_key(\'aab\') == False' + assert (m.has_key('aa')), 'expected m.has_key(\'aa\') == True' + assert (not m.has_key('aab')), 'expected m.has_key(\'aab\') == False' - assert( m.has_key(12) == True ), 'expected m.has_key(12) == True' - assert( m.has_key(13) == False ), 'expected m.has_key(13) == False' - assert( m.has_key(32) == True ), 'expected m.has_key(32) == True' + assert (m.has_key(12)), 'expected m.has_key(12) == True' + assert (not m.has_key(13)), 'expected m.has_key(13) == False' + assert (m.has_key(32)), 'expected m.has_key(32) == True' m['something else'] = 'abcd' - assert( len(m) == 2 ), 'expected len(m) == 2' - all_keys.append(('something else',)) # store for later + assert (len(m) == 2), 'expected len(m) == 2' + all_keys.append(('something else', )) # store for later m[23] = 0 - assert( len(m) == 3 ), 'expected len(m) == 3' - all_keys.append((23,)) # store for later + assert (len(m) == 3), 'expected len(m) == 3' + all_keys.append((23, )) # store for later # check if it's possible to read this value back using either of keys - assert( m['aa'] == 123 ), 'expected m[\'aa\'] == 123' - assert( m[12] == 123 ), 'expected m[12] == 123' - assert( m[32] == 123 ), 'expected m[32] == 123' - assert( m['mmm'] == 123 ), 'expected m[\'mmm\'] == 123' + assert (m['aa'] == 123), 'expected m[\'aa\'] == 123' + assert (m[12] == 123), 'expected m[12] == 123' + assert (m[32] == 123), 'expected m[32] == 123' + assert (m['mmm'] == 123), 'expected m[\'mmm\'] == 123' # now update value and again - confirm it back - using different keys.. m['aa'] = 45 - assert( m['aa'] == 45 ), 'expected m[\'aa\'] == 45' - assert( m[12] == 45 ), 'expected m[12] == 45' - assert( m[32] == 45 ), 'expected m[32] == 45' - assert( m['mmm'] == 45 ), 'expected m[\'mmm\'] == 45' + assert (m['aa'] == 45), 'expected m[\'aa\'] == 45' + assert (m[12] == 45), 'expected m[12] == 45' + assert (m[32] == 45), 'expected m[32] == 45' + assert (m['mmm'] == 45), 'expected m[\'mmm\'] == 45' m[12] = '4' - assert( m['aa'] == '4' ), 'expected m[\'aa\'] == \'4\'' - assert( m[12] == '4' ), 'expected m[12] == \'4\'' + assert (m['aa'] == '4'), 'expected m[\'aa\'] == \'4\'' + assert (m[12] == '4'), 'expected m[12] == \'4\'' # test __str__ - m_str_exp = '{(23): 0, (\'aa\', \'mmm\', 32, 12): \'4\', (\'something else\'): \'abcd\'}' + m_str_exp = ( + '{(23): 0, (\'aa\', \'mmm\', 32, 12): \'4\', ' + '(\'something else\'): \'abcd\'}' + ) m_str = str(m) - assert (len(m_str) > 0), 'str(m) should not be empty!' - assert (m_str[0] == '{'), 'str(m) should start with \'{\', but does with \'%c\'' % m_str[0] - assert (m_str[-1] == '}'), 'str(m) should end with \'}\', but does with \'%c\'' % m_str[-1] + assert (len(m_str) > 0), 'str(m) should not be empty!' + assert (m_str[0] == '{' + ), ('str(m) should start with \'{\', but does with \'%c\'' % m_str[0]) + assert (m_str[-1] == '}' + ), ('str(m) should end with \'}\', but does with \'%c\'' % m_str[-1]) # check if all key-values are there as expected. They might be sorted differently def get_values_from_str(dict_str): @@ -381,41 +424,52 @@ def test_multi_key_dict(): keys = tuple(sorted([k.strip() for k in keys.split(',')])) sorted_keys_and_values.append((keys, val)) return sorted_keys_and_values + exp = get_values_from_str(m_str_exp) act = get_values_from_str(m_str) - assert (set(act) == set(exp)), 'str(m) values: \'{0}\' are not {1} '.format(act, exp) + assert (set(act) == set(exp) + ), ('str(m) values: \'{0}\' are not {1} '.format(act, exp)) # try accessing / creating new (keys)-> value mapping whilst one of these # keys already maps to a value in this dictionaries try: m['aa', 'bb'] = 'something new' - assert(False), 'Should not allow adding multiple-keys when one of keys (\'aa\') already exists!' - except KeyError as err: + assert(False), ( + 'Should not allow adding multiple-keys when one of keys ' + '(\'aa\') already exists!' + ) + except KeyError: pass # now check if we can get all possible keys (formed in a list of tuples) # each tuple containing all keys) - res = sorted([sorted([str(x) for x in k]) for k in m.keys()]) + res = sorted([sorted([str(x) for x in k]) for k in m.keys()]) # type: ignore expected = sorted([sorted([str(x) for x in k]) for k in all_keys]) - assert(res == expected), 'unexpected values from m.keys(), got:\n%s\n expected:\n%s' %(res, expected) + assert (res == expected), ( + 'unexpected values from m.keys(), got:\n%s\n expected:\n%s' % (res, expected) + ) # check default items (which will unpack tupe with key(s) and value) num_of_elements = 0 for keys, value in m.items(): sorted_keys = sorted([str(k) for k in keys]) num_of_elements += 1 - assert(sorted_keys in expected), 'm.items(): unexpected keys: %s' % (sorted_keys) - assert(m[keys[0]] == value), 'm.items(): unexpected value: %s (keys: %s)' % (value, keys) - assert(num_of_elements > 0), 'm.items() returned generator that did not produce anything' + assert (sorted_keys + in expected), ('m.items(): unexpected keys: %s' % (sorted_keys)) + assert (m[keys[0]] == value + ), ('m.items(): unexpected value: %s (keys: %s)' % (value, keys)) + assert (num_of_elements + > 0), ('m.items() returned generator that did not produce anything') # test default iterkeys() num_of_elements = 0 - for keys in m.keys(): + for keys in m.keys(): # type: ignore num_of_elements += 1 keys_s = sorted([str(k) for k in keys]) - assert(keys_s in expected), 'm.keys(): unexpected keys: {0}'.format(keys_s) + assert (keys_s in expected), 'm.keys(): unexpected keys: {0}'.format(keys_s) - assert(num_of_elements > 0), 'm.iterkeys() returned generator that did not produce anything' + assert (num_of_elements + > 0), ('m.iterkeys() returned generator that did not produce anything') # test iterkeys(int, True): useful to get all info from the dictionary # dictionary is iterated over the type specified, but all keys are returned. @@ -423,75 +477,93 @@ def test_multi_key_dict(): for keys in m.iterkeys(int, True): keys_s = sorted([str(k) for k in keys]) num_of_elements += 1 - assert(keys_s in expected), 'm.iterkeys(int, True): unexpected keys: {0}'.format(keys_s) - assert(num_of_elements > 0), 'm.iterkeys(int, True) returned generator that did not produce anything' - + assert (keys_s in expected + ), ('m.iterkeys(int, True): unexpected keys: {0}'.format(keys_s)) + assert (num_of_elements > 0), ( + 'm.iterkeys(int, True) returned generator that did not produce anything' + ) # test values for different types of keys() expected = set([0, '4']) res = set(m.values(int)) - assert (res == expected), 'm.values(int) are {0}, but expected: {1}.'.format(res, expected) + assert (res == expected + ), ('m.values(int) are {0}, but expected: {1}.'.format(res, expected)) expected = sorted(['4', 'abcd']) - res = sorted(m.values(str)) - assert (res == expected), 'm.values(str) are {0}, but expected: {1}.'.format(res, expected) + res = sorted(m.values(str)) + assert (res == expected + ), ('m.values(str) are {0}, but expected: {1}.'.format(res, expected)) - current_values = set([0, '4', 'abcd']) # default (should give all values) - res = set(m.values()) - assert (res == current_values), 'm.values() are {0}, but expected: {1}.'.format(res, current_values) + current_values = set([0, '4', 'abcd']) # default (should give all values) + res = set(m.values()) + assert (res == current_values + ), ('m.values() are {0}, but expected: {1}.'.format(res, current_values)) + + #test itervalues() (default) - should return all values. (Itervalues for other types + # are tested below) - #test itervalues() (default) - should return all values. (Itervalues for other types are tested below) vals = set() for value in m.itervalues(): vals.add(value) - assert (current_values == vals), 'itervalues(): expected {0}, but collected {1}'.format(current_values, vals) + assert (current_values == vals), ( + 'itervalues(): expected {0}, but collected {1}'.format(current_values, vals) + ) #test items(int) - items_for_int = sorted([((12, 32), '4'), ((23,), 0)]) - assert (items_for_int == sorted(m.items(int))), 'items(int): expected {0}, but collected {1}'.format(items_for_int, - sorted(m.items(int))) + items_for_int = sorted([((12, 32), '4'), ((23, ), 0)]) + assert (items_for_int == sorted(m.items(int))), ( + 'items(int): expected {0}, but collected {1}'.format( + items_for_int, sorted(m.items(int)) + ) + ) # test items(str) - items_for_str = set([(('aa','mmm'), '4'), (('something else',), 'abcd')]) + items_for_str = set([(('aa', 'mmm'), '4'), (('something else', ), 'abcd')]) res = set(m.items(str)) - assert (set(res) == items_for_str), 'items(str): expected {0}, but collected {1}'.format(items_for_str, res) + assert (set(res) == items_for_str), ( + 'items(str): expected {0}, but collected {1}'.format(items_for_str, res) + ) # test items() (default - all items) - # we tested keys(), values(), and __get_item__ above so here we'll re-create all_items using that + # we tested keys(), values(), and __get_item__ above so here we'll re-create + # all_items using that all_items = set() keys = m.keys() - values = m.values() - for k in keys: - all_items.add( (tuple(k), m[k[0]]) ) + m.values() + for k in keys: # type: ignore + all_items.add((tuple(k), m[k[0]])) res = set(m.items()) - assert (all_items == res), 'items() (all items): expected {0},\n\t\t\t\tbut collected {1}'.format(all_items, res) + assert (all_items == res), ( + 'items() (all items): expected {0},\n\t\t\t\tbut ' + 'collected {1}'.format(all_items, res) + ) # now test deletion.. curr_len = len(m) del m[12] - assert( len(m) == curr_len - 1 ), 'expected len(m) == %d' % (curr_len - 1) - assert(not m.has_key(12)), 'expected deleted key to no longer be found!' + assert (len(m) == curr_len - 1), 'expected len(m) == %d' % (curr_len - 1) + assert (not m.has_key(12)), 'expected deleted key to no longer be found!' - # try again + # try again try: del m['aa'] - assert(False), 'cant remove again: item m[\'aa\'] should not exist!' - except KeyError as err: + assert (False), 'cant remove again: item m[\'aa\'] should not exist!' + except KeyError: pass - # try to access non-existing + # try to access non-existing try: - k = m['aa'] - assert(False), 'removed item m[\'aa\'] should not exist!' - except KeyError as err: + k = m['aa'] + assert (False), 'removed item m[\'aa\'] should not exist!' + except KeyError: pass - # try to access non-existing with a different key + # try to access non-existing with a different key try: - k = m[12] - assert(False), 'removed item m[12] should not exist!' - except KeyError as err: + k = m[12] + assert (False), 'removed item m[12] should not exist!' + except KeyError: pass # prepare for other tests (also testing creation of new items) @@ -499,11 +571,12 @@ def test_multi_key_dict(): m = multi_key_dict() tst_range = list(range(10, 40)) + list(range(50, 70)) for i in tst_range: - m[i] = i # will create a dictionary, where keys are same as items + m[i] = i # will create a dictionary, where keys are same as items # test items() for key, value in m.items(int): - assert(key == (value,)), 'items(int): expected {0}, but received {1}'.format(key, value) + assert (key == (value, ) + ), ('items(int): expected {0}, but received {1}'.format(key, value)) # test iterkeys() num_of_elements = 0 @@ -511,88 +584,103 @@ def test_multi_key_dict(): for key in m.iterkeys(int): returned_keys.add(key) num_of_elements += 1 - assert(num_of_elements > 0), 'm.iteritems(int) returned generator that did not produce anything' - assert (returned_keys == set(tst_range)), 'iterkeys(int): expected {0}, but received {1}'.format(expected, key) - + assert (num_of_elements + > 0), ('m.iteritems(int) returned generator that did not produce anything') + assert (returned_keys == set(tst_range) + ), ('iterkeys(int): expected {0}, but received {1}'.format(expected, key)) #test itervalues(int) num_of_elements = 0 - returned_values = set() + returned_values = set() for value in m.itervalues(int): returned_values.add(value) num_of_elements += 1 - assert (num_of_elements > 0), 'm.itervalues(int) returned generator that did not produce anything' - assert (returned_values == set(tst_range)), 'itervalues(int): expected {0}, but received {1}'.format(expected, value) + assert (num_of_elements > 0 + ), ('m.itervalues(int) returned generator that did not produce anything') + assert (returned_values == set(tst_range)), ( + 'itervalues(int): expected {0}, ' + 'but received {1}'.format(expected, value) + ) # test values(int) res = sorted([x for x in m.values(int)]) assert (res == tst_range), 'm.values(int) is not as expected.' # test keys() - assert (set(m.keys(int)) == set(tst_range)), 'm.keys(int) is not as expected.' + assert (set(m.keys(int)) == set(tst_range)), 'm.keys(int) is not as expected.' # type: ignore # test setitem with multiple keys m['xy', 999, 'abcd'] = 'teststr' try: m['xy', 998] = 'otherstr' - assert(False), 'creating / updating m[\'xy\', 998] should fail!' - except KeyError as err: + assert (False), 'creating / updating m[\'xy\', 998] should fail!' + except KeyError: pass # test setitem with multiple keys m['cd'] = 'somethingelse' try: m['cd', 999] = 'otherstr' - assert(False), 'creating / updating m[\'cd\', 999] should fail!' - except KeyError as err: + assert (False), 'creating / updating m[\'cd\', 999] should fail!' + except KeyError: pass m['xy', 999] = 'otherstr' - assert (m['xy'] == 'otherstr'), 'm[\'xy\'] is not as expected.' - assert (m[999] == 'otherstr'), 'm[999] is not as expected.' + assert (m['xy'] == 'otherstr'), 'm[\'xy\'] is not as expected.' + assert (m[999] == 'otherstr'), 'm[999] is not as expected.' assert (m['abcd'] == 'otherstr'), 'm[\'abcd\'] is not as expected.' - m['abcd', 'xy'] = 'another' - assert (m['xy'] == 'another'), 'm[\'xy\'] is not == \'another\'.' - assert (m[999] == 'another'), 'm[999] is not == \'another\'' + m['abcd', 'xy'] = 'another' + assert (m['xy'] == 'another'), 'm[\'xy\'] is not == \'another\'.' + assert (m[999] == 'another'), 'm[999] is not == \'another\'' assert (m['abcd'] == 'another'), 'm[\'abcd\'] is not == \'another\'.' # test get functionality of basic dictionaries m['CanIGet'] = 'yes' assert (m.get('CanIGet') == 'yes') - assert (m.get('ICantGet') == None) + assert (m.get('ICantGet') is None) assert (m.get('ICantGet', "Ok") == "Ok") k = multi_key_dict() k['1:12', 1] = 'key_has_:' - k.items() # should not cause any problems to have : in key + k.items() # should not cause any problems to have : in key assert (k[1] == 'key_has_:'), 'k[1] is not equal to \'abc:def:ghi\'' import datetime n = datetime.datetime.now() - l = multi_key_dict() - l[n] = 'now' # use datetime obj as a key + d = multi_key_dict() + d[n] = 'now' # use datetime obj as a key #test keys.. - res = [x for x in l.keys()][0] # for python3 keys() returns dict_keys dictionarly + res = [ + x for x in d.keys( + ) # type: ignore # for python3 keys() returns dict_keys dictionarly + ][0] expected = n, - assert(expected == res), 'Expected \"{0}\", but got: \"{1}\"'.format(expected, res) + assert (expected == res), 'Expected \"{0}\", but got: \"{1}\"'.format(expected, res) - res = [x for x in l.keys(datetime.datetime)][0] - assert(n == res), 'Expected {0} as a key, but got: {1}'.format(n, res) - - res = [x for x in l.values()] # for python3 keys() returns dict_values dictionarly + res = [x for x in d.keys(datetime.datetime)][0] # type: ignore + assert (n == res), 'Expected {0} as a key, but got: {1}'.format(n, res) + + res = [x for x in d.values()] # for python3 keys() returns dict_values dictionarly expected = ['now'] - assert(res == expected), 'Expected values: {0}, but got: {1}'.format(expected, res) + assert (res == expected), 'Expected values: {0}, but got: {1}'.format(expected, res) # test items.. - exp_items = [((n,), 'now')] - r = list(l.items()) - assert(r == exp_items), 'Expected for items(): tuple of keys: {0}, but got: {1}'.format(r, exp_items) - assert(exp_items[0][1] == 'now'), 'Expected for items(): value: {0}, but got: {1}'.format('now', - exp_items[0][1]) + exp_items = [((n, ), 'now')] + r = list(d.items()) + assert (r == exp_items), ( + 'Expected for items(): tuple of keys: {0}, but got: {1}'.format(r, exp_items) + ) + assert (exp_items[0][1] == 'now'), ( + 'Expected for items(): value: {0}, but got: {1}'.format('now', exp_items[0][1]) + ) - x = multi_key_dict({('k', 'kilo'):1000, ('M', 'MEGA', 1000000):1000000}, milli=0.01) + x = multi_key_dict( + { + ('k', 'kilo'): 1000, ('M', 'MEGA', 1000000): 1000000 + }, milli = 0.01 + ) assert (x['k'] == 1000), 'x[\'k\'] is not equal to 1000' x['kilo'] = 'kilo' assert (x['kilo'] == 'kilo'), 'x[\'kilo\'] is not equal to \'kilo\'' @@ -605,11 +693,13 @@ def test_multi_key_dict(): try: y = multi_key_dict([(('two', 'duo'), 2), ('one', 'uno', 1), ('three', 3)]) - assert(False), 'creating dictionary using iterable with tuples of size > 2 should fail!' - except: + assert (False), ( + 'creating dictionary using iterable with tuples of size > 2 should fail!' + ) + except Exception: pass - print ('All test passed OK!') + print('All test passed OK!') __all__ = ["multi_key_dict"] @@ -617,5 +707,4 @@ if __name__ == '__main__': try: test_multi_key_dict() except KeyboardInterrupt: - print ('\n(interrupted by user)') - + print('\n(interrupted by user)') diff --git a/src/python/jw/util/stree/StringTree.py b/src/python/jw/util/stree/StringTree.py index 8bc4d89..a04b14c 100644 --- a/src/python/jw/util/stree/StringTree.py +++ b/src/python/jw/util/stree/StringTree.py @@ -1,14 +1,16 @@ -# -*- coding: utf-8 -*- - from __future__ import annotations -from typing import Any, List, Optional, Union +import fnmatch +import re -import re, fnmatch from collections import OrderedDict from enum import Enum, auto +from typing import TYPE_CHECKING -from ..log import * +if TYPE_CHECKING: + from typing import List, Optional, Union, Any + +from ..log import DEBUG, get_caller_pos, slog def quote(s): if is_quoted(s): @@ -26,7 +28,7 @@ def is_quoted(s: str) -> bool: if len(s) < 2: return False d = s[0] - if d == s[-1] and d in [ '"', "'" ]: + if d == s[-1] and d in ['"', "'"]: return True return False @@ -38,16 +40,18 @@ def cleanup_string(s: str) -> str: return s[1:-1].replace('\\' + s[0], s[0]) return s -class StringTree: # export +class StringTree: # export - def __init__(self, path: str, content: str, parent: StringTree|None=None) -> None: + def __init__( + self, path: str, content: str, parent: StringTree | None = None + ) -> None: slog(DEBUG, f'Constructing StringTree(path="{path}", content="{content}")') self.__parent = parent self.children: OrderedDict[str, StringTree] = OrderedDict() self.content: Optional[str] = None self.__set(path, content) - assert(hasattr(self, "content")) + assert (hasattr(self, "content")) #assert self.content is not None # root (content = [ symbols ]) @@ -65,60 +69,89 @@ class StringTree: # export #parent.dump(INFO, "These children are added") self.content = parent.content for name, c in parent.children.items(): - if not name in self.children.keys(): + if name not in self.children.keys(): slog(DEBUG, f'At {self.content}: Adding new child {c}') self.children[name] = c else: self.children[name].__adopt_children(c) - def __set(self, path_, content, split=True): - slog(DEBUG, ('At "{}": '.format(str(self.content)) if hasattr(self, "content") else "") + f'Setting "{path_}" -> "{content}"') - #assert self.content != str(content) # Not sure what the idea behind this was. It often goes off, and all works fine without. - if content is not None and not type(content) in [str, StringTree]: - raise Exception("Tried to add content of unsupported type {}".format(type(content).__name__)) + def __set(self, path_, content, split = True): + slog( + DEBUG, + ('At "{}": '.format(str(self.content)) if hasattr(self, "content") else "") + + f'Setting "{path_}" -> "{content}"' + ) + + # Not sure what the idea behind this was. It often goes off, and all + # works fine without. + #assert self.content != str(content) + + if content is not None and type(content) not in [str, StringTree]: + raise Exception( + "Tried to add content of unsupported type {}".format( + type(content).__name__ + ) + ) if path_ is None: if isinstance(content, str): self.content = cleanup_string(content) elif isinstance(content, StringTree): self.__adopt_children(content) else: - raise Exception("Tried to add content of unsupported type {}".format(type(content).__name__)) - slog(DEBUG, " -- content = >" + str(content) + "<, self.content = >" + str(self.content) + "<") + raise Exception( + "Tried to add content of unsupported type {}".format( + type(content).__name__ + ) + ) + slog( + DEBUG, + " -- content = >" + str(content) + "<, self.content = >" + + str(self.content) + "<" + ) return self path = cleanup_string(path_) - components = path.split('.') if split else [ path ] - l = len(components) - if len(path) == 0 or l == 0: - #assert self.content is None or (isinstance(content, StringTree) and content.content == self.content) + components = path.split('.') if split else [path] + length = len(components) + if len(path) == 0 or length == 0: + + #assert self.content is None or ( + # isinstance(content, StringTree) and content.content == self.content + #) + if isinstance(content, StringTree): - #assert isinstance(content, StringTree), "Type: " + type(content).__name__ + #assert isinstance(content, StringTree), ( + # f'Type: {type(content).__name__ }' + #) self.__adopt_children(content) else: if self.content != content: #self.content = cleanup_string(content) slog(DEBUG, f'Changing content: "{self.content}" ->"{content}"') - assert(content != '"[a-zA-Z0-9+_*/-]"') + assert (content != '"[a-zA-Z0-9+_*/-]"') self.content = content #assert(content != "'antlr_doesnt_understand_vertical_tab'") #self.children[content] = StringTree(None, content) return self - #assert self.content is not None, "tried to set empty content to {}".format(path_) + #assert self.content is not None, f'Tried to set empty content to "{path_}"' nibble = components[0] rest = '.'.join(components[1:]) if nibble not in self.children: - self.children[nibble] = StringTree('', content=nibble, parent=self) - if l > 1: + self.children[nibble] = StringTree('', content = nibble, parent = self) + if length > 1: assert len(rest) > 0 - return self.children[nibble].__set(rest, content=content) + return self.children[nibble].__set(rest, content = content) # last component, a.k.a. leaf if content is not None: - gc = content if isinstance(content, StringTree) else StringTree('', content=content, parent=self.children[nibble]) + gc = content if isinstance(content, StringTree) else StringTree( + '', content = content, parent = self.children[nibble] + ) # Make sure no existing grand child is updated. It would reside too # far up in the grand child OrderedDict, we need it last if gc.content in self.children[nibble].children: del self.children[nibble].children[gc.content] + assert gc.content is not None, 'Grand-child content is None' self.children[nibble].children[gc.content] = gc return self.children[nibble] @@ -129,17 +162,17 @@ class StringTree: # export r = self.get(path) if r is None: raise KeyError(path) - return r.value() # type: ignore + return r.value() # type: ignore def __setitem__(self, key, value): return self.__set(key, value) - def __dump(self, prio, indent=0, **kwargs): + def __dump(self, prio, indent = 0, **kwargs): caller = kwargs['caller'] if 'caller' in kwargs.keys() else get_caller_pos(1) - slog(prio, '|' + (' ' * indent) + str(self.content), caller=caller) + slog(prio, '|' + (' ' * indent) + str(self.content), caller = caller) indent += 2 for name, child in self.children.items(): - child.__dump(prio, indent=indent, caller=caller) + child.__dump(prio, indent = indent, caller = caller) @property def path(self): @@ -164,7 +197,12 @@ class StringTree: # export raise Exception("Tried to set empty content") self.content = content - def add(self, path: str, content: Optional[Union[str, StringTree]] = None, split: bool = True) -> StringTree: + def add( + self, + path: str, + content: Optional[Union[str, StringTree]] = None, + split: bool = True + ) -> StringTree: slog(DEBUG, f'-- At "{self.content}": Adding "{path}" -> "{content}"') return self.__set(path, content, split) @@ -176,7 +214,7 @@ class StringTree: # export slog(DEBUG, "returning myself") return self if is_quoted(path_): - if not path in self.children.keys(): + if path not in self.children.keys(): return None return self.children[path] components = path.split('.') @@ -185,7 +223,7 @@ class StringTree: # export name = cleanup_string(components[0]) if not hasattr(self, "children"): return None - if not name in self.children.keys(): + if name not in self.children.keys(): slog(DEBUG, "Name \"" + name + "\" is not in children of", self.content) for child in self.children: slog(DEBUG, "child = ", child) @@ -193,7 +231,7 @@ class StringTree: # export relpath = '.'.join(components[1:]) return self.children[name].get(relpath) - def value(self, path = None, default=None) -> Optional[str]: + def value(self, path = None, default = None) -> Optional[str]: if path: child = self.get(path) if child is None: @@ -204,7 +242,7 @@ class StringTree: # export if len(self.children) == 0: raise Exception('tried to get value from leaf "{}"'.format(self.content)) slog(DEBUG, f'Returning value from children {self.children}') - return self.children[next(reversed(self.children))].content # type: ignore + return self.children[next(reversed(self.children))].content # type: ignore @property def parent(self): @@ -216,9 +254,12 @@ class StringTree: # export return self return self.__parent.root - def child_list(self, depth_first: bool=True) -> List[StringTree]: - if depth_first == False: - raise Exception("tried to retrieve child list with breadth-first search, not yet implemented") + def child_list(self, depth_first: bool = True) -> List[StringTree]: + if not depth_first: + raise Exception( + 'Tried to retrieve child list with breadth-first ' + 'search, not yet implemented' + ) r = [] for name, c in self.children.items(): r.append(c) @@ -230,32 +271,30 @@ class StringTree: # export msg = '' if args is not None: msg = ' ' + ' '.join(args) + ' ' - slog(prio, ",------------" + msg + "----------- >", caller=caller) - self.__dump(prio, indent=0, caller=caller) - slog(prio, "`------------" + msg + "----------- <", caller=caller) + slog(prio, ",------------" + msg + "----------- >", caller = caller) + self.__dump(prio, indent = 0, caller = caller) + slog(prio, "`------------" + msg + "----------- <", caller = caller) class Match(Enum): - Equal = auto() - RegExArg = auto() - RegExConf = auto() - GlobArg = auto() - GlobConf = auto() + Equal = auto() + RegExArg = auto() + RegExConf = auto() + GlobArg = auto() + GlobConf = auto() - def __find(self, key: str|None, val: str|None, match: Match, depth_first: bool): + def __find(self, key: str | None, val: str | None, m: Match, depth_first: bool): def __children(): for name, child in self.children.items(): - ret.extend(child.__find(key, val, match, depth_first)) + ret.extend(child.__find(key, val, m, depth_first)) def __self(): _val = self.value() _content = self.content try: - if ( - (key == _content and matcher(val, _val)) - or (key is None and matcher(val, _val)) - or (key == _content and val is None) - ): + if ((key == _content and matcher(val, _val)) + or (key is None and matcher(val, _val)) + or (key == _content and val is None)): ret.append(self) except Exception as e: if isinstance(e, re.PatternError): @@ -263,29 +302,33 @@ class StringTree: # export else: raise - def __debug_matcher(matcher, log_level=DEBUG): + def __select_matcher(m: StringTree.Match) -> Any: + match m: + case self.Match.Equal: + return lambda x, y: x == y + case self.Match.RegExArg: + return lambda x, y: re.search(x, y) is not None + case self.Match.RegExConf: + return lambda x, y: re.search(y, x) is not None + case self.Match.GlobArg: + return lambda x, y: fnmatch.fnmatch(y, x) + case self.Match.GlobConf: + return lambda x, y: fnmatch.fnmatch(x, y) + case _: + raise NotImplementedError(f'Matcher {m} is not yet implemented') + + def __debug_matcher(matcher, log_level = DEBUG): + def __matcher(x, y): slog(log_level, f'Comparing "{x}" ~ "{y}"') return matcher(x, y) + return __matcher if not self.children: return [] - matcher = lambda x, y: x == y - match match: - case self.Match.Equal: - pass - case self.Match.RegExArg: - matcher = lambda x, y: re.search(x, y) is not None - case self.Match.RegExConf: - matcher = lambda x, y: re.search(y, x) is not None - case self.Match.GlobArg: - matcher = lambda x, y: fnmatch.fnmatch(y, x) - case self.Match.GlobConf: - matcher = lambda x, y: fnmatch.fnmatch(x, y) - case _: - raise NotImplementedError(f'Matcher {match} is not yet implemented') + matcher = __select_matcher(m) ret: list[StringTree] = [] @@ -298,5 +341,16 @@ class StringTree: # export return ret - def find(self, key: str|None=None, val: str|None=None, match: Match=Match.Equal, depth_first: bool=False): - return [ node.parent.path for node in self.__find(key, val, match=match, depth_first=depth_first)] + def find( + self, + key: str | None = None, + val: str | None = None, + match: Match = Match.Equal, + depth_first: bool = False + ): + ret: list[str] = [] + for node in self.__find(key, val, m = match, depth_first = depth_first): + if node.parent is None: + break + ret.append(node.parent.path) + return ret diff --git a/src/python/jw/util/stree/serdes.py b/src/python/jw/util/stree/serdes.py index 9c8621c..8be6287 100644 --- a/src/python/jw/util/stree/serdes.py +++ b/src/python/jw/util/stree/serdes.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- +import glob +import os +import re -import os, glob - -from .StringTree import * -from ..log import * +from ..log import DEBUG, ERR, INFO, slog, slog_m +from .StringTree import StringTree, cleanup_string def _cleanup_line(line: str) -> str: line = line.strip() @@ -15,18 +15,22 @@ def _cleanup_line(line: str) -> str: if c == in_quote: in_quote = None else: - if c in [ '"', "'" ]: + if c in ['"', "'"]: in_quote = c elif in_quote is None and c == '#': return r.strip() r += c - if len(r) >= 2 and r[0] in [ '"', "'" ] and r[-1] == r[0]: + if len(r) >= 2 and r[0] in ['"', "'"] and r[-1] == r[0]: return r[1:-1] return r -def parse(s: str, allow_full_lines: bool=True, root_content: str='root') -> StringTree: # export +def parse( # export + s: str, + allow_full_lines: bool = True, + root_content: str = 'root' +) -> StringTree: slog_m(DEBUG, "--->--- parsing --->---\n" + s + "\n---<--- parsing ---<---\n") - root = StringTree('', content=root_content) + root = StringTree('', content = root_content) sec = '' for line in s.splitlines(): slog(DEBUG, f'Parsing: "{line}"') @@ -47,7 +51,7 @@ def parse(s: str, allow_full_lines: bool=True, root_content: str='root') -> Stri root.add(sec) continue elif line[0] == ']': - assert(len(sec) > 0) + assert (len(sec) > 0) sec = '.'.join(sec.split('.')[0:-1]) continue lhs = '' @@ -67,17 +71,19 @@ def parse(s: str, allow_full_lines: bool=True, root_content: str='root') -> Stri raise Exception("failed to parse assignment", line) rhs = 'empty' split = False - root.add(sec + '.' + cleanup_string(lhs), cleanup_string(rhs), split=split) + root.add(sec + '.' + cleanup_string(lhs), cleanup_string(rhs), split = split) return root -def _read_lines_from_one_path(path: str, throw=True, level=0, log_prio=INFO, paths_buf=None): +def _read_lines_from_one_path( + path: str, throw = True, level = 0, log_prio = INFO, paths_buf = None +): try: with open(path, 'r') as infile: slog(log_prio, 'Reading {}"{}"'.format(' ' * level * 2, path)) if paths_buf is not None: paths_buf.append(path) ret = [] - for line in infile: # lines are all trailed by \n + for line in infile: # lines are all trailed by \n m = re.search(r'^\s*(-)*include\s+(\S+)', line) if m: optional = m.group(1) == '-' @@ -86,7 +92,12 @@ def _read_lines_from_one_path(path: str, throw=True, level=0, log_prio=INFO, pat dir_name = os.path.dirname(path) if len(dir_name): include_path = dir_name + '/' + include_path - include_lines = _read_lines(include_path, throw=(not optional), level=level+1, paths_buf=paths_buf) + include_lines = _read_lines( + include_path, + throw = (not optional), + level = level + 1, + paths_buf = paths_buf + ) if include_lines is None: slog(DEBUG, f'{path}: Failed to process "{line}"') continue @@ -100,17 +111,26 @@ def _read_lines_from_one_path(path: str, throw=True, level=0, log_prio=INFO, pat raise return None -def _read_lines(path: str, throw=True, level=0, log_prio=INFO, paths_buf=None): +def _read_lines(path: str, throw = True, level = 0, log_prio = INFO, paths_buf = None): paths = glob.glob(path) ret = [] for p in paths: - rr = _read_lines_from_one_path(p, throw=throw, level=level, log_prio=log_prio, paths_buf=paths_buf) + rr = _read_lines_from_one_path( + p, throw = throw, level = level, log_prio = log_prio, paths_buf = paths_buf + ) if rr is None: return None ret.extend(rr) return ret -def read(path: str, root_content: str='root', log_prio=INFO, paths_buf=None) -> StringTree: # export - lines = _read_lines_from_one_path(path, log_prio=log_prio, paths_buf=paths_buf) +def read( # export + path: str, + root_content: str = 'root', + log_prio = INFO, + paths_buf = None +) -> StringTree: + lines = _read_lines_from_one_path(path, log_prio = log_prio, paths_buf = paths_buf) + if lines is None: + raise Exception(f'Could not read ini file from "{path}"') s = ''.join(lines) - return parse(s, root_content=root_content) + return parse(s, root_content = root_content)