diff --git a/cutekit/__init__.py b/cutekit/__init__.py index efb0d7f..bb7ec34 100644 --- a/cutekit/__init__.py +++ b/cutekit/__init__.py @@ -2,11 +2,12 @@ import sys import os import logging +from pathlib import Path + from . import ( - builder, # noqa: F401 this is imported for side effects - cli, + builder, + cli, # noqa: F401 this is imported for side effects const, - graph, # noqa: F401 this is imported for side effects model, plugins, pods, # noqa: F401 this is imported for side effects @@ -28,48 +29,75 @@ def ensure(version: tuple[int, int, int]): ) -def setupLogger(verbose: bool): - if verbose: - logging.basicConfig( - level=logging.DEBUG, - format=f"{vt100.CYAN}%(asctime)s{vt100.RESET} {vt100.YELLOW}%(levelname)s{vt100.RESET} %(name)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - else: - projectRoot = model.Project.topmost() - logFile = const.GLOBAL_LOG_FILE - if projectRoot is not None: - logFile = os.path.join(projectRoot.dirname(), const.PROJECT_LOG_FILE) +class logger: + class LoggerArgs: + verbose: bool = cli.arg(None, "verbose", "Enable verbose logging") - shell.mkdir(os.path.dirname(logFile)) + @staticmethod + def setup(args: LoggerArgs): + if args.verbose: + logging.basicConfig( + level=logging.DEBUG, + format=f"{vt100.CYAN}%(asctime)s{vt100.RESET} {vt100.YELLOW}%(levelname)s{vt100.RESET} %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + else: + projectRoot = model.Project.topmost() + logFile = const.GLOBAL_LOG_FILE + if projectRoot is not None: + logFile = os.path.join(projectRoot.dirname(), const.PROJECT_LOG_FILE) - logging.basicConfig( - level=logging.INFO, - filename=logFile, - filemode="w", - format="%(asctime)s %(levelname)s %(name)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) + shell.mkdir(os.path.dirname(logFile)) + + logging.basicConfig( + level=logging.INFO, + filename=logFile, + filemode="w", + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + +class RootArgs( + plugins.PluginsArgs, + pods.PodSetupArgs, + logger.LoggerArgs, +): + pass + + +@cli.command(None, "/", const.DESCRIPTION) +def _(args: RootArgs): + const.setup() + logger.setup(args) + plugins.setup(args) + pods.setup(args) + + +@cli.command("u", "usage", "Show usage information") +def _(): + print(f"Usage: {const.ARGV0} [args...]") + + +@cli.command("v", "version", "Show current version") +def _(): + print(f"CuteKit v{const.VERSION_STR}") def main() -> int: try: shell.mkdir(const.GLOBAL_CK_DIR) - extraArgs = os.environ.get("CK_EXTRA_ARGS", None) - args = cli.parse((extraArgs.split(" ") if extraArgs else []) + sys.argv[1:]) - setupLogger(args.consumeOpt("verbose", False) is True) - - const.setup() - plugins.setup(args) - pods.setup(args) - cli.exec(args) - + extra = os.environ.get("CK_EXTRA_ARGS", None) + args = [const.ARGV0] + (extra.split(" ") if extra else []) + sys.argv[1:] + cli._root.eval(args) return 0 + except RuntimeError as e: logging.exception(e) - cli.error(str(e)) + vt100.error(str(e)) cli.usage() + return 1 + except KeyboardInterrupt: print() - - return 1 + return 1 diff --git a/cutekit/builder.py b/cutekit/builder.py index a2d4d55..eefb981 100644 --- a/cutekit/builder.py +++ b/cutekit/builder.py @@ -6,7 +6,7 @@ from pathlib import Path import sys from typing import Callable, Literal, TextIO, Union -from . import shell, rules, model, ninja, const, cli, vt100 +from . import cli, shell, rules, model, ninja, const, vt100 _logger = logging.getLogger(__name__) @@ -16,8 +16,8 @@ class Scope: registry: model.Registry @staticmethod - def use(args: cli.Args, props: model.Props = {}) -> "Scope": - registry = model.Registry.use(args, props) + def use(args: model.RegistryArgs) -> "Scope": + registry = model.Registry.use(args) return Scope(registry) def key(self) -> str: @@ -33,9 +33,9 @@ class TargetScope(Scope): target: model.Target @staticmethod - def use(args: cli.Args, props: model.Props = {}) -> "TargetScope": - registry = model.Registry.use(args, props) - target = model.Target.use(args, props) + def use(args: model.TargetArgs) -> "TargetScope": + registry = model.Registry.use(args) + target = model.Target.use(args) return TargetScope(registry, target) def key(self) -> str: @@ -384,32 +384,35 @@ def build( @cli.command("b", "builder", "Build/Run/Clean a component or all components") -def _(args: cli.Args): +def _(): pass +class BuildArgs(model.TargetArgs): + component: str = cli.operand("component", "Component to build") + + @cli.command("b", "builder/build", "Build a component or all components") -def _(args: cli.Args): +def _(args: BuildArgs): scope = TargetScope.use(args) - componentSpec = args.consumeArg() component = None - if componentSpec is not None: - component = scope.registry.lookup(componentSpec, model.Component) + if args.component is not None: + component = scope.registry.lookup(args.component, model.Component) build(scope, component if component is not None else "all")[0] +class RunArgs(BuildArgs, shell.DebugArgs, shell.ProfileArgs): + debug: bool = cli.arg(None, "debug", "Attach a debugger") + profile: bool = cli.arg(None, "profile", "Profile the execution") + args: list[str] = cli.extra("args", "Arguments to pass to the component") + + @cli.command("r", "builder/run", "Run a component") -def runCmd(args: cli.Args): - debug = args.consumeOpt("debug", False) is True - wait = args.consumeOpt("wait", False) is True - debugger = str(args.consumeOpt("debugger", "lldb")) - - profile = args.consumeOpt("profile", False) is True - what = str(args.consumeOpt("what", "cpu")) - rate = int(args.consumeOpt("rate", 1000)) - +def runCmd(args: RunArgs): componentSpec = args.consumeArg() or "__main__" - scope = TargetScope.use(args, {"debug": debug}) + + args.props |= {"debug": args.debug} + scope = TargetScope.use(args) component = scope.registry.lookup( componentSpec, model.Component, includeProvides=True @@ -423,39 +426,39 @@ def runCmd(args: cli.Args): os.environ["CK_BUILDDIR"] = product.target.builddir os.environ["CK_COMPONENT"] = product.component.id - command = [str(product.path), *args.extra] + command = [str(product.path), *args.args] - if debug: - shell.debug(command, debugger=debugger, wait=wait) - elif profile: - shell.profile(command, what=what, rate=rate) + if args.debug: + shell.debug(command, debugger=args.debugger, wait=args.wait) + elif args.profile: + shell.profile(command, what=args.what, rate=args.rate) else: shell.exec(*command) @cli.command("t", "builder/test", "Run all test targets") -def _(args: cli.Args): +def _(args: RunArgs): # This is just a wrapper around the `run` command that try # to run a special hook component named __tests__. - args.args.insert(0, "__tests__") + args.component = "__tests__" runCmd(args) @cli.command("d", "builder/debug", "Debug a component") -def _(args: cli.Args): +def _(args: RunArgs): # This is just a wrapper around the `run` command that # always enable debug mode. - args.opts["debug"] = True + args.debug = True runCmd(args) @cli.command("c", "builder/clean", "Clean build files") -def _(args: cli.Args): - model.Project.use(args) +def _(): + model.Project.use() shell.rmrf(const.BUILD_DIR) @cli.command("n", "builder/nuke", "Clean all build files and caches") -def _(args: cli.Args): - model.Project.use(args) +def _(): + model.Project.use() shell.rmrf(const.PROJECT_CK_DIR) diff --git a/cutekit/cli.py b/cutekit/cli.py index 52e14d7..5888c0a 100644 --- a/cutekit/cli.py +++ b/cutekit/cli.py @@ -1,217 +1,685 @@ -import inspect -import logging import sys +from enum import Enum +from types import GenericAlias +import typing as tp import dataclasses as dt -from pathlib import Path -from typing import Optional, Union, Callable - -from . import const, vt100 - -Value = Union[str, bool, int] - -_logger = logging.getLogger(__name__) +from typing import Any, Callable, Optional, Union +from cutekit import vt100, const -class Args: - opts: dict[str, Value] - args: list[str] - extra: list[str] +T = tp.TypeVar("T") - def __init__(self): - self.opts = {} - self.args = [] - self.extra = [] +# --- Scan -------------------------------------------------------------- # - def consumePrefix(self, prefix: str) -> dict[str, Value]: - result: dict[str, Value] = {} - copy = self.opts.copy() - for key, value in copy.items(): - if key.startswith(prefix): - result[key[len(prefix) :]] = value - del self.opts[key] + +class Scan: + _src: str + _off: int + _save: list[int] + + def __init__(self, src: str, off: int = 0): + self._src = src + self._off = 0 + self._save = [] + + def curr(self) -> str: + if self.eof(): + return "\0" + return self._src[self._off] + + def next(self) -> str: + if self.eof(): + return "\0" + + self._off += 1 + return self.curr() + + def peek(self, off: int = 1) -> str: + if self._off + off >= len(self._src): + return "\0" + + return self._src[self._off + off] + + def eof(self) -> bool: + return self._off >= len(self._src) + + def skipStr(self, s: str) -> bool: + if self._src[self._off :].startswith(s): + self._off += len(s) + return True + + return False + + def isStr(self, s: str) -> bool: + self.save() + if self.skipStr(s): + self.restore() + return True + + self.restore() + return False + + def save(self) -> None: + self._save.append(self._off) + + def restore(self) -> None: + self._off = self._save.pop() + + def skipWhitespace(self) -> bool: + result = False + while not self.eof() and self.curr().isspace(): + self.next() + result = True return result - def consumeOpt(self, key: str, default: Value = False) -> Value: - if key in self.opts: - result = self.opts[key] - del self.opts[key] - return result - return default + def skipSeparator(self, sep: str) -> bool: + self.save() + self.skipWhitespace() + if self.skipStr(sep): + self.skipWhitespace() + return True - def tryConsumeOpt(self, key: str) -> Optional[Value]: - if key in self.opts: - result = self.opts[key] - del self.opts[key] - return result - return None + self.restore() + return False - def consumeArg(self, default: Optional[str] = None) -> Optional[str]: - if len(self.args) == 0: - return default + def isSeparator(self, sep: str) -> bool: + self.save() + self.skipWhitespace() + if self.skipStr(sep): + self.skipWhitespace() + self.restore() + return True - first = self.args[0] - del self.args[0] - return first + self.restore() + return False + + def skipKeyword(self, keyword: str) -> bool: + self.save() + self.skipWhitespace() + if self.skipStr(keyword) and not self.curr().isalnum(): + return True + + self.restore() + return False + + def isKeyword(self, keyword: str) -> bool: + self.save() + self.skipWhitespace() + if self.skipStr(keyword) and not self.curr().isalnum(): + self.restore() + return True + + self.restore() + return False -def parse(args: list[str]) -> Args: - result = Args() +# --- Parser ------------------------------------------------------------ # - for i in range(len(args)): - arg = args[i] - if arg.startswith("--") and not arg == "--": - if "=" in arg: - key, value = arg[2:].split("=", 1) - result.opts[key] = value - else: - result.opts[arg[2:]] = True - elif arg == "--": - result.extra += args[i + 1 :] +PrimitiveValue = str | bool | int +Object = dict[str, PrimitiveValue] +List = list[PrimitiveValue] +Value = str | bool | int | Object | List + + +@dt.dataclass +class Token: + pass + + +@dt.dataclass +class ArgumentToken(Token): + key: str + subkey: Optional[str] + value: Value + short: bool + + +@dt.dataclass +class OperandToken(Token): + value: str + + +@dt.dataclass +class ExtraToken(Token): + args: list[str] + + +def _parseIdent(s: Scan) -> str: + res = "" + while not s.eof() and (s.curr().isalnum() or s.curr() in "_-+"): + res += s.curr() + s.next() + return res + + +def _parseUntilComma(s: Scan) -> str: + res = "" + while not s.eof() and s.curr() != ",": + res += s.curr() + s.next() + return res + + +def _expectIdent(s: Scan) -> str: + res = _parseIdent(s) + if len(res) == 0: + raise RuntimeError("Expected identifier") + return res + + +def _parseString(s: Scan, quote: str) -> str: + s.skipStr(quote) + res = "" + escaped = False + while not s.eof(): + c = s.curr() + if escaped: + res += c + escaped = False + elif c == "\\": + escaped = True + elif c == quote: break else: - result.args.append(arg) - - return result + res += c + s.next() + if not s.skipStr(quote): + raise RuntimeError("Unterminated string") + return res -Callback = Callable[[Args], None] +def _tryParseInt(ident) -> Optional[int]: + try: + return int(ident) + except ValueError: + return None + + +def _parsePrimitive(s: Scan) -> PrimitiveValue: + if s.curr() == '"': + return _parseString(s, '"') + elif s.curr() == "'": + return _parseString(s, "'") + else: + ident = _parseUntilComma(s) + + if ident in ("true", "True", "y", "yes", "Y", "Yes"): + return True + elif ident in ("false", "False", "n", "no", "N", "No"): + return False + elif n := _tryParseInt(ident): + return n + else: + return ident + + +def _parseValue(s: Scan) -> Value: + lhs = _parsePrimitive(s) + if s.eof(): + return lhs + values: List = [lhs] + while not s.eof() and s.skipStr(","): + values.append(_parsePrimitive(s)) + return values + + +def parseValue(s: str) -> Value: + return _parseValue(Scan(s)) + + +def parseArg(arg: str) -> list[Token]: + s = Scan(arg) + if s.skipStr("--"): + key = _expectIdent(s) + subkey = "" + if s.skipStr(":"): + subkey = _expectIdent(s) + if s.skipStr("="): + value = _parseValue(s) + else: + value = True + return [ArgumentToken(key, subkey, value, False)] + elif s.skipStr("-"): + res = [] + while not s.eof(): + key = s.curr() + if not key.isalnum(): + raise RuntimeError("Expected alphanumeric") + s.next() + res.append(ArgumentToken(key, None, True, True)) + return tp.cast(list[Token], res) + else: + return [OperandToken(arg)] + + +def parseArgs(args: list[str]) -> list[Token]: + res: list[Token] = [] + while len(args) > 0: + arg = args.pop(0) + if arg == "--": + res.append(ExtraToken(args)) + break + else: + res.extend(parseArg(arg)) + return res + + +# --- Schema ----------------------------------------------------------------- # + + +class FieldKind(Enum): + FLAG = 0 + OPERAND = 1 + EXTRA = 2 + + +@dt.dataclass +class Field: + kind: FieldKind + shortName: Optional[str] + longName: str + description: str = "" + default: Any = None + + _fieldName: str | None = dt.field(init=False, default=None) + _fieldType: type | None = dt.field(init=False, default=None) + + def bind(self, typ: type, name: str): + self._fieldName = name + self._fieldType = typ.__annotations__[name] + if self.longName is None: + self.longName = name + + def isList(self) -> bool: + return ( + isinstance(self._fieldType, GenericAlias) + and self._fieldType.__origin__ == list + ) + + def isDict(self) -> bool: + return ( + isinstance(self._fieldType, GenericAlias) + and self._fieldType.__origin__ == dict + ) + + def innerType(self) -> type: + if self.isList(): + return self._fieldType.__args__[0] + + if self.isDict(): + return self._fieldType.__args__[1] + + return self._fieldType + + def defaultValue(self) -> Any: + if self._fieldType is None: + return None + + if self.default is not None: + return self.default + + if self._fieldType == bool: + return False + elif self._fieldType == int: + return 0 + elif self._fieldType == str: + return "" + elif self.isList(): + return [] + elif self.isDict(): + return {} + else: + return None + + def setDefault(self, obj: Any): + if self._fieldName: + setattr(obj, self._fieldName, self.defaultValue()) + + def castValue(self, val: Any, subkey: Optional[str]): + try: + val = int(val) + except ValueError: + pass + except TypeError: + pass + + if isinstance(val, list): + return [self.castValue(v, subkey) for v in val] + + val = self.innerType()(val) + + if self.isDict() and subkey: + return {subkey: val} + + if self.isDict(): + return {str(val): True} + + return val + + def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None): + value = self.castValue(value, subkey) + field = getattr(obj, self._fieldName) + if isinstance(field, list): + if isinstance(value, list): + field.extend(value) + else: + field.append(value) + elif isinstance(field, dict): + field.update(value) + else: + setattr(obj, self._fieldName, value) + + def getAttr(self, obj: Any) -> Any: + return getattr(obj, self._fieldName) + + +def arg( + shortName: str | None = None, + longName: str = "", + description: str = "", + default: Any = None, +) -> Any: + return Field(FieldKind.FLAG, shortName, longName, description, default) + + +def operand(longName: str = "", description: str = "") -> Any: + return Field(FieldKind.OPERAND, None, longName, description) + + +def extra(longName: str = "", description: str = "") -> Any: + return Field(FieldKind.EXTRA, None, longName, description) + + +@dt.dataclass +class Schema: + typ: Optional[type] = None + args: list[Field] = dt.field(default_factory=list) + operands: list[Field] = dt.field(default_factory=list) + extras: Optional[Field] = None + + @staticmethod + def extract(typ: type) -> "Schema": + s = Schema(typ) + + for f in typ.__annotations__.keys(): + field = getattr(typ, f, None) + + if field is None: + raise ValueError(f"Field '{f}' is not defined") + + if not isinstance(field, Field): + raise ValueError(f"Field '{f}' is not a Field") + + field.bind(typ, f) + + if field.kind == FieldKind.FLAG: + s.args.append(field) + elif field.kind == FieldKind.OPERAND: + s.operands.append(field) + elif field.kind == FieldKind.EXTRA: + if s.extras: + raise ValueError("Only one extra argument is allowed") + s.extras = field + + # now move to the base class + for base in typ.__bases__: + if base == object: + continue + baseSchema = Schema.extract(base) + s.args.extend(baseSchema.args) + s.operands.extend(baseSchema.operands) + if not s.extras: + s.extras = baseSchema.extras + elif baseSchema.extras: + raise ValueError("Only one extra argument is allowed") + + return s + + @staticmethod + def extractFromCallable(fn: tp.Callable) -> Optional["Schema"]: + typ: type | None = ( + None + if len(fn.__annotations__) == 0 + else next(iter(fn.__annotations__.values())) + ) + + if typ is None: + return None + + return Schema.extract(typ) + + def usage(self) -> str: + res = "" + for arg in self.args: + flag = "" + if arg.shortName: + flag += f"-{arg.shortName}" + + if arg.longName: + if flag: + flag += ", " + flag += f"--{arg.longName}" + res += f"[{flag}] " + for operand in self.operands: + res += f"<{operand.longName}> " + if self.extras: + res += f"[-- {self.extras.longName}]" + return res + + def _lookupArg(self, key: str, short: bool) -> Field: + for arg in self.args: + if short and arg.shortName == key: + return arg + elif not short and arg.longName == key: + return arg + raise ValueError(f"Unknown argument '{key}'") + + def _setOperand(self, tok: OperandToken): + return + + def _instanciate(self) -> Any: + if self.typ is None: + return None + res = self.typ() + for arg in self.args: + arg.setDefault(res) + return res + + def parse(self, args: list[str]) -> Any: + res = self._instanciate() + if res is None: + if len(args) > 0: + raise ValueError("Unexpected arguments") + else: + return None + + stack = args[:] + while len(stack) > 0: + if stack[0] == "--": + if not self.extras: + raise ValueError("Unexpected '--'") + self._setExtra(res, stack.pop(0)) + break + + toks = parseArg(stack.pop(0)) + for tok in toks: + if isinstance(tok, ArgumentToken): + arg = self._lookupArg(tok.key, tok.short) + arg.putValue(res, tok.value, tok.subkey) + elif isinstance(tok, OperandToken): + self._setOperand(tok) + else: + raise ValueError(f"Unexpected token: {type(tok)}") + + return res @dt.dataclass class Command: shortName: Optional[str] longName: str - helpText: str - isPlugin: bool - callback: Callback + description: str = "" + epilog: Optional[str] = None + schema: Optional[Schema] = None + callable: Optional[tp.Callable] = None subcommands: dict[str, "Command"] = dt.field(default_factory=dict) + populated: bool = False + + def _spliceArgs(self, args: list[str]) -> tuple[list[str], list[str]]: + rest = args[:] + curr = [] + if len(self.subcommands) > 0: + while len(rest) > 0 and rest[0].startswith("-") and rest[0] != "--": + curr.append(rest.pop(0)) + else: + curr = rest + rest = [] + return curr, rest + + def help(self, cmd): + vt100.title(f"{cmd}") + print() + + vt100.subtitle("Usage") + print(vt100.indent(f"{cmd}{self.usage()}")) + print() + + vt100.subtitle("Description") + print(vt100.indent(self.description)) + print() + + if self.schema and any(self.schema.args): + vt100.subtitle("Options") + for arg in self.schema.args: + flag = "" + if arg.shortName: + flag += f"-{arg.shortName}" + + if arg.longName: + if flag: + flag += ", " + flag += f"--{arg.longName}" + + if arg.description: + flag += f" {arg.description}" + + print(vt100.indent(flag)) + print() + + if any(self.subcommands): + vt100.subtitle("Subcommands") + for name, sub in self.subcommands.items(): + print( + vt100.indent( + f"{vt100.GREEN}{sub.shortName or ' '}{vt100.RESET} {name} - {sub.description}" + ) + ) + print() + + if self.epilog: + print(self.epilog) + print() + + # for name, sub in self.subcommands.items(): + # sub.help(f"{cmd} {name}") + + def usage(self) -> str: + res = " " + if self.schema: + res += self.schema.usage() + + if len(self.subcommands) == 1: + res += "[subcommand] [args...]" + + elif len(self.subcommands) > 0: + res += "{" + first = True + for name, cmd in self.subcommands.items(): + if not first: + res += "|" + res += f"{name}" + first = False + res += "}" + + res += " [args...]" + + return res + + def lookupSubcommand(self, name: str) -> "Command": + if name in self.subcommands: + return self.subcommands[name] + for sub in self.subcommands.values(): + if sub.shortName == name: + return sub + raise ValueError(f"Unknown subcommand '{name}'") + + def eval(self, args: list[str]): + cmd = args.pop(0) + curr, rest = self._spliceArgs(args) + if "-h" in curr or "--help" in curr: + self.help(cmd) + return + if "-u" in curr or "--usage" in curr: + print("Usage: " + cmd + self.usage(), end="\n\n") + return + + try: + if self.callable: + if self.schema: + args = self.schema.parse(curr) + self.callable(args) + else: + self.callable() + + if self.subcommands: + if len(rest) == 0 and not self.populated: + raise ValueError("Expected subcommand") + else: + self.lookupSubcommand(rest[0]).eval(rest) + elif len(rest) > 0: + raise ValueError(f"Unknown operand '{rest[0]}'") + + except ValueError as e: + vt100.error(str(e)) + print("Usage: " + cmd + self.usage(), end="\n\n") + return -commands: dict[str, Command] = {} +_root = Command(None, const.ARGV0) -def command(shortName: Optional[str], longName: str, helpText: str): - curframe = inspect.currentframe() - calframe = inspect.getouterframes(curframe, 2) +def _splitPath(path: str) -> list[str]: + if path == "/": + return [] + return path.split("/") - def wrap(fn: Callable[[Args], None]): - _logger.debug(f"Registering command {longName}") - path = longName.split("/") - parent = commands - for p in path[:-1]: - parent = parent[p].subcommands - parent[path[-1]] = Command( - shortName, - path[-1], - helpText, - Path(calframe[1].filename).parent != Path(__file__).parent, - fn, - ) +def _resolvePath(path: list[str]) -> Command: + if path == "/": + return _root + cmd = _root + for name in path: + if name not in cmd.subcommands: + cmd.subcommands[name] = Command(None, name) + cmd = cmd.subcommands[name] + return cmd + + +def command(shortName: str, longName: str, description: str = "") -> Callable: + def wrap(fn: Callable): + schema = Schema.extractFromCallable(fn) + path = _splitPath(longName) + cmd = _resolvePath(path) + if cmd.populated: + raise ValueError(f"Command '{longName}' is already defined") + cmd.shortName = shortName + cmd.longName = len(path) > 0 and path[-1] or "" + cmd.description = description + cmd.schema = schema + cmd.callable = fn + cmd.populated = True return fn return wrap - - -# --- Builtins Commands ------------------------------------------------------ # - - -@command("u", "usage", "Show usage information") -def usage(args: Optional[Args] = None): - print(f"Usage: {const.ARGV0} [args...]") - - -def error(msg: str) -> None: - print(f"{vt100.RED}Error:{vt100.RESET} {msg}\n", file=sys.stderr) - - -def warning(msg: str) -> None: - print(f"{vt100.YELLOW}Warning:{vt100.RESET} {msg}\n", file=sys.stderr) - - -def ask(msg: str, default: Optional[bool] = None) -> bool: - if default is None: - msg = f"{msg} [y/n] " - elif default: - msg = f"{msg} [Y/n] " - else: - msg = f"{msg} [y/N] " - - while True: - result = input(msg).lower() - if result in ("y", "yes"): - return True - elif result in ("n", "no"): - return False - elif result == "" and default is not None: - return default - - -@command("h", "help", "Show this help message") -def helpCmd(args: Args): - usage() - - print() - - vt100.title("Description") - print(f" {const.DESCRIPTION}") - - print() - vt100.title("Commands") - for cmd in sorted(commands.values(), key=lambda c: c.longName): - if cmd.longName.startswith("_") or len(cmd.subcommands) > 0: - continue - - pluginText = "" - if cmd.isPlugin: - pluginText = f"{vt100.CYAN}(plugin){vt100.RESET}" - - print( - f" {vt100.GREEN}{cmd.shortName or ' '}{vt100.RESET} {cmd.longName} - {cmd.helpText} {pluginText}" - ) - - for cmd in sorted(commands.values(), key=lambda c: c.longName): - if cmd.longName.startswith("_") or len(cmd.subcommands) == 0: - continue - - print() - vt100.title(f"{cmd.longName.capitalize()} - {cmd.helpText}") - for subcmd in sorted(cmd.subcommands.values(), key=lambda c: c.longName): - pluginText = "" - if subcmd.isPlugin: - pluginText = f"{vt100.CYAN}(plugin){vt100.RESET}" - - print( - f" {vt100.GREEN}{subcmd.shortName or ' '}{vt100.RESET} {subcmd.longName} - {subcmd.helpText} {pluginText}" - ) - - print() - vt100.title("Logging") - print(" Logs are stored in:") - print(f" - {const.PROJECT_LOG_FILE}") - print(f" - {const.GLOBAL_LOG_FILE}") - - -@command("v", "version", "Show current version") -def versionCmd(args: Args): - print(f"CuteKit v{const.VERSION_STR}") - - -def exec(args: Args, cmds=commands): - cmd = args.consumeArg() - - if cmd is None: - raise RuntimeError("No command specified") - - for c in cmds.values(): - if c.shortName == cmd or c.longName == cmd: - if len(c.subcommands) > 0: - exec(args, c.subcommands) - return - else: - c.callback(args) - return - - raise RuntimeError(f"Unknown command {cmd}") diff --git a/cutekit/graph.py b/cutekit/graph.py deleted file mode 100644 index aac391e..0000000 --- a/cutekit/graph.py +++ /dev/null @@ -1,95 +0,0 @@ -import os - -from typing import Optional, cast - -from . import vt100, cli, model - - -def view( - registry: model.Registry, - target: model.Target, - scope: Optional[str] = None, - showExe: bool = True, - showDisabled: bool = False, -): - from graphviz import Digraph # type: ignore - - g = Digraph(target.id, filename="graph.gv") - - g.attr("graph", splines="ortho", rankdir="BT", ranksep="1.5") - g.attr("node", shape="ellipse") - g.attr( - "graph", - label=f"<{scope or 'Full Dependency Graph'}
{target.id}>", - labelloc="t", - ) - - scopeInstance = None - - if scope is not None: - scopeInstance = registry.lookup(scope, model.Component) - - for component in registry.iterEnabled(target): - if not component.type == model.Kind.LIB and not showExe: - continue - - if ( - scopeInstance is not None - and component.id != scope - and component.id not in scopeInstance.resolved[target.id].required - ): - continue - - if component.resolved[target.id].enabled: - fillcolor = "lightgrey" if component.type == model.Kind.LIB else "lightblue" - shape = "plaintext" if not scope == component.id else "box" - - g.node( - component.id, - f"<{component.id}
{vt100.wordwrap(component.description, 40,newline='
')}>", - shape=shape, - style="filled", - fillcolor=fillcolor, - ) - - for req in component.requires: - g.edge(component.id, req) - - for req in component.provides: - isChosen = target.routing.get(req, None) == component.id - - g.edge( - req, - component.id, - arrowhead="none", - color=("blue" if isChosen else "black"), - ) - elif showDisabled: - g.node( - component.id, - f"<{component.id}
{vt100.wordwrap(component.description, 40,newline='
')}

{vt100.wordwrap(str(component.resolved[target.id].reason), 40,newline='
')}
>", - shape="plaintext", - style="filled", - fontcolor="#999999", - fillcolor="#eeeeee", - ) - - for req in component.requires: - g.edge(component.id, req, color="#aaaaaa") - - for req in component.provides: - g.edge(req, component.id, arrowhead="none", color="#aaaaaa") - - g.view(filename=os.path.join(target.builddir, "graph.gv")) - - -@cli.command("g", "graph", "Show the dependency graph") -def _(args: cli.Args): - registry = model.Registry.use(args) - target = model.Target.use(args) - - scope = cast(Optional[str], args.tryConsumeOpt("scope")) - onlyLibs = args.consumeOpt("only-libs", False) is True - showDisabled = args.consumeOpt("show-disabled", False) is True - - view(registry, target, scope=scope, showExe=not onlyLibs, showDisabled=showDisabled) diff --git a/cutekit/model.py b/cutekit/model.py index b076e22..b6dbbb4 100644 --- a/cutekit/model.py +++ b/cutekit/model.py @@ -7,11 +7,10 @@ from enum import Enum from typing import Any, Generator, Optional, Type, cast from pathlib import Path from dataclasses_json import DataClassJsonMixin -from typing import Union from cutekit import const, shell -from . import jexpr, compat, utils, cli, vt100 +from . import cli, jexpr, compat, utils, vt100 _logger = logging.getLogger(__name__) @@ -171,7 +170,7 @@ class Project(Manifest): Project.fetchs(project.extern) @staticmethod - def use(args: cli.Args) -> "Project": + def use() -> "Project": global _project if _project is None: _project = Project.ensure() @@ -179,29 +178,37 @@ class Project(Manifest): @cli.command("m", "model", "Manage the model") -def _(args: cli.Args): +def _(): pass @cli.command("i", "model/install", "Install required external packages") -def _(args: cli.Args): - project = Project.use(args) +def _(): + project = Project.use() Project.fetchs(project.extern) +class ModelInitArgs: + repo: str = cli.arg( + None, + "repo", + "The repository to fetch templates from", + default=const.DEFAULT_REPO_TEMPLATES, + ) + list: bool = cli.arg("l", "list", "List available templates") + template: str = cli.operand("template", "The template to use") + name: str = cli.operand("name", "The name of the project") + + @cli.command("I", "model/init", "Initialize a new project") -def _(args: cli.Args): +def _(args: ModelInitArgs): import requests - repo = args.consumeOpt("repo", const.DEFAULT_REPO_TEMPLATES) - list = args.consumeOpt("list") - - template = args.consumeArg() - name = args.consumeArg() - _logger.info("Fetching registry...") - r = requests.get(f"https://raw.githubusercontent.com/{repo}/main/registry.json") + r = requests.get( + f"https://raw.githubusercontent.com/{args.repo}/main/registry.json" + ) if r.status_code != 200: _logger.error("Failed to fetch registry") @@ -209,34 +216,34 @@ def _(args: cli.Args): registry = r.json() - if list: + if args.list: print( "\n".join(f"* {entry['id']} - {entry['description']}" for entry in registry) ) return - if not template: + if not args.template: raise RuntimeError("Template not specified") def template_match(t: jexpr.Json) -> str: - return t["id"] == template + return t["id"] == args.template if not any(filter(template_match, registry)): - raise LookupError(f"Couldn't find a template named {template}") + raise LookupError(f"Couldn't find a template named {args.template}") - if not name: - _logger.info(f"No name was provided, defaulting to {template}") - name = template + if not args.name: + _logger.info(f"No name was provided, defaulting to {args.template}") + args.name = args.template - if os.path.exists(name): - raise RuntimeError(f"Directory {name} already exists") + if os.path.exists(args.name): + raise RuntimeError(f"Directory {args.name} already exists") - print(f"Creating project {name} from template {template}...") - shell.cloneDir(f"https://github.com/{repo}", template, name) - print(f"Project {name} created\n") + print(f"Creating project {args.name} from template {args.template}...") + shell.cloneDir(f"https://github.com/{args.repo}", args.template, args.name) + print(f"Project {args.name} created\n") print("We suggest that you begin by typing:") - print(f" {vt100.GREEN}cd {name}{vt100.RESET}") + print(f" {vt100.GREEN}cd {args.name}{vt100.RESET}") print( f" {vt100.GREEN}cutekit install{vt100.BRIGHT_BLACK} # Install external packages{vt100.RESET}" ) @@ -263,6 +270,17 @@ DEFAULT_TOOLS: Tools = { } +class RegistryArgs: + props: dict[str, str] = cli.arg(None, "prop", "Set a property") + mixins: list[str] = cli.arg(None, "mixins", "Apply mixins") + + +class TargetArgs(RegistryArgs): + target: str = cli.arg( + None, "target", "The target to use", default="host-" + shell.uname().machine + ) + + @dt.dataclass class Target(Manifest): props: Props = dt.field(default_factory=dict) @@ -287,10 +305,9 @@ class Target(Manifest): return os.path.join(const.BUILD_DIR, f"{self.id}{postfix}") @staticmethod - def use(args: cli.Args, props: Props = {}) -> "Target": - registry = Registry.use(args, props) - targetSpec = str(args.consumeOpt("target", "host-" + shell.uname().machine)) - return registry.ensure(targetSpec, Target) + def use(args: TargetArgs) -> "Target": + registry = Registry.use(args) + return registry.ensure(args.target, Target) def route(self, componentSpec: str): """ @@ -536,18 +553,14 @@ class Registry(DataClassJsonMixin): return m @staticmethod - def use(args: cli.Args, props: Props = {}) -> "Registry": + def use(args: RegistryArgs) -> "Registry": global _registry if _registry is not None: return _registry - project = Project.use(args) - mixins = str(args.consumeOpt("mixins", "")).split(",") - if mixins == [""]: - mixins = [] - props |= cast(dict[str, str], args.consumePrefix("prop:")) - _registry = Registry.load(project, mixins, props) + project = Project.use() + _registry = Registry.load(project, args.mixins, args.props) return _registry @staticmethod @@ -608,11 +621,11 @@ class Registry(DataClassJsonMixin): ) else: victim.resolved[target.id].injected.append(c.id) - victim.resolved[ - target.id - ].required = utils.uniqPreserveOrder( - c.resolved[target.id].required - + victim.resolved[target.id].required + victim.resolved[target.id].required = ( + utils.uniqPreserveOrder( + c.resolved[target.id].required + + victim.resolved[target.id].required + ) ) # Resolve tooling @@ -639,9 +652,8 @@ class Registry(DataClassJsonMixin): @cli.command("l", "model/list", "List all components and targets") -def _(args: cli.Args): +def _(args: TargetArgs): registry = Registry.use(args) - components = list(registry.iter(Component)) targets = list(registry.iter(Target)) @@ -659,3 +671,103 @@ def _(args: cli.Args): else: print(vt100.p(", ".join(map(lambda m: m.id, targets)))) print() + + +def view( + registry: Registry, + target: Target, + scope: Optional[str] = None, + showExe: bool = True, + showDisabled: bool = False, +): + from graphviz import Digraph # type: ignore + + g = Digraph(target.id, filename="graph.gv") + + g.attr("graph", splines="ortho", rankdir="BT", ranksep="1.5") + g.attr("node", shape="ellipse") + g.attr( + "graph", + label=f"<{scope or 'Full Dependency Graph'}
{target.id}>", + labelloc="t", + ) + + scopeInstance = None + + if scope is not None: + scopeInstance = registry.lookup(scope, Component) + + for component in registry.iterEnabled(target): + if not component.type == Kind.LIB and not showExe: + continue + + if ( + scopeInstance is not None + and component.id != scope + and component.id not in scopeInstance.resolved[target.id].required + ): + continue + + if component.resolved[target.id].enabled: + fillcolor = "lightgrey" if component.type == model.Kind.LIB else "lightblue" + shape = "plaintext" if not scope == component.id else "box" + + g.node( + component.id, + f"<{component.id}
{vt100.wordwrap(component.description, 40,newline='
')}>", + shape=shape, + style="filled", + fillcolor=fillcolor, + ) + + for req in component.requires: + g.edge(component.id, req) + + for req in component.provides: + isChosen = target.routing.get(req, None) == component.id + + g.edge( + req, + component.id, + arrowhead="none", + color=("blue" if isChosen else "black"), + ) + elif showDisabled: + g.node( + component.id, + f"<{component.id}
{vt100.wordwrap(component.description, 40,newline='
')}

{vt100.wordwrap(str(component.resolved[target.id].reason), 40,newline='
')}
>", + shape="plaintext", + style="filled", + fontcolor="#999999", + fillcolor="#eeeeee", + ) + + for req in component.requires: + g.edge(component.id, req, color="#aaaaaa") + + for req in component.provides: + g.edge(req, component.id, arrowhead="none", color="#aaaaaa") + + g.view(filename=os.path.join(target.builddir, "graph.gv")) + + +class GraphArgs(TargetArgs): + onlyLibs: bool = cli.arg(False, "only-libs", "Show only libraries") + showDisabled: bool = cli.arg(False, "show-disabled", "Show disabled components") + scope: str = cli.arg( + None, "scope", "Show only the specified component and its dependencies" + ) + + +@cli.command("g", "model/graph", "Show the dependency graph") +def _(args: GraphArgs): + registry = Registry.use(args) + target = Target.use(args) + + view( + registry, + target, + scope=args.scope, + showExe=not args.onlyLibs, + showDisabled=args.showDisabled, + ) diff --git a/cutekit/plugins.py b/cutekit/plugins.py index 5d38187..7c7bda4 100644 --- a/cutekit/plugins.py +++ b/cutekit/plugins.py @@ -2,7 +2,7 @@ import logging import os import sys -from . import shell, model, const, cli +from . import cli, shell, model, const, vt100 import importlib.util as importlib @@ -24,7 +24,7 @@ def load(path: str): spec.loader.exec_module(module) except Exception as e: _logger.error(f"Failed to load plugin {path}: {e}") - cli.warning(f"Plugin {path} loading skipped due to error") + vt100.warning(f"Plugin {path} loading skipped due to error") def loadAll(): @@ -51,6 +51,10 @@ def loadAll(): load(os.path.join(pluginDir, files)) -def setup(args: cli.Args): - if not bool(args.consumeOpt("safemode", False)): +class PluginsArgs: + safemod: bool = cli.arg(None, "safemode", "disable plugin loading") + + +def setup(args: PluginsArgs): + if args.safemod: loadAll() diff --git a/cutekit/pods.py b/cutekit/pods.py index 67fe65d..30a9935 100644 --- a/cutekit/pods.py +++ b/cutekit/pods.py @@ -77,16 +77,42 @@ IMAGES: dict[str, Image] = { } -def setup(args: cli.Args): +class PodSetupArgs: + pod: str | bool | None = cli.arg( + None, "pod", "Reincarnate cutekit within the specified pod" + ) + + +class PodNameArg: + name: str = cli.arg(None, "name", "Name of the pod") + + +class PodImageArg: + image: str = cli.arg(None, "image", "Base image to use for the pod") + + +class PodCreateArgs(PodNameArg, PodImageArg): + pass + + +class PodKillArgs(PodNameArg): + all: bool = cli.arg("a", "all", "Kill all pods") + + +class PodExecArgs(PodNameArg): + cmd: str = cli.operand("cmd", "Command to execute") + args: list[str] = cli.extra("args", "Extra arguments to pass to the command") + + +def setup(args: PodSetupArgs): """ Reincarnate cutekit within a docker container, this is useful for cross-compiling """ - pod = args.consumeOpt("pod", False) - if not pod: + if not args.pod: return - if isinstance(pod, str): - pod = pod.strip() + if isinstance(args.pod, str): + pod = args.pod.strip() pod = podPrefix + pod if pod is True: pod = defaultPodName @@ -114,7 +140,7 @@ def setup(args: cli.Args): @cli.command("p", "pod", "Manage pods") -def _(args: cli.Args): +def _(): pass @@ -125,22 +151,22 @@ def tryDecode(data: Optional[bytes], default: str = "") -> str: @cli.command("c", "pod/create", "Create a new pod") -def _(args: cli.Args): +def _(args: PodCreateArgs): """ Create a new development pod with cutekit installed and the current project mounted at /project """ project = model.Project.ensure() - name = str(args.consumeOpt("name", defaultPodName)) + name = args.name if not name.startswith(podPrefix): name = f"{podPrefix}{name}" - image = IMAGES[str(args.consumeOpt("image", defaultPodImage))] + image = IMAGES[args.image] client = docker.from_env() try: existing = client.containers.get(name) - if cli.ask(f"Pod '{name[len(podPrefix):]}' already exists, kill it?", False): + if vt100.ask(f"Pod '{name[len(podPrefix):]}' already exists, kill it?", False): existing.stop() existing.remove() else: @@ -177,10 +203,12 @@ def _(args: cli.Args): @cli.command("k", "pod/kill", "Stop and remove a pod") -def _(args: cli.Args): +def _(args: PodKillArgs): client = docker.from_env() - name = str(args.consumeOpt("name", defaultPodName)) - all = args.consumeOpt("all", False) is True + + name = args.name + all = args.all + if not name.startswith(podPrefix): name = f"{podPrefix}{name}" @@ -191,25 +219,19 @@ def _(args: cli.Args): continue container.stop() container.remove() - print(f"Pod '{container.name[len(podPrefix) :]}' killed") + print(f"Pod '{args.name}' killed") return container = client.containers.get(name) container.stop() container.remove() - print(f"Pod '{name[len(podPrefix) :]}' killed") + print(f"Pod '{args.name}' killed") except docker.errors.NotFound: - raise RuntimeError(f"Pod '{name[len(podPrefix):]}' does not exist") - - -@cli.command("s", "pod/shell", "Open a shell in a pod") -def _(args: cli.Args): - args.args.insert(0, "/bin/bash") - podExecCmd(args) + raise RuntimeError(f"Pod '{args.name}' does not exist") @cli.command("l", "pod/list", "List all pods") -def _(args: cli.Args): +def _(): client = docker.from_env() hasPods = False for container in client.containers.list(all=True): @@ -223,16 +245,13 @@ def _(args: cli.Args): @cli.command("e", "pod/exec", "Execute a command in a pod") -def podExecCmd(args: cli.Args): - name = str(args.consumeOpt("name", defaultPodName)) +def podExecCmd(args: PodExecArgs): + name = args.name + if not name.startswith(podPrefix): name = f"{podPrefix}{name}" - cmd = args.consumeArg() - if cmd is None: - raise RuntimeError("Missing command to execute") - try: - shell.exec("docker", "exec", "-it", name, cmd, *args.extra) + shell.exec("docker", "exec", "-it", name, args.cmd, *args.args) except Exception: - raise RuntimeError(f"Pod '{name[len(podPrefix):]}' does not exist") + raise RuntimeError(f"Pod '{args.name}' does not exist") diff --git a/cutekit/requirements.txt b/cutekit/requirements.txt index 892a2bf..46604ce 100644 --- a/cutekit/requirements.txt +++ b/cutekit/requirements.txt @@ -2,3 +2,4 @@ requests ~= 2.31.0 graphviz ~= 0.20.1 dataclasses-json ~= 0.6.2 docker ~= 6.1.3 +asserts ~= 0.12.0 diff --git a/cutekit/shell.py b/cutekit/shell.py index 239019f..82a38e7 100644 --- a/cutekit/shell.py +++ b/cutekit/shell.py @@ -14,7 +14,7 @@ import dataclasses as dt from pathlib import Path from typing import Literal, Optional -from . import const, cli +from . import cli, const _logger = logging.getLogger(__name__) @@ -426,28 +426,53 @@ def compress(path: str, dest: Optional[str] = None, format: str = "zstd") -> str # --- Commands --------------------------------------------------------------- # -@cli.command("s", "scripts", "Manage scripts") -def _(args: cli.Args): +@cli.command("s", "shell", "Shell like commands") +def _(): pass -@cli.command("d", "debug", "Debug a program") -def _(args: cli.Args): - wait = args.consumeOpt("wait", False) is True - debugger = args.consumeOpt("debugger", "lldb") - command = [str(args.consumeArg()), *args.extra] - debug(command, debugger=str(debugger), wait=wait) +class CommandArgs: + cmd: str = cli.operand("command", "The command to debug") + args: list[str] = cli.extra("args", "The arguments to pass to the command") + + def fullCmd(self) -> list[str]: + return [self.cmd, *self.args] -@cli.command("p", "profile", "Profile a program") -def _(args: cli.Args): - command = [str(args.consumeArg()), *args.extra] - profile(command) +class DebugArgs: + wait: bool = cli.arg(None, "wait", "Wait for the debugger to attach") + debbuger: str = cli.arg(None, "debugger", "The debugger to use", default="lldb") -@cli.command("c", "compress", "Compress a file or directory") -def _(args: cli.Args): - path = str(args.consumeArg()) - dest = args.consumeOpt("dest", None) - format = args.consumeOpt("format", "zstd") - compress(path, dest, format) +class _DebugArgs(DebugArgs, CommandArgs): + pass + + +@cli.command("d", "shell/debug", "Debug a program") +def _(args: _DebugArgs): + debug(args.fullCmd(), debugger=str(args.debugger), wait=args.wait) + + +class ProfileArgs: + rate: int = cli.arg(None, "rate", "The sampling rate", default=1000) + what: str = cli.arg(None, "what", "What to profile (cpu or mem)", default="cpu") + + +class _ProfileArgs(ProfileArgs, CommandArgs): + pass + + +@cli.command("p", "shell/profile", "Profile a program") +def _(args: _ProfileArgs): + profile(args.fullCmd(), rate=args.rate, what=args.what) + + +class CompresseArgs: + format: str = cli.arg(None, "format", "The compression format", default="zstd") + dest: Optional[str] = cli.arg(None, "dest", "The destination file or directory") + path: str = cli.operand("path", "The file or directory to compress") + + +@cli.command("c", "shell/compress", "Compress a file or directory") +def _(args: CompresseArgs): + compress(args.path, dest=args.dest, format=args.format) diff --git a/cutekit/vt100.py b/cutekit/vt100.py index 96c7eea..0ce993e 100644 --- a/cutekit/vt100.py +++ b/cutekit/vt100.py @@ -1,3 +1,7 @@ +import sys +from typing import Optional + + BLACK = "\033[30m" RED = "\033[31m" GREEN = "\033[32m" @@ -48,8 +52,38 @@ def indent(text: str, indent: int = 4) -> str: def title(text: str): - print(f"{BOLD}{text}{RESET}:") + print(f"{BOLD+WHITE+UNDERLINE}{text}{RESET}") + + +def subtitle(text: str): + print(f"{BOLD+WHITE}{text}{RESET}:") def p(text: str): return indent(wordwrap(text)) + + +def error(msg: str) -> None: + print(f"{RED}Error:{RESET} {msg}\n", file=sys.stderr) + + +def warning(msg: str) -> None: + print(f"{YELLOW}Warning:{RESET} {msg}\n", file=sys.stderr) + + +def ask(msg: str, default: Optional[bool] = None) -> bool: + if default is None: + msg = f"{msg} [y/n] " + elif default: + msg = f"{msg} [Y/n] " + else: + msg = f"{msg} [y/N] " + + while True: + result = input(msg).lower() + if result in ("y", "yes"): + return True + elif result in ("n", "no"): + return False + elif result == "" and default is not None: + return default diff --git a/doc/extends.md b/doc/extends.md index 0075e5f..910f6b0 100644 --- a/doc/extends.md +++ b/doc/extends.md @@ -11,6 +11,6 @@ For example you can add a new command to the CLI: from cutekit import cli @cli.command("h", "hello", "Print hello world") -def _(args: cli.Args) -> None: +def _() -> None: print("Hello world!") ``` diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..92769e3 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,239 @@ +from cutekit import cli, utils +from asserts import ( + assert_is, + assert_true, + assert_equal, + assert_raises, + assert_is_instance, +) + +# --- Parse Values ----------------------------------------------------------- # + + +def test_parse_int_val(): + assert_equal(cli.parseValue("1"), 1) + assert_equal(cli.parseValue("2"), 2) + assert_equal(cli.parseValue("+2"), +2) + assert_equal(cli.parseValue("-2"), -2) + + +def test_parse_true_val(): + assert_equal(cli.parseValue("true"), True) + assert_equal(cli.parseValue("True"), True) + assert_equal(cli.parseValue("y"), True) + assert_equal(cli.parseValue("yes"), True) + assert_equal(cli.parseValue("Y"), True) + assert_equal(cli.parseValue("Yes"), True) + + +def test_parse_false_val(): + assert_equal(cli.parseValue("false"), False) + assert_equal(cli.parseValue("False"), False) + assert_equal(cli.parseValue("n"), False) + assert_equal(cli.parseValue("no"), False) + assert_equal(cli.parseValue("N"), False) + assert_equal(cli.parseValue("No"), False) + + +def test_parse_str_val(): + assert_equal(cli.parseValue("foo"), "foo") + assert_equal(cli.parseValue("'foo'"), "foo") + assert_equal(cli.parseValue('"foo"'), "foo") + + +def test_parse_list_val(): + assert_equal(cli.parseValue("foo,bar"), ["foo", "bar"]) + assert_equal(cli.parseValue("'foo','bar'"), ["foo", "bar"]) + assert_equal(cli.parseValue('"foo","bar"'), ["foo", "bar"]) + + +# --- Parse Args ------------------------------------------------------------- # + + +def test_parse_short_arg(): + args = cli.parseArg("-a") + assert_equal(len(args), 1) + arg = args[0] + assert_is_instance(arg, cli.ArgumentToken) + assert_equal(arg.key, "a") + assert_equal(arg.value, True) + + +def test_parse_short_args(): + args = cli.parseArg("-abc") + assert_equal(len(args), 3) + arg = args[0] + assert_is_instance(arg, cli.ArgumentToken) + assert_equal(arg.key, "a") + assert_equal(arg.value, True) + + arg = args[1] + assert_is_instance(arg, cli.ArgumentToken) + assert_equal(arg.key, "b") + assert_equal(arg.value, True) + + arg = args[2] + assert_is_instance(arg, cli.ArgumentToken) + assert_equal(arg.key, "c") + assert_equal(arg.value, True) + + +def test_parse_long_arg(): + args = cli.parseArg("--foo") + assert_equal(len(args), 1) + arg = args[0] + assert_is_instance(arg, cli.ArgumentToken) + assert_equal(arg.key, "foo") + assert_equal(arg.value, True) + + +def test_parse_long_arg_with_value(): + args = cli.parseArg("--foo=bar") + assert_equal(len(args), 1) + arg = args[0] + assert_is_instance(arg, cli.ArgumentToken) + assert_equal(arg.key, "foo") + assert_equal(arg.value, "bar") + + +def test_parse_long_arg_with_value_list(): + args = cli.parseArg("--foo=bar,baz") + assert_equal(len(args), 1) + arg = args[0] + assert_is_instance(arg, cli.ArgumentToken) + assert_equal(arg.key, "foo") + assert_equal(arg.value, ["bar", "baz"]) + + +def test_parse_key_subkey_arg(): + args = cli.parseArg("--foo:bar") + assert_equal(len(args), 1) + arg = args[0] + assert_is_instance(arg, cli.ArgumentToken) + assert_equal(arg.key, "foo") + assert_equal(arg.subkey, "bar") + assert_equal(arg.value, True) + + +def extractParse(type: type[utils.T], args: list[str]) -> utils.T: + schema = cli.Schema.extract(type) + return schema.parse(args) + + +class IntArg: + value: int = cli.arg(None, "value") + + +def test_cli_arg_int(): + + assert_equal(extractParse(IntArg, ["--value=-1"]).value, -1) + assert_equal(extractParse(IntArg, ["--value=0"]).value, 0) + assert_equal(extractParse(IntArg, ["--value=1"]).value, 1) + + +class StrArg: + value: str = cli.arg(None, "value") + + +def test_cli_arg_str1(): + assert_equal(extractParse(StrArg, ["--value=foo"]).value, "foo") + assert_equal(extractParse(StrArg, ["--value='foo, bar'"]).value, "foo, bar") + + +class BoolArg: + value: bool = cli.arg(None, "value") + + +def test_cli_arg_bool(): + assert_is(extractParse(BoolArg, ["--value"]).value, True) + + assert_is(extractParse(BoolArg, ["--value=true"]).value, True) + assert_is(extractParse(BoolArg, ["--value=True"]).value, True) + assert_is(extractParse(BoolArg, ["--value=y"]).value, True) + assert_is(extractParse(BoolArg, ["--value=yes"]).value, True) + assert_is(extractParse(BoolArg, ["--value=Y"]).value, True) + assert_is(extractParse(BoolArg, ["--value=Yes"]).value, True) + assert_is(extractParse(BoolArg, ["--value=1"]).value, True) + + assert_is(extractParse(BoolArg, ["--value=false"]).value, False) + assert_is(extractParse(BoolArg, ["--value=False"]).value, False) + assert_is(extractParse(BoolArg, ["--value=n"]).value, False) + assert_is(extractParse(BoolArg, ["--value=no"]).value, False) + assert_is(extractParse(BoolArg, ["--value=N"]).value, False) + assert_is(extractParse(BoolArg, ["--value=No"]).value, False) + assert_is(extractParse(BoolArg, ["--value=0"]).value, False) + + +class IntListArg: + value: list[int] = cli.arg(None, "value") + + +def test_cli_arg_list_int1(): + assert_equal(extractParse(IntListArg, []).value, []) + assert_equal(extractParse(IntListArg, ["--value=1", "--value=2"]).value, [1, 2]) + assert_equal(extractParse(IntListArg, ["--value=1,2"]).value, [1, 2]) + + +class StrListArg: + value: list[str] = cli.arg(None, "value") + + +def test_cli_arg_list_str(): + assert_equal(extractParse(StrListArg, []).value, []) + + assert_equal( + extractParse(StrListArg, ["--value=foo", "--value=bar"]).value, + [ + "foo", + "bar", + ], + ) + + assert_equal(extractParse(StrListArg, ["--value=foo,bar"]).value, ["foo", "bar"]) + assert_equal(extractParse(StrListArg, ["--value=foo,bar"]).value, ["foo", "bar"]) + assert_equal(extractParse(StrListArg, ["--value='foo,bar'"]).value, ["foo,bar"]) + assert_equal(extractParse(StrListArg, ["--value='foo, bar'"]).value, ["foo, bar"]) + assert_equal(extractParse(StrListArg, ['--value="foo, bar"']).value, ["foo, bar"]) + + +class StrDictArg: + value: dict[str, str] = cli.arg(None, "value") + + +def test_cli_arg_dict_str(): + assert_equal(extractParse(StrDictArg, ["--value:foo=bar"]).value, {"foo": "bar"}) + assert_equal( + extractParse(StrDictArg, ["--value:foo=bar", "--value:baz=qux"]).value, + { + "foo": "bar", + "baz": "qux", + }, + ) + + +class StrOptArg: + value: str | None = cli.arg(None, "value") + + +def test_cli_arg_str_opt(): + assert_equal(extractParse(StrOptArg, []).value, None) + assert_equal(extractParse(StrOptArg, ["--value=foo"]).value, "foo") + + +class FooArg: + foo: str = cli.arg(None, "foo") + + +class BazArg: + baz: str = cli.arg(None, "baz") + + +class BarArg(FooArg, BazArg): + bar: str = cli.arg(None, "bar") + + +def test_cli_arg_inheritance(): + res = extractParse(BarArg, ["--foo=foo", "--bar=bar", "--baz=baz"]) + assert_equal(res.foo, "foo") + assert_equal(res.bar, "bar") + assert_equal(res.baz, "baz")