752 lines
20 KiB
Python
752 lines
20 KiB
Python
from enum import Enum
|
|
import os
|
|
import sys
|
|
from types import GenericAlias
|
|
import typing as tp
|
|
import dataclasses as dt
|
|
import logging
|
|
|
|
from typing import Any, Callable, Optional
|
|
from cutekit import vt100, const, utils
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
# --- Scan -------------------------------------------------------------- #
|
|
|
|
|
|
class Scan:
|
|
_src: str
|
|
_off: int
|
|
_save: list[int]
|
|
|
|
def __init__(self, src: str, off: int = 0):
|
|
self._src = src
|
|
self._off = 0
|
|
self._save = []
|
|
|
|
def curr(self) -> str:
|
|
if self.eof():
|
|
return "\0"
|
|
return self._src[self._off]
|
|
|
|
def next(self) -> str:
|
|
if self.eof():
|
|
return "\0"
|
|
|
|
self._off += 1
|
|
return self.curr()
|
|
|
|
def peek(self, off: int = 1) -> str:
|
|
if self._off + off >= len(self._src):
|
|
return "\0"
|
|
|
|
return self._src[self._off + off]
|
|
|
|
def eof(self) -> bool:
|
|
return self._off >= len(self._src)
|
|
|
|
def skipStr(self, s: str) -> bool:
|
|
if self._src[self._off :].startswith(s):
|
|
self._off += len(s)
|
|
return True
|
|
|
|
return False
|
|
|
|
def isStr(self, s: str) -> bool:
|
|
self.save()
|
|
if self.skipStr(s):
|
|
self.restore()
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
def save(self) -> None:
|
|
self._save.append(self._off)
|
|
|
|
def restore(self) -> None:
|
|
self._off = self._save.pop()
|
|
|
|
def skipWhitespace(self) -> bool:
|
|
result = False
|
|
while not self.eof() and self.curr().isspace():
|
|
self.next()
|
|
result = True
|
|
return result
|
|
|
|
def skipSeparator(self, sep: str) -> bool:
|
|
self.save()
|
|
self.skipWhitespace()
|
|
if self.skipStr(sep):
|
|
self.skipWhitespace()
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
def isSeparator(self, sep: str) -> bool:
|
|
self.save()
|
|
self.skipWhitespace()
|
|
if self.skipStr(sep):
|
|
self.skipWhitespace()
|
|
self.restore()
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
def skipKeyword(self, keyword: str) -> bool:
|
|
self.save()
|
|
self.skipWhitespace()
|
|
if self.skipStr(keyword) and not self.curr().isalnum():
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
def isKeyword(self, keyword: str) -> bool:
|
|
self.save()
|
|
self.skipWhitespace()
|
|
if self.skipStr(keyword) and not self.curr().isalnum():
|
|
self.restore()
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
|
|
# --- Parser ------------------------------------------------------------ #
|
|
|
|
PrimitiveValue = str | bool | int
|
|
Object = dict[str, PrimitiveValue]
|
|
List = list[PrimitiveValue]
|
|
Value = str | bool | int | Object | List
|
|
|
|
|
|
@dt.dataclass
|
|
class Token:
|
|
pass
|
|
|
|
|
|
@dt.dataclass
|
|
class ArgumentToken(Token):
|
|
key: str
|
|
subkey: Optional[str]
|
|
value: Value
|
|
short: bool
|
|
|
|
|
|
@dt.dataclass
|
|
class OperandToken(Token):
|
|
value: str
|
|
|
|
|
|
@dt.dataclass
|
|
class ExtraToken(Token):
|
|
args: list[str]
|
|
|
|
|
|
def _parseIdent(s: Scan) -> str:
|
|
res = ""
|
|
while not s.eof() and (s.curr().isalnum() or s.curr() in "_-+"):
|
|
res += s.curr()
|
|
s.next()
|
|
return res
|
|
|
|
|
|
def _parseUntilComma(s: Scan) -> str:
|
|
res = ""
|
|
while not s.eof() and s.curr() != ",":
|
|
res += s.curr()
|
|
s.next()
|
|
return res
|
|
|
|
|
|
def _expectIdent(s: Scan) -> str:
|
|
res = _parseIdent(s)
|
|
if len(res) == 0:
|
|
raise RuntimeError("Expected identifier")
|
|
return res
|
|
|
|
|
|
def _parseString(s: Scan, quote: str) -> str:
|
|
s.skipStr(quote)
|
|
res = ""
|
|
escaped = False
|
|
while not s.eof():
|
|
c = s.curr()
|
|
if escaped:
|
|
res += c
|
|
escaped = False
|
|
elif c == "\\":
|
|
escaped = True
|
|
elif c == quote:
|
|
break
|
|
else:
|
|
res += c
|
|
s.next()
|
|
if not s.skipStr(quote):
|
|
raise RuntimeError("Unterminated string")
|
|
return res
|
|
|
|
|
|
def _tryParseInt(ident) -> Optional[int]:
|
|
try:
|
|
return int(ident)
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
def _parsePrimitive(s: Scan) -> PrimitiveValue:
|
|
if s.curr() == '"':
|
|
return _parseString(s, '"')
|
|
elif s.curr() == "'":
|
|
return _parseString(s, "'")
|
|
else:
|
|
ident = _parseUntilComma(s)
|
|
|
|
if ident in ("true", "True", "y", "yes", "Y", "Yes"):
|
|
return True
|
|
elif ident in ("false", "False", "n", "no", "N", "No"):
|
|
return False
|
|
elif n := _tryParseInt(ident):
|
|
return n
|
|
else:
|
|
return ident
|
|
|
|
|
|
def _parseValue(s: Scan) -> Value:
|
|
lhs = _parsePrimitive(s)
|
|
if s.eof():
|
|
return lhs
|
|
values: List = [lhs]
|
|
while not s.eof() and s.skipStr(","):
|
|
values.append(_parsePrimitive(s))
|
|
return values
|
|
|
|
|
|
def parseValue(s: str) -> Value:
|
|
return _parseValue(Scan(s))
|
|
|
|
|
|
def parseArg(arg: str) -> list[Token]:
|
|
s = Scan(arg)
|
|
if s.skipStr("--"):
|
|
key = _expectIdent(s)
|
|
subkey = ""
|
|
if s.skipStr(":"):
|
|
subkey = _expectIdent(s)
|
|
if s.skipStr("="):
|
|
value = _parseValue(s)
|
|
else:
|
|
value = True
|
|
return [ArgumentToken(key, subkey, value, False)]
|
|
elif s.skipStr("-"):
|
|
res = []
|
|
while not s.eof():
|
|
key = s.curr()
|
|
if not key.isalnum():
|
|
raise RuntimeError("Expected alphanumeric")
|
|
s.next()
|
|
res.append(ArgumentToken(key, None, True, True))
|
|
return tp.cast(list[Token], res)
|
|
else:
|
|
return [OperandToken(arg)]
|
|
|
|
|
|
def parseArgs(args: list[str]) -> list[Token]:
|
|
res: list[Token] = []
|
|
while len(args) > 0:
|
|
arg = args.pop(0)
|
|
if arg == "--":
|
|
res.append(ExtraToken(args))
|
|
break
|
|
else:
|
|
res.extend(parseArg(arg))
|
|
return res
|
|
|
|
|
|
# --- Schema ----------------------------------------------------------------- #
|
|
|
|
|
|
class FieldKind(Enum):
|
|
FLAG = 0
|
|
OPERAND = 1
|
|
EXTRA = 2
|
|
|
|
|
|
@dt.dataclass
|
|
class Field:
|
|
kind: FieldKind
|
|
shortName: Optional[str]
|
|
longName: str
|
|
description: str = ""
|
|
default: Any = None
|
|
|
|
_fieldName: str | None = dt.field(init=False, default=None)
|
|
_fieldType: type | None = dt.field(init=False, default=None)
|
|
|
|
def bind(self, typ: type, name: str):
|
|
self._fieldName = name
|
|
self._fieldType = typ.__annotations__[name]
|
|
if self.longName is None:
|
|
self.longName = name
|
|
|
|
def isBool(self) -> bool:
|
|
return self._fieldType == bool
|
|
|
|
def isList(self) -> bool:
|
|
return (
|
|
isinstance(self._fieldType, GenericAlias)
|
|
and self._fieldType.__origin__ == list
|
|
)
|
|
|
|
def isDict(self) -> bool:
|
|
return (
|
|
isinstance(self._fieldType, GenericAlias)
|
|
and self._fieldType.__origin__ == dict
|
|
)
|
|
|
|
def innerType(self) -> type:
|
|
if self.isList():
|
|
return self._fieldType.__args__[0]
|
|
|
|
if self.isDict():
|
|
return self._fieldType.__args__[1]
|
|
|
|
return self._fieldType
|
|
|
|
def defaultValue(self) -> Any:
|
|
if self._fieldType is None:
|
|
return None
|
|
|
|
if self.default is not None:
|
|
return self.default
|
|
|
|
if self._fieldType == bool:
|
|
return False
|
|
elif self._fieldType == int:
|
|
return 0
|
|
elif self._fieldType == str:
|
|
return ""
|
|
elif self.isList():
|
|
return []
|
|
elif self.isDict():
|
|
return {}
|
|
else:
|
|
return None
|
|
|
|
def setDefault(self, obj: Any):
|
|
if self._fieldName:
|
|
setattr(obj, self._fieldName, self.defaultValue())
|
|
|
|
def castValue(self, val: Any, subkey: Optional[str]):
|
|
try:
|
|
val = int(val)
|
|
except ValueError:
|
|
pass
|
|
except TypeError:
|
|
pass
|
|
|
|
if isinstance(val, list):
|
|
return [self.castValue(v, subkey) for v in val]
|
|
|
|
val = self.innerType()(val)
|
|
|
|
if self.isDict() and subkey:
|
|
return {subkey: val}
|
|
|
|
if self.isDict():
|
|
return {str(val): True}
|
|
|
|
return val
|
|
|
|
def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None):
|
|
value = self.castValue(value, subkey)
|
|
field = getattr(obj, self._fieldName)
|
|
if isinstance(field, list):
|
|
if isinstance(value, list):
|
|
field.extend(value)
|
|
else:
|
|
field.append(value)
|
|
elif isinstance(field, dict):
|
|
field.update(value)
|
|
else:
|
|
setattr(obj, self._fieldName, value)
|
|
|
|
def getAttr(self, obj: Any) -> Any:
|
|
return getattr(obj, self._fieldName)
|
|
|
|
|
|
def arg(
|
|
shortName: str | None = None,
|
|
longName: str = "",
|
|
description: str = "",
|
|
default: Any = None,
|
|
) -> Any:
|
|
return Field(FieldKind.FLAG, shortName, longName, description, default)
|
|
|
|
|
|
def operand(longName: str = "", description: str = "", default: Any = None) -> Any:
|
|
return Field(FieldKind.OPERAND, None, longName, description, default)
|
|
|
|
|
|
def extra(longName: str = "", description: str = "") -> Any:
|
|
return Field(FieldKind.EXTRA, None, longName, description)
|
|
|
|
|
|
@dt.dataclass
|
|
class Schema:
|
|
typ: Optional[type] = None
|
|
args: list[Field] = dt.field(default_factory=list)
|
|
operands: list[Field] = dt.field(default_factory=list)
|
|
extras: Optional[Field] = None
|
|
|
|
@staticmethod
|
|
def extract(typ: type) -> "Schema":
|
|
s = Schema(typ)
|
|
|
|
for f in typ.__annotations__.keys():
|
|
field = getattr(typ, f, None)
|
|
|
|
if field is None:
|
|
raise ValueError(f"Field '{f}' is not defined")
|
|
|
|
if not isinstance(field, Field):
|
|
raise ValueError(f"Field '{f}' is not a Field")
|
|
|
|
field.bind(typ, f)
|
|
|
|
if field.kind == FieldKind.FLAG:
|
|
s.args.append(field)
|
|
elif field.kind == FieldKind.OPERAND:
|
|
s.operands.append(field)
|
|
elif field.kind == FieldKind.EXTRA:
|
|
if s.extras:
|
|
raise ValueError("Only one extra argument is allowed")
|
|
s.extras = field
|
|
|
|
# now move to the base class
|
|
for base in typ.__bases__:
|
|
if base == object:
|
|
continue
|
|
baseSchema = Schema.extract(base)
|
|
s.args.extend(baseSchema.args)
|
|
s.operands.extend(baseSchema.operands)
|
|
if not s.extras:
|
|
s.extras = baseSchema.extras
|
|
elif baseSchema.extras:
|
|
raise ValueError("Only one extra argument is allowed")
|
|
|
|
return s
|
|
|
|
@staticmethod
|
|
def extractFromCallable(fn: tp.Callable) -> Optional["Schema"]:
|
|
typ: type | None = (
|
|
None
|
|
if len(fn.__annotations__) == 0
|
|
else next(iter(fn.__annotations__.values()))
|
|
)
|
|
|
|
if typ is None:
|
|
return None
|
|
|
|
return Schema.extract(typ)
|
|
|
|
def usage(self) -> str:
|
|
res = ""
|
|
for arg in self.args:
|
|
flag = ""
|
|
if arg.shortName:
|
|
flag += f"-{arg.shortName}"
|
|
|
|
if arg.longName:
|
|
if flag:
|
|
flag += ", "
|
|
flag += f"--{arg.longName}"
|
|
res += f"[{flag}] "
|
|
for operand in self.operands:
|
|
res += f"<{operand.longName}> "
|
|
if self.extras:
|
|
res += f"[-- {self.extras.longName}]"
|
|
return res
|
|
|
|
def _lookupArg(self, key: str, short: bool) -> Field:
|
|
for arg in self.args:
|
|
if short and arg.shortName == key:
|
|
return arg
|
|
elif not short and arg.longName == key:
|
|
return arg
|
|
raise ValueError(f"Unknown argument '{key}'")
|
|
|
|
def _setOperand(self, obj: Any, value: Any):
|
|
if len(self.operands) == 0:
|
|
raise ValueError(f"Unexpected operand '{value}'")
|
|
|
|
for operand in self.operands:
|
|
if operand.getAttr(obj) is None or operand.isList():
|
|
operand.putValue(obj, value)
|
|
return
|
|
|
|
def _instanciate(self) -> Any:
|
|
if self.typ is None:
|
|
return None
|
|
res = self.typ()
|
|
for arg in self.args:
|
|
arg.setDefault(res)
|
|
|
|
for operand in self.operands:
|
|
if operand.isList():
|
|
setattr(res, operand._fieldName, [])
|
|
else:
|
|
setattr(res, operand._fieldName, None)
|
|
|
|
if self.extras:
|
|
self.extras.setDefault(res)
|
|
|
|
return res
|
|
|
|
def parse(self, args: list[str]) -> Any:
|
|
res = self._instanciate()
|
|
if res is None:
|
|
if len(args) > 0:
|
|
raise ValueError("Unexpected arguments")
|
|
else:
|
|
return None
|
|
|
|
stack = args[:]
|
|
while len(stack) > 0:
|
|
if stack[0] == "--":
|
|
if not self.extras:
|
|
raise ValueError("Unexpected '--'")
|
|
self.extras.putValue(res, stack[1:])
|
|
break
|
|
|
|
toks = parseArg(stack.pop(0))
|
|
while len(toks) > 0:
|
|
tok = toks.pop(0)
|
|
if isinstance(tok, ArgumentToken):
|
|
arg = self._lookupArg(tok.key, tok.short)
|
|
if tok.short and not arg.isBool():
|
|
if len(stack) == 0:
|
|
raise ValueError(
|
|
f"Expected value for argument '-{arg.shortName}'"
|
|
)
|
|
|
|
arg.putValue(res, parseValue(stack.pop(0)))
|
|
else:
|
|
arg.putValue(res, tok.value, tok.subkey)
|
|
elif isinstance(tok, OperandToken):
|
|
self._setOperand(res, tok.value)
|
|
else:
|
|
raise ValueError(f"Unexpected token: {type(tok)}")
|
|
|
|
return res
|
|
|
|
|
|
@dt.dataclass
|
|
class Command:
|
|
shortName: Optional[str]
|
|
path: list[str] = dt.field(default_factory=list)
|
|
description: str = ""
|
|
epilog: Optional[str] = None
|
|
|
|
schema: Optional[Schema] = None
|
|
callable: Optional[tp.Callable] = None
|
|
subcommands: dict[str, "Command"] = dt.field(default_factory=dict)
|
|
populated: bool = False
|
|
|
|
@property
|
|
def longName(self) -> str:
|
|
return self.path[-1]
|
|
|
|
def _spliceArgs(self, args: list[str]) -> tuple[list[str], list[str]]:
|
|
rest = args[:]
|
|
curr = []
|
|
if len(self.subcommands) > 0:
|
|
while len(rest) > 0 and rest[0].startswith("-") and rest[0] != "--":
|
|
curr.append(rest.pop(0))
|
|
else:
|
|
curr = rest
|
|
rest = []
|
|
return curr, rest
|
|
|
|
def help(self):
|
|
vt100.title(f"{self.longName}")
|
|
print()
|
|
|
|
vt100.subtitle("Usage")
|
|
print(vt100.indent(f"{' '.join(self.path)}{self.usage()}"))
|
|
print()
|
|
|
|
vt100.subtitle("Description")
|
|
print(vt100.indent(self.description))
|
|
print()
|
|
|
|
if self.schema and any(self.schema.args):
|
|
vt100.subtitle("Options")
|
|
for arg in self.schema.args:
|
|
flag = ""
|
|
if arg.shortName:
|
|
flag += f"-{arg.shortName}"
|
|
|
|
if arg.longName:
|
|
if flag:
|
|
flag += ", "
|
|
flag += f"--{arg.longName}"
|
|
|
|
if arg.description:
|
|
flag += f" {arg.description}"
|
|
|
|
print(vt100.indent(flag))
|
|
print()
|
|
|
|
if any(self.subcommands):
|
|
vt100.subtitle("Subcommands")
|
|
for name, sub in self.subcommands.items():
|
|
print(
|
|
vt100.indent(
|
|
f"{vt100.GREEN}{sub.shortName or ' '}{vt100.RESET} {name} - {sub.description}"
|
|
)
|
|
)
|
|
print()
|
|
|
|
if self.epilog:
|
|
print(self.epilog)
|
|
print()
|
|
|
|
def usage(self) -> str:
|
|
res = " "
|
|
if self.schema:
|
|
res += self.schema.usage()
|
|
|
|
if len(self.subcommands) == 1:
|
|
res += "[subcommand] [args...]"
|
|
|
|
elif len(self.subcommands) > 0:
|
|
res += "{"
|
|
first = True
|
|
for name, cmd in self.subcommands.items():
|
|
if not first:
|
|
res += "|"
|
|
res += f"{name}"
|
|
first = False
|
|
res += "}"
|
|
|
|
res += " [args...]"
|
|
|
|
return res
|
|
|
|
def lookupSubcommand(self, name: str) -> "Command":
|
|
if name in self.subcommands:
|
|
return self.subcommands[name]
|
|
for sub in self.subcommands.values():
|
|
if sub.shortName == name:
|
|
return sub
|
|
raise ValueError(f"Unknown subcommand '{name}'")
|
|
|
|
def invoke(self, argv: list[str]):
|
|
if self.callable:
|
|
if self.schema:
|
|
args = self.schema.parse(argv)
|
|
self.callable(args)
|
|
else:
|
|
self.callable()
|
|
|
|
def eval(self, args: list[str]):
|
|
cmd = args.pop(0)
|
|
curr, rest = self._spliceArgs(args)
|
|
|
|
if "-h" in curr or "--help" in curr:
|
|
if len(self.path) == 1:
|
|
# HACK: This is a special case for the root command
|
|
# it need to always be run because it might
|
|
# load some plugins that will register subcommands
|
|
# that need to be displayed in the help.
|
|
self.invoke([])
|
|
self.help()
|
|
return
|
|
|
|
if "-u" in curr or "--usage" in curr:
|
|
if len(self.path) == 1:
|
|
# HACK: Same as the help flag, the root command needs to be
|
|
# always run to load plugins
|
|
self.invoke([])
|
|
print("Usage: " + cmd + self.usage(), end="\n\n")
|
|
return
|
|
|
|
try:
|
|
self.invoke(curr)
|
|
|
|
if self.subcommands:
|
|
if len(rest) > 0:
|
|
if not self.populated:
|
|
raise ValueError("Expected subcommand")
|
|
else:
|
|
self.lookupSubcommand(rest[0]).eval(rest)
|
|
else:
|
|
print("Usage: " + cmd + self.usage(), end="\n\n")
|
|
return
|
|
elif len(rest) > 0:
|
|
raise ValueError(f"Unknown operand '{rest[0]}'")
|
|
|
|
except ValueError as e:
|
|
vt100.error(str(e))
|
|
print("Usage: " + cmd + self.usage(), end="\n\n")
|
|
return
|
|
|
|
|
|
_root = Command(None, const.ARGV0)
|
|
|
|
|
|
def _splitPath(path: str) -> list[str]:
|
|
if path == "/":
|
|
return []
|
|
return path.split("/")
|
|
|
|
|
|
def _resolvePath(path: list[str]) -> Command:
|
|
if path == "/":
|
|
return _root
|
|
cmd = _root
|
|
for name in path:
|
|
if name not in cmd.subcommands:
|
|
cmd.subcommands[name] = Command(None, name)
|
|
cmd = cmd.subcommands[name]
|
|
return cmd
|
|
|
|
|
|
def command(shortName: Optional[str], longName: str, description: str = "") -> Callable:
|
|
def wrap(fn: Callable):
|
|
schema = Schema.extractFromCallable(fn)
|
|
path = _splitPath(longName)
|
|
cmd = _resolvePath(path)
|
|
|
|
_logger.info(f"Registering command '{'.'.join(path)}'")
|
|
if cmd.populated:
|
|
raise ValueError(f"Command '{longName}' is already defined")
|
|
|
|
cmd.shortName = shortName
|
|
cmd.description = description
|
|
cmd.schema = schema
|
|
cmd.callable = fn
|
|
cmd.populated = True
|
|
cmd.path = [const.ARGV0] + path
|
|
return fn
|
|
|
|
return wrap
|
|
|
|
|
|
def usage():
|
|
print(f"Usage: {const.ARGV0} {_root.usage()}")
|
|
|
|
|
|
def exec():
|
|
extra = os.environ.get("CK_EXTRA_ARGS", None)
|
|
args = [const.ARGV0] + (extra.split(" ") if extra else []) + sys.argv[1:]
|
|
_root.eval(args)
|
|
|
|
|
|
def defaults(typ: type[utils.T]) -> utils.T:
|
|
schema = Schema.extract(typ)
|
|
return tp.cast(utils.T, schema._instanciate())
|