570 lines
14 KiB
Python
570 lines
14 KiB
Python
import sys
|
|
from enum import Enum
|
|
import typing as tp
|
|
import dataclasses as dt
|
|
|
|
from typing import Any, Callable, Optional
|
|
|
|
|
|
T = tp.TypeVar("T")
|
|
|
|
# --- Scan -------------------------------------------------------------- #
|
|
|
|
|
|
class Scan:
|
|
_src: str
|
|
_off: int
|
|
_save: list[int]
|
|
|
|
def __init__(self, src: str, off: int = 0):
|
|
self._src = src
|
|
self._off = 0
|
|
self._save = []
|
|
|
|
def curr(self) -> str:
|
|
if self.eof():
|
|
return "\0"
|
|
return self._src[self._off]
|
|
|
|
def next(self) -> str:
|
|
if self.eof():
|
|
return "\0"
|
|
|
|
self._off += 1
|
|
return self.curr()
|
|
|
|
def peek(self, off: int = 1) -> str:
|
|
if self._off + off >= len(self._src):
|
|
return "\0"
|
|
|
|
return self._src[self._off + off]
|
|
|
|
def eof(self) -> bool:
|
|
return self._off >= len(self._src)
|
|
|
|
def skipStr(self, s: str) -> bool:
|
|
if self._src[self._off :].startswith(s):
|
|
self._off += len(s)
|
|
return True
|
|
|
|
return False
|
|
|
|
def isStr(self, s: str) -> bool:
|
|
self.save()
|
|
if self.skipStr(s):
|
|
self.restore()
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
def save(self) -> None:
|
|
self._save.append(self._off)
|
|
|
|
def restore(self) -> None:
|
|
self._off = self._save.pop()
|
|
|
|
def skipWhitespace(self) -> bool:
|
|
result = False
|
|
while not self.eof() and self.curr().isspace():
|
|
self.next()
|
|
result = True
|
|
return result
|
|
|
|
def skipSeparator(self, sep: str) -> bool:
|
|
self.save()
|
|
self.skipWhitespace()
|
|
if self.skipStr(sep):
|
|
self.skipWhitespace()
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
def isSeparator(self, sep: str) -> bool:
|
|
self.save()
|
|
self.skipWhitespace()
|
|
if self.skipStr(sep):
|
|
self.skipWhitespace()
|
|
self.restore()
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
def skipKeyword(self, keyword: str) -> bool:
|
|
self.save()
|
|
self.skipWhitespace()
|
|
if self.skipStr(keyword) and not self.curr().isalnum():
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
def isKeyword(self, keyword: str) -> bool:
|
|
self.save()
|
|
self.skipWhitespace()
|
|
if self.skipStr(keyword) and not self.curr().isalnum():
|
|
self.restore()
|
|
return True
|
|
|
|
self.restore()
|
|
return False
|
|
|
|
|
|
# --- Parser ------------------------------------------------------------ #
|
|
|
|
PrimitiveValue = str | bool | int
|
|
Object = dict[str, PrimitiveValue]
|
|
List = list[PrimitiveValue]
|
|
Value = str | bool | int | Object | List
|
|
|
|
|
|
@dt.dataclass
|
|
class Token:
|
|
pass
|
|
|
|
|
|
@dt.dataclass
|
|
class ArgumentToken(Token):
|
|
key: str
|
|
subkey: Optional[str]
|
|
value: Value
|
|
short: bool
|
|
|
|
|
|
@dt.dataclass
|
|
class OperandToken(Token):
|
|
value: str
|
|
|
|
|
|
@dt.dataclass
|
|
class ExtraToken(Token):
|
|
args: list[str]
|
|
|
|
|
|
def _parseIdent(s: Scan) -> str:
|
|
res = ""
|
|
while not s.eof() and (s.curr().isalnum() or s.curr() in "_-+"):
|
|
res += s.curr()
|
|
s.next()
|
|
return res
|
|
|
|
|
|
def _parseUntilComma(s: Scan) -> str:
|
|
res = ""
|
|
while not s.eof() and s.curr() != ",":
|
|
res += s.curr()
|
|
s.next()
|
|
return res
|
|
|
|
|
|
def _expectIdent(s: Scan) -> str:
|
|
res = _parseIdent(s)
|
|
if len(res) == 0:
|
|
raise RuntimeError("Expected identifier")
|
|
return res
|
|
|
|
|
|
def _parseString(s: Scan, quote: str) -> str:
|
|
s.skipStr(quote)
|
|
res = ""
|
|
escaped = False
|
|
while not s.eof():
|
|
c = s.curr()
|
|
if escaped:
|
|
res += c
|
|
escaped = False
|
|
elif c == "\\":
|
|
escaped = True
|
|
elif c == quote:
|
|
break
|
|
else:
|
|
res += c
|
|
s.next()
|
|
if not s.skipStr(quote):
|
|
raise RuntimeError("Unterminated string")
|
|
return res
|
|
|
|
|
|
def _tryParseInt(ident) -> Optional[int]:
|
|
try:
|
|
return int(ident)
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
def _parsePrimitive(s: Scan) -> PrimitiveValue:
|
|
if s.curr() == '"':
|
|
return _parseString(s, '"')
|
|
elif s.curr() == "'":
|
|
return _parseString(s, "'")
|
|
else:
|
|
ident = _parseUntilComma(s)
|
|
|
|
if ident in ("true", "True", "y", "yes", "Y", "Yes"):
|
|
return True
|
|
elif ident in ("false", "False", "n", "no", "N", "No"):
|
|
return False
|
|
elif n := _tryParseInt(ident):
|
|
return n
|
|
else:
|
|
return ident
|
|
|
|
|
|
def _parseValue(s: Scan) -> Value:
|
|
lhs = _parsePrimitive(s)
|
|
if s.eof():
|
|
return lhs
|
|
values: List = [lhs]
|
|
while not s.eof() and s.skipStr(","):
|
|
values.append(_parsePrimitive(s))
|
|
return values
|
|
|
|
|
|
def parseValue(s: str) -> Value:
|
|
return _parseValue(Scan(s))
|
|
|
|
|
|
def parseArg(arg: str) -> list[Token]:
|
|
s = Scan(arg)
|
|
if s.skipStr("--"):
|
|
key = _expectIdent(s)
|
|
subkey = ""
|
|
if s.skipStr(":"):
|
|
subkey = _expectIdent(s)
|
|
if s.skipStr("="):
|
|
value = _parseValue(s)
|
|
else:
|
|
value = True
|
|
return [ArgumentToken(key, subkey, value, False)]
|
|
elif s.skipStr("-"):
|
|
res = []
|
|
while not s.eof():
|
|
key = s.curr()
|
|
if not key.isalnum():
|
|
raise RuntimeError("Expected alphanumeric")
|
|
s.next()
|
|
res.append(ArgumentToken(key, None, True, True))
|
|
return res
|
|
else:
|
|
return [OperandToken(arg)]
|
|
|
|
|
|
def parseArgs(args: list[str]) -> list[Token]:
|
|
res: list[Token] = []
|
|
while len(args) > 0:
|
|
arg = args.pop(0)
|
|
if arg == "--":
|
|
res.append(ExtraToken(args))
|
|
break
|
|
else:
|
|
res.extend(parseArg(arg))
|
|
return res
|
|
|
|
|
|
# --- Schema ----------------------------------------------------------------- #
|
|
|
|
|
|
class FieldType(Enum):
|
|
FLAG = 0
|
|
OPERAND = 1
|
|
EXTRA = 2
|
|
|
|
|
|
@dt.dataclass
|
|
class Field:
|
|
type: FieldType
|
|
shortName: Optional[str]
|
|
longName: str
|
|
description: str = ""
|
|
default: Any = None
|
|
|
|
_fieldName: str | None = dt.field(init=False, default=None)
|
|
_fieldType: type | None = dt.field(init=False, default=None)
|
|
|
|
def bind(self, typ: type, name: str):
|
|
self._fieldName = name
|
|
self._fieldType = typ.__annotations__[name]
|
|
if self.longName is None:
|
|
self.longName = name
|
|
|
|
def defaultValue(self) -> Any:
|
|
if self._fieldType is None:
|
|
return None
|
|
|
|
if self.default is not None:
|
|
return self.default
|
|
|
|
if self._fieldType == bool:
|
|
return False
|
|
elif self._fieldType == int:
|
|
return 0
|
|
elif self._fieldType == str:
|
|
return ""
|
|
elif self._fieldType == list:
|
|
return []
|
|
elif self._fieldType == dict:
|
|
return {}
|
|
else:
|
|
return None
|
|
|
|
def setDefault(self, obj: Any):
|
|
if self._fieldName:
|
|
setattr(obj, self._fieldName, self.default)
|
|
|
|
def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None):
|
|
setattr(obj, self._fieldName, value)
|
|
|
|
def getAttr(self, obj: Any) -> Any:
|
|
return getattr(obj, self._fieldName)
|
|
|
|
|
|
def arg(
|
|
shortName: str | None = None,
|
|
longName: str = "",
|
|
description: str = "",
|
|
default: Any = None,
|
|
) -> Any:
|
|
return Field(FieldType.FLAG, shortName, longName, description, default)
|
|
|
|
|
|
def operand(longName: str = "", description: str = "") -> Any:
|
|
return Field(FieldType.OPERAND, None, longName, description)
|
|
|
|
|
|
def extra(longName: str = "", description: str = "") -> Any:
|
|
return Field(FieldType.EXTRA, None, longName, description)
|
|
|
|
|
|
@dt.dataclass
|
|
class Schema:
|
|
typ: Optional[type] = None
|
|
args: list[Field] = dt.field(default_factory=list)
|
|
operands: list[Field] = dt.field(default_factory=list)
|
|
extras: Optional[Field] = None
|
|
|
|
@staticmethod
|
|
def extract(typ: type) -> "Schema":
|
|
s = Schema(typ)
|
|
|
|
for f in typ.__annotations__.keys():
|
|
field = getattr(typ, f, None)
|
|
|
|
if field is None:
|
|
raise ValueError(f"Field '{f}' is not defined")
|
|
|
|
if not isinstance(field, Field):
|
|
raise ValueError(f"Field '{f}' is not a Field")
|
|
|
|
field.bind(typ, f)
|
|
|
|
if field.type == FieldType.FLAG:
|
|
s.args.append(field)
|
|
elif field.type == FieldType.OPERAND:
|
|
s.operands.append(field)
|
|
elif field.type == FieldType.EXTRA:
|
|
if s.extras:
|
|
raise ValueError("Only one extra argument is allowed")
|
|
s.extras = field
|
|
|
|
return s
|
|
|
|
@staticmethod
|
|
def extractFromCallable(fn: tp.Callable) -> "Schema":
|
|
typ: type | None = (
|
|
None
|
|
if len(fn.__annotations__) == 0
|
|
else next(iter(fn.__annotations__.values()))
|
|
)
|
|
|
|
if typ is None:
|
|
return Schema()
|
|
|
|
return Schema.extract(typ)
|
|
|
|
def usage(self) -> str:
|
|
res = ""
|
|
for arg in self.args:
|
|
flag = ""
|
|
if arg.shortName:
|
|
flag += f"-{arg.shortName}"
|
|
|
|
if arg.longName:
|
|
if flag:
|
|
flag += ", "
|
|
flag += f"--{arg.longName}"
|
|
res += f"[{flag}] "
|
|
for operand in self.operands:
|
|
res += f"<{operand.longName}> "
|
|
if self.extras:
|
|
res += f"[-- {self.extras.longName}]"
|
|
return res
|
|
|
|
def _lookupArg(self, key: str, short: bool) -> Field:
|
|
for arg in self.args:
|
|
if short and arg.shortName == key:
|
|
return arg
|
|
elif not short and arg.longName == key:
|
|
return arg
|
|
raise ValueError(f"Unknown argument '{key}'")
|
|
|
|
def _setOperand(self, tok: OperandToken):
|
|
return
|
|
|
|
def _instanciate(self) -> Any:
|
|
if self.typ is None:
|
|
return None
|
|
res = self.typ()
|
|
for arg in self.args:
|
|
field = arg.field
|
|
if field:
|
|
field.put(res, arg.default or field.default())
|
|
else:
|
|
raise ValueError("Argument has no field")
|
|
return res
|
|
|
|
def parse(self, args: list[str]) -> Any:
|
|
res = self._instanciate()
|
|
if res is None:
|
|
if len(args) > 0:
|
|
raise ValueError("Unexpected arguments")
|
|
else:
|
|
return None
|
|
|
|
stack = args[:]
|
|
while len(stack) > 0:
|
|
if stack[0] == "--":
|
|
if not self.extras:
|
|
raise ValueError("Unexpected '--'")
|
|
self._setExtra(res, stack.pop(0))
|
|
break
|
|
|
|
toks = parseArg(stack.pop(0))
|
|
for tok in toks:
|
|
if isinstance(tok, ArgumentToken):
|
|
arg = self._lookupArg(tok.key, tok.short)
|
|
arg.putValue(res, tok.value, tok.subkey)
|
|
elif isinstance(tok, OperandToken):
|
|
self._setOperand(tok)
|
|
else:
|
|
raise ValueError(f"Unexpected token: {type(tok)}")
|
|
|
|
return res
|
|
|
|
|
|
@dt.dataclass
|
|
class Command:
|
|
shortName: Optional[str]
|
|
longName: str
|
|
description: str = ""
|
|
schema: Optional[Schema] = None
|
|
callable: Optional[tp.Callable] = None
|
|
subcommands: dict[str, "Command"] = dt.field(default_factory=dict)
|
|
populated: bool = False
|
|
|
|
def _spliceArgs(self, args: list[str]) -> tuple[list[str], list[str]]:
|
|
rest = args[:]
|
|
curr = []
|
|
if len(self.subcommands) > 0:
|
|
while len(rest) > 0 and rest[0].startswith("-") and rest[0] != "--":
|
|
curr.append(rest.pop(0))
|
|
else:
|
|
curr = rest
|
|
return curr, rest
|
|
|
|
def help(self):
|
|
pass
|
|
|
|
def usage(self) -> str:
|
|
res = " "
|
|
if self.schema:
|
|
res += self.schema.usage()
|
|
|
|
if len(self.subcommands) == 1:
|
|
sub = next(iter(self.subcommands.values()))
|
|
res += sub.longName + sub.usage()
|
|
elif len(self.subcommands) > 0:
|
|
res += "{"
|
|
first = True
|
|
for name, cmd in self.subcommands.items():
|
|
if not first:
|
|
res += " | "
|
|
else:
|
|
res += f"{name}"
|
|
res += "}"
|
|
|
|
res += "[args...]"
|
|
|
|
return res
|
|
|
|
def eval(self, args: list[str]):
|
|
cmd = args.pop(0)
|
|
curr, rest = self._spliceArgs(args)
|
|
if "-h" in curr or "--help" in curr:
|
|
self.help()
|
|
return
|
|
if "-u" in curr or "--usage" in curr:
|
|
print("Usage: " + cmd + self.usage(), end="\n\n")
|
|
return
|
|
|
|
try:
|
|
if self.schema and self.callable:
|
|
args = self.schema.parse(curr)
|
|
self.callable(args)
|
|
elif self.callable:
|
|
self.callable()
|
|
|
|
if self.subcommands:
|
|
if len(rest) == 0 and not self.populated:
|
|
raise ValueError("Expected subcommand")
|
|
elif rest[0] in self.subcommands:
|
|
print("eval", rest[0])
|
|
self.subcommands[rest[0]].eval(rest)
|
|
elif len(rest) > 0:
|
|
raise ValueError(f"Unknown subcommand '{rest[0]}'")
|
|
elif len(rest) > 0:
|
|
raise ValueError(f"Unknown operand '{rest[0]}'")
|
|
|
|
except ValueError as e:
|
|
print("Error: " + str(e))
|
|
print("Usage: " + cmd + self.usage(), end="\n\n")
|
|
return
|
|
|
|
|
|
_root = Command(None, sys.argv[0])
|
|
|
|
|
|
def _splitPath(path: str) -> list[str]:
|
|
if path == "/":
|
|
return []
|
|
return path.split("/")
|
|
|
|
|
|
def _resolvePath(path: list[str]) -> Command:
|
|
if path == "/":
|
|
return _root
|
|
cmd = _root
|
|
for name in path:
|
|
if name not in cmd.subcommands:
|
|
cmd.subcommands[name] = Command(None, name)
|
|
cmd = cmd.subcommands[name]
|
|
return cmd
|
|
|
|
|
|
def command(shortName: str, longName: str, description: str = "") -> Callable:
|
|
def wrap(fn: Callable):
|
|
schema = Schema.extractFromCallable(fn)
|
|
path = _splitPath(longName)
|
|
cmd = _resolvePath(path)
|
|
if cmd.populated:
|
|
raise ValueError(f"Command '{longName}' is already defined")
|
|
cmd.shortName = shortName
|
|
cmd.longName = path[-1]
|
|
cmd.description = description
|
|
cmd.schema = schema
|
|
cmd.callable = fn
|
|
cmd.populated = True
|
|
return fn
|
|
|
|
return wrap
|