cutekit/cutekit/cli2.py
2024-02-13 09:56:43 +01:00

570 lines
14 KiB
Python

import sys
from enum import Enum
import typing as tp
import dataclasses as dt
from typing import Any, Callable, Optional
T = tp.TypeVar("T")
# --- Scan -------------------------------------------------------------- #
class Scan:
_src: str
_off: int
_save: list[int]
def __init__(self, src: str, off: int = 0):
self._src = src
self._off = 0
self._save = []
def curr(self) -> str:
if self.eof():
return "\0"
return self._src[self._off]
def next(self) -> str:
if self.eof():
return "\0"
self._off += 1
return self.curr()
def peek(self, off: int = 1) -> str:
if self._off + off >= len(self._src):
return "\0"
return self._src[self._off + off]
def eof(self) -> bool:
return self._off >= len(self._src)
def skipStr(self, s: str) -> bool:
if self._src[self._off :].startswith(s):
self._off += len(s)
return True
return False
def isStr(self, s: str) -> bool:
self.save()
if self.skipStr(s):
self.restore()
return True
self.restore()
return False
def save(self) -> None:
self._save.append(self._off)
def restore(self) -> None:
self._off = self._save.pop()
def skipWhitespace(self) -> bool:
result = False
while not self.eof() and self.curr().isspace():
self.next()
result = True
return result
def skipSeparator(self, sep: str) -> bool:
self.save()
self.skipWhitespace()
if self.skipStr(sep):
self.skipWhitespace()
return True
self.restore()
return False
def isSeparator(self, sep: str) -> bool:
self.save()
self.skipWhitespace()
if self.skipStr(sep):
self.skipWhitespace()
self.restore()
return True
self.restore()
return False
def skipKeyword(self, keyword: str) -> bool:
self.save()
self.skipWhitespace()
if self.skipStr(keyword) and not self.curr().isalnum():
return True
self.restore()
return False
def isKeyword(self, keyword: str) -> bool:
self.save()
self.skipWhitespace()
if self.skipStr(keyword) and not self.curr().isalnum():
self.restore()
return True
self.restore()
return False
# --- Parser ------------------------------------------------------------ #
PrimitiveValue = str | bool | int
Object = dict[str, PrimitiveValue]
List = list[PrimitiveValue]
Value = str | bool | int | Object | List
@dt.dataclass
class Token:
pass
@dt.dataclass
class ArgumentToken(Token):
key: str
subkey: Optional[str]
value: Value
short: bool
@dt.dataclass
class OperandToken(Token):
value: str
@dt.dataclass
class ExtraToken(Token):
args: list[str]
def _parseIdent(s: Scan) -> str:
res = ""
while not s.eof() and (s.curr().isalnum() or s.curr() in "_-+"):
res += s.curr()
s.next()
return res
def _parseUntilComma(s: Scan) -> str:
res = ""
while not s.eof() and s.curr() != ",":
res += s.curr()
s.next()
return res
def _expectIdent(s: Scan) -> str:
res = _parseIdent(s)
if len(res) == 0:
raise RuntimeError("Expected identifier")
return res
def _parseString(s: Scan, quote: str) -> str:
s.skipStr(quote)
res = ""
escaped = False
while not s.eof():
c = s.curr()
if escaped:
res += c
escaped = False
elif c == "\\":
escaped = True
elif c == quote:
break
else:
res += c
s.next()
if not s.skipStr(quote):
raise RuntimeError("Unterminated string")
return res
def _tryParseInt(ident) -> Optional[int]:
try:
return int(ident)
except ValueError:
return None
def _parsePrimitive(s: Scan) -> PrimitiveValue:
if s.curr() == '"':
return _parseString(s, '"')
elif s.curr() == "'":
return _parseString(s, "'")
else:
ident = _parseUntilComma(s)
if ident in ("true", "True", "y", "yes", "Y", "Yes"):
return True
elif ident in ("false", "False", "n", "no", "N", "No"):
return False
elif n := _tryParseInt(ident):
return n
else:
return ident
def _parseValue(s: Scan) -> Value:
lhs = _parsePrimitive(s)
if s.eof():
return lhs
values: List = [lhs]
while not s.eof() and s.skipStr(","):
values.append(_parsePrimitive(s))
return values
def parseValue(s: str) -> Value:
return _parseValue(Scan(s))
def parseArg(arg: str) -> list[Token]:
s = Scan(arg)
if s.skipStr("--"):
key = _expectIdent(s)
subkey = ""
if s.skipStr(":"):
subkey = _expectIdent(s)
if s.skipStr("="):
value = _parseValue(s)
else:
value = True
return [ArgumentToken(key, subkey, value, False)]
elif s.skipStr("-"):
res = []
while not s.eof():
key = s.curr()
if not key.isalnum():
raise RuntimeError("Expected alphanumeric")
s.next()
res.append(ArgumentToken(key, None, True, True))
return res
else:
return [OperandToken(arg)]
def parseArgs(args: list[str]) -> list[Token]:
res: list[Token] = []
while len(args) > 0:
arg = args.pop(0)
if arg == "--":
res.append(ExtraToken(args))
break
else:
res.extend(parseArg(arg))
return res
# --- Schema ----------------------------------------------------------------- #
class FieldType(Enum):
FLAG = 0
OPERAND = 1
EXTRA = 2
@dt.dataclass
class Field:
type: FieldType
shortName: Optional[str]
longName: str
description: str = ""
default: Any = None
_fieldName: str | None = dt.field(init=False, default=None)
_fieldType: type | None = dt.field(init=False, default=None)
def bind(self, typ: type, name: str):
self._fieldName = name
self._fieldType = typ.__annotations__[name]
if self.longName is None:
self.longName = name
def defaultValue(self) -> Any:
if self._fieldType is None:
return None
if self.default is not None:
return self.default
if self._fieldType == bool:
return False
elif self._fieldType == int:
return 0
elif self._fieldType == str:
return ""
elif self._fieldType == list:
return []
elif self._fieldType == dict:
return {}
else:
return None
def setDefault(self, obj: Any):
if self._fieldName:
setattr(obj, self._fieldName, self.default)
def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None):
setattr(obj, self._fieldName, value)
def getAttr(self, obj: Any) -> Any:
return getattr(obj, self._fieldName)
def arg(
shortName: str | None = None,
longName: str = "",
description: str = "",
default: Any = None,
) -> Any:
return Field(FieldType.FLAG, shortName, longName, description, default)
def operand(longName: str = "", description: str = "") -> Any:
return Field(FieldType.OPERAND, None, longName, description)
def extra(longName: str = "", description: str = "") -> Any:
return Field(FieldType.EXTRA, None, longName, description)
@dt.dataclass
class Schema:
typ: Optional[type] = None
args: list[Field] = dt.field(default_factory=list)
operands: list[Field] = dt.field(default_factory=list)
extras: Optional[Field] = None
@staticmethod
def extract(typ: type) -> "Schema":
s = Schema(typ)
for f in typ.__annotations__.keys():
field = getattr(typ, f, None)
if field is None:
raise ValueError(f"Field '{f}' is not defined")
if not isinstance(field, Field):
raise ValueError(f"Field '{f}' is not a Field")
field.bind(typ, f)
if field.type == FieldType.FLAG:
s.args.append(field)
elif field.type == FieldType.OPERAND:
s.operands.append(field)
elif field.type == FieldType.EXTRA:
if s.extras:
raise ValueError("Only one extra argument is allowed")
s.extras = field
return s
@staticmethod
def extractFromCallable(fn: tp.Callable) -> "Schema":
typ: type | None = (
None
if len(fn.__annotations__) == 0
else next(iter(fn.__annotations__.values()))
)
if typ is None:
return Schema()
return Schema.extract(typ)
def usage(self) -> str:
res = ""
for arg in self.args:
flag = ""
if arg.shortName:
flag += f"-{arg.shortName}"
if arg.longName:
if flag:
flag += ", "
flag += f"--{arg.longName}"
res += f"[{flag}] "
for operand in self.operands:
res += f"<{operand.longName}> "
if self.extras:
res += f"[-- {self.extras.longName}]"
return res
def _lookupArg(self, key: str, short: bool) -> Field:
for arg in self.args:
if short and arg.shortName == key:
return arg
elif not short and arg.longName == key:
return arg
raise ValueError(f"Unknown argument '{key}'")
def _setOperand(self, tok: OperandToken):
return
def _instanciate(self) -> Any:
if self.typ is None:
return None
res = self.typ()
for arg in self.args:
field = arg.field
if field:
field.put(res, arg.default or field.default())
else:
raise ValueError("Argument has no field")
return res
def parse(self, args: list[str]) -> Any:
res = self._instanciate()
if res is None:
if len(args) > 0:
raise ValueError("Unexpected arguments")
else:
return None
stack = args[:]
while len(stack) > 0:
if stack[0] == "--":
if not self.extras:
raise ValueError("Unexpected '--'")
self._setExtra(res, stack.pop(0))
break
toks = parseArg(stack.pop(0))
for tok in toks:
if isinstance(tok, ArgumentToken):
arg = self._lookupArg(tok.key, tok.short)
arg.putValue(res, tok.value, tok.subkey)
elif isinstance(tok, OperandToken):
self._setOperand(tok)
else:
raise ValueError(f"Unexpected token: {type(tok)}")
return res
@dt.dataclass
class Command:
shortName: Optional[str]
longName: str
description: str = ""
schema: Optional[Schema] = None
callable: Optional[tp.Callable] = None
subcommands: dict[str, "Command"] = dt.field(default_factory=dict)
populated: bool = False
def _spliceArgs(self, args: list[str]) -> tuple[list[str], list[str]]:
rest = args[:]
curr = []
if len(self.subcommands) > 0:
while len(rest) > 0 and rest[0].startswith("-") and rest[0] != "--":
curr.append(rest.pop(0))
else:
curr = rest
return curr, rest
def help(self):
pass
def usage(self) -> str:
res = " "
if self.schema:
res += self.schema.usage()
if len(self.subcommands) == 1:
sub = next(iter(self.subcommands.values()))
res += sub.longName + sub.usage()
elif len(self.subcommands) > 0:
res += "{"
first = True
for name, cmd in self.subcommands.items():
if not first:
res += " | "
else:
res += f"{name}"
res += "}"
res += "[args...]"
return res
def eval(self, args: list[str]):
cmd = args.pop(0)
curr, rest = self._spliceArgs(args)
if "-h" in curr or "--help" in curr:
self.help()
return
if "-u" in curr or "--usage" in curr:
print("Usage: " + cmd + self.usage(), end="\n\n")
return
try:
if self.schema and self.callable:
args = self.schema.parse(curr)
self.callable(args)
elif self.callable:
self.callable()
if self.subcommands:
if len(rest) == 0 and not self.populated:
raise ValueError("Expected subcommand")
elif rest[0] in self.subcommands:
print("eval", rest[0])
self.subcommands[rest[0]].eval(rest)
elif len(rest) > 0:
raise ValueError(f"Unknown subcommand '{rest[0]}'")
elif len(rest) > 0:
raise ValueError(f"Unknown operand '{rest[0]}'")
except ValueError as e:
print("Error: " + str(e))
print("Usage: " + cmd + self.usage(), end="\n\n")
return
_root = Command(None, sys.argv[0])
def _splitPath(path: str) -> list[str]:
if path == "/":
return []
return path.split("/")
def _resolvePath(path: list[str]) -> Command:
if path == "/":
return _root
cmd = _root
for name in path:
if name not in cmd.subcommands:
cmd.subcommands[name] = Command(None, name)
cmd = cmd.subcommands[name]
return cmd
def command(shortName: str, longName: str, description: str = "") -> Callable:
def wrap(fn: Callable):
schema = Schema.extractFromCallable(fn)
path = _splitPath(longName)
cmd = _resolvePath(path)
if cmd.populated:
raise ValueError(f"Command '{longName}' is already defined")
cmd.shortName = shortName
cmd.longName = path[-1]
cmd.description = description
cmd.schema = schema
cmd.callable = fn
cmd.populated = True
return fn
return wrap