This commit is contained in:
Sleepy Monax 2024-02-13 09:56:43 +01:00
parent a13c76cddf
commit 647043f4d3
3 changed files with 783 additions and 0 deletions

570
cutekit/cli2.py Normal file
View file

@ -0,0 +1,570 @@
import sys
from enum import Enum
import typing as tp
import dataclasses as dt
from typing import Any, Callable, Optional
T = tp.TypeVar("T")
# --- Scan -------------------------------------------------------------- #
class Scan:
_src: str
_off: int
_save: list[int]
def __init__(self, src: str, off: int = 0):
self._src = src
self._off = 0
self._save = []
def curr(self) -> str:
if self.eof():
return "\0"
return self._src[self._off]
def next(self) -> str:
if self.eof():
return "\0"
self._off += 1
return self.curr()
def peek(self, off: int = 1) -> str:
if self._off + off >= len(self._src):
return "\0"
return self._src[self._off + off]
def eof(self) -> bool:
return self._off >= len(self._src)
def skipStr(self, s: str) -> bool:
if self._src[self._off :].startswith(s):
self._off += len(s)
return True
return False
def isStr(self, s: str) -> bool:
self.save()
if self.skipStr(s):
self.restore()
return True
self.restore()
return False
def save(self) -> None:
self._save.append(self._off)
def restore(self) -> None:
self._off = self._save.pop()
def skipWhitespace(self) -> bool:
result = False
while not self.eof() and self.curr().isspace():
self.next()
result = True
return result
def skipSeparator(self, sep: str) -> bool:
self.save()
self.skipWhitespace()
if self.skipStr(sep):
self.skipWhitespace()
return True
self.restore()
return False
def isSeparator(self, sep: str) -> bool:
self.save()
self.skipWhitespace()
if self.skipStr(sep):
self.skipWhitespace()
self.restore()
return True
self.restore()
return False
def skipKeyword(self, keyword: str) -> bool:
self.save()
self.skipWhitespace()
if self.skipStr(keyword) and not self.curr().isalnum():
return True
self.restore()
return False
def isKeyword(self, keyword: str) -> bool:
self.save()
self.skipWhitespace()
if self.skipStr(keyword) and not self.curr().isalnum():
self.restore()
return True
self.restore()
return False
# --- Parser ------------------------------------------------------------ #
PrimitiveValue = str | bool | int
Object = dict[str, PrimitiveValue]
List = list[PrimitiveValue]
Value = str | bool | int | Object | List
@dt.dataclass
class Token:
pass
@dt.dataclass
class ArgumentToken(Token):
key: str
subkey: Optional[str]
value: Value
short: bool
@dt.dataclass
class OperandToken(Token):
value: str
@dt.dataclass
class ExtraToken(Token):
args: list[str]
def _parseIdent(s: Scan) -> str:
res = ""
while not s.eof() and (s.curr().isalnum() or s.curr() in "_-+"):
res += s.curr()
s.next()
return res
def _parseUntilComma(s: Scan) -> str:
res = ""
while not s.eof() and s.curr() != ",":
res += s.curr()
s.next()
return res
def _expectIdent(s: Scan) -> str:
res = _parseIdent(s)
if len(res) == 0:
raise RuntimeError("Expected identifier")
return res
def _parseString(s: Scan, quote: str) -> str:
s.skipStr(quote)
res = ""
escaped = False
while not s.eof():
c = s.curr()
if escaped:
res += c
escaped = False
elif c == "\\":
escaped = True
elif c == quote:
break
else:
res += c
s.next()
if not s.skipStr(quote):
raise RuntimeError("Unterminated string")
return res
def _tryParseInt(ident) -> Optional[int]:
try:
return int(ident)
except ValueError:
return None
def _parsePrimitive(s: Scan) -> PrimitiveValue:
if s.curr() == '"':
return _parseString(s, '"')
elif s.curr() == "'":
return _parseString(s, "'")
else:
ident = _parseUntilComma(s)
if ident in ("true", "True", "y", "yes", "Y", "Yes"):
return True
elif ident in ("false", "False", "n", "no", "N", "No"):
return False
elif n := _tryParseInt(ident):
return n
else:
return ident
def _parseValue(s: Scan) -> Value:
lhs = _parsePrimitive(s)
if s.eof():
return lhs
values: List = [lhs]
while not s.eof() and s.skipStr(","):
values.append(_parsePrimitive(s))
return values
def parseValue(s: str) -> Value:
return _parseValue(Scan(s))
def parseArg(arg: str) -> list[Token]:
s = Scan(arg)
if s.skipStr("--"):
key = _expectIdent(s)
subkey = ""
if s.skipStr(":"):
subkey = _expectIdent(s)
if s.skipStr("="):
value = _parseValue(s)
else:
value = True
return [ArgumentToken(key, subkey, value, False)]
elif s.skipStr("-"):
res = []
while not s.eof():
key = s.curr()
if not key.isalnum():
raise RuntimeError("Expected alphanumeric")
s.next()
res.append(ArgumentToken(key, None, True, True))
return res
else:
return [OperandToken(arg)]
def parseArgs(args: list[str]) -> list[Token]:
res: list[Token] = []
while len(args) > 0:
arg = args.pop(0)
if arg == "--":
res.append(ExtraToken(args))
break
else:
res.extend(parseArg(arg))
return res
# --- Schema ----------------------------------------------------------------- #
class FieldType(Enum):
FLAG = 0
OPERAND = 1
EXTRA = 2
@dt.dataclass
class Field:
type: FieldType
shortName: Optional[str]
longName: str
description: str = ""
default: Any = None
_fieldName: str | None = dt.field(init=False, default=None)
_fieldType: type | None = dt.field(init=False, default=None)
def bind(self, typ: type, name: str):
self._fieldName = name
self._fieldType = typ.__annotations__[name]
if self.longName is None:
self.longName = name
def defaultValue(self) -> Any:
if self._fieldType is None:
return None
if self.default is not None:
return self.default
if self._fieldType == bool:
return False
elif self._fieldType == int:
return 0
elif self._fieldType == str:
return ""
elif self._fieldType == list:
return []
elif self._fieldType == dict:
return {}
else:
return None
def setDefault(self, obj: Any):
if self._fieldName:
setattr(obj, self._fieldName, self.default)
def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None):
setattr(obj, self._fieldName, value)
def getAttr(self, obj: Any) -> Any:
return getattr(obj, self._fieldName)
def arg(
shortName: str | None = None,
longName: str = "",
description: str = "",
default: Any = None,
) -> Any:
return Field(FieldType.FLAG, shortName, longName, description, default)
def operand(longName: str = "", description: str = "") -> Any:
return Field(FieldType.OPERAND, None, longName, description)
def extra(longName: str = "", description: str = "") -> Any:
return Field(FieldType.EXTRA, None, longName, description)
@dt.dataclass
class Schema:
typ: Optional[type] = None
args: list[Field] = dt.field(default_factory=list)
operands: list[Field] = dt.field(default_factory=list)
extras: Optional[Field] = None
@staticmethod
def extract(typ: type) -> "Schema":
s = Schema(typ)
for f in typ.__annotations__.keys():
field = getattr(typ, f, None)
if field is None:
raise ValueError(f"Field '{f}' is not defined")
if not isinstance(field, Field):
raise ValueError(f"Field '{f}' is not a Field")
field.bind(typ, f)
if field.type == FieldType.FLAG:
s.args.append(field)
elif field.type == FieldType.OPERAND:
s.operands.append(field)
elif field.type == FieldType.EXTRA:
if s.extras:
raise ValueError("Only one extra argument is allowed")
s.extras = field
return s
@staticmethod
def extractFromCallable(fn: tp.Callable) -> "Schema":
typ: type | None = (
None
if len(fn.__annotations__) == 0
else next(iter(fn.__annotations__.values()))
)
if typ is None:
return Schema()
return Schema.extract(typ)
def usage(self) -> str:
res = ""
for arg in self.args:
flag = ""
if arg.shortName:
flag += f"-{arg.shortName}"
if arg.longName:
if flag:
flag += ", "
flag += f"--{arg.longName}"
res += f"[{flag}] "
for operand in self.operands:
res += f"<{operand.longName}> "
if self.extras:
res += f"[-- {self.extras.longName}]"
return res
def _lookupArg(self, key: str, short: bool) -> Field:
for arg in self.args:
if short and arg.shortName == key:
return arg
elif not short and arg.longName == key:
return arg
raise ValueError(f"Unknown argument '{key}'")
def _setOperand(self, tok: OperandToken):
return
def _instanciate(self) -> Any:
if self.typ is None:
return None
res = self.typ()
for arg in self.args:
field = arg.field
if field:
field.put(res, arg.default or field.default())
else:
raise ValueError("Argument has no field")
return res
def parse(self, args: list[str]) -> Any:
res = self._instanciate()
if res is None:
if len(args) > 0:
raise ValueError("Unexpected arguments")
else:
return None
stack = args[:]
while len(stack) > 0:
if stack[0] == "--":
if not self.extras:
raise ValueError("Unexpected '--'")
self._setExtra(res, stack.pop(0))
break
toks = parseArg(stack.pop(0))
for tok in toks:
if isinstance(tok, ArgumentToken):
arg = self._lookupArg(tok.key, tok.short)
arg.putValue(res, tok.value, tok.subkey)
elif isinstance(tok, OperandToken):
self._setOperand(tok)
else:
raise ValueError(f"Unexpected token: {type(tok)}")
return res
@dt.dataclass
class Command:
shortName: Optional[str]
longName: str
description: str = ""
schema: Optional[Schema] = None
callable: Optional[tp.Callable] = None
subcommands: dict[str, "Command"] = dt.field(default_factory=dict)
populated: bool = False
def _spliceArgs(self, args: list[str]) -> tuple[list[str], list[str]]:
rest = args[:]
curr = []
if len(self.subcommands) > 0:
while len(rest) > 0 and rest[0].startswith("-") and rest[0] != "--":
curr.append(rest.pop(0))
else:
curr = rest
return curr, rest
def help(self):
pass
def usage(self) -> str:
res = " "
if self.schema:
res += self.schema.usage()
if len(self.subcommands) == 1:
sub = next(iter(self.subcommands.values()))
res += sub.longName + sub.usage()
elif len(self.subcommands) > 0:
res += "{"
first = True
for name, cmd in self.subcommands.items():
if not first:
res += " | "
else:
res += f"{name}"
res += "}"
res += "[args...]"
return res
def eval(self, args: list[str]):
cmd = args.pop(0)
curr, rest = self._spliceArgs(args)
if "-h" in curr or "--help" in curr:
self.help()
return
if "-u" in curr or "--usage" in curr:
print("Usage: " + cmd + self.usage(), end="\n\n")
return
try:
if self.schema and self.callable:
args = self.schema.parse(curr)
self.callable(args)
elif self.callable:
self.callable()
if self.subcommands:
if len(rest) == 0 and not self.populated:
raise ValueError("Expected subcommand")
elif rest[0] in self.subcommands:
print("eval", rest[0])
self.subcommands[rest[0]].eval(rest)
elif len(rest) > 0:
raise ValueError(f"Unknown subcommand '{rest[0]}'")
elif len(rest) > 0:
raise ValueError(f"Unknown operand '{rest[0]}'")
except ValueError as e:
print("Error: " + str(e))
print("Usage: " + cmd + self.usage(), end="\n\n")
return
_root = Command(None, sys.argv[0])
def _splitPath(path: str) -> list[str]:
if path == "/":
return []
return path.split("/")
def _resolvePath(path: list[str]) -> Command:
if path == "/":
return _root
cmd = _root
for name in path:
if name not in cmd.subcommands:
cmd.subcommands[name] = Command(None, name)
cmd = cmd.subcommands[name]
return cmd
def command(shortName: str, longName: str, description: str = "") -> Callable:
def wrap(fn: Callable):
schema = Schema.extractFromCallable(fn)
path = _splitPath(longName)
cmd = _resolvePath(path)
if cmd.populated:
raise ValueError(f"Command '{longName}' is already defined")
cmd.shortName = shortName
cmd.longName = path[-1]
cmd.description = description
cmd.schema = schema
cmd.callable = fn
cmd.populated = True
return fn
return wrap

View file

@ -2,3 +2,4 @@ requests ~= 2.31.0
graphviz ~= 0.20.1
dataclasses-json ~= 0.6.2
docker ~= 6.1.3
asserts ~= 0.12.0

212
tests/test_cli.py Normal file
View file

@ -0,0 +1,212 @@
from cutekit import cli2, utils
from asserts import (
assert_is,
assert_true,
assert_equal,
assert_raises,
assert_is_instance,
)
# --- Parse Values ----------------------------------------------------------- #
def test_parse_int_val():
assert_equal(cli2.parseValue("1"), 1)
assert_equal(cli2.parseValue("2"), 2)
assert_equal(cli2.parseValue("+2"), +2)
assert_equal(cli2.parseValue("-2"), -2)
def test_parse_true_val():
assert_equal(cli2.parseValue("true"), True)
assert_equal(cli2.parseValue("True"), True)
assert_equal(cli2.parseValue("y"), True)
assert_equal(cli2.parseValue("yes"), True)
assert_equal(cli2.parseValue("Y"), True)
assert_equal(cli2.parseValue("Yes"), True)
def test_parse_false_val():
assert_equal(cli2.parseValue("false"), False)
assert_equal(cli2.parseValue("False"), False)
assert_equal(cli2.parseValue("n"), False)
assert_equal(cli2.parseValue("no"), False)
assert_equal(cli2.parseValue("N"), False)
assert_equal(cli2.parseValue("No"), False)
def test_parse_str_val():
assert_equal(cli2.parseValue("foo"), "foo")
assert_equal(cli2.parseValue("'foo'"), "foo")
assert_equal(cli2.parseValue('"foo"'), "foo")
def test_parse_list_val():
assert_equal(cli2.parseValue("foo,bar"), ["foo", "bar"])
assert_equal(cli2.parseValue("'foo','bar'"), ["foo", "bar"])
assert_equal(cli2.parseValue('"foo","bar"'), ["foo", "bar"])
# --- Parse Args ------------------------------------------------------------- #
def test_parse_short_arg():
args = cli2.parseArg("-a")
assert_equal(len(args), 1)
arg = args[0]
assert_is_instance(arg, cli2.ArgumentToken)
assert_equal(arg.key, "a")
assert_equal(arg.value, True)
def test_parse_short_args():
args = cli2.parseArg("-abc")
assert_equal(len(args), 3)
arg = args[0]
assert_is_instance(arg, cli2.ArgumentToken)
assert_equal(arg.key, "a")
assert_equal(arg.value, True)
arg = args[1]
assert_is_instance(arg, cli2.ArgumentToken)
assert_equal(arg.key, "b")
assert_equal(arg.value, True)
arg = args[2]
assert_is_instance(arg, cli2.ArgumentToken)
assert_equal(arg.key, "c")
assert_equal(arg.value, True)
def test_parse_long_arg():
args = cli2.parseArg("--foo")
assert_equal(len(args), 1)
arg = args[0]
assert_is_instance(arg, cli2.ArgumentToken)
assert_equal(arg.key, "foo")
assert_equal(arg.value, True)
def test_parse_long_arg_with_value():
args = cli2.parseArg("--foo=bar")
assert_equal(len(args), 1)
arg = args[0]
assert_is_instance(arg, cli2.ArgumentToken)
assert_equal(arg.key, "foo")
assert_equal(arg.value, "bar")
def test_parse_long_arg_with_value_list():
args = cli2.parseArg("--foo=bar,baz")
assert_equal(len(args), 1)
arg = args[0]
assert_is_instance(arg, cli2.ArgumentToken)
assert_equal(arg.key, "foo")
assert_equal(arg.value, ["bar", "baz"])
def test_parse_key_subkey_arg():
args = cli2.parseArg("--foo:bar")
assert_equal(len(args), 1)
arg = args[0]
assert_is_instance(arg, cli2.ArgumentToken)
assert_equal(arg.key, "foo")
assert_equal(arg.subkey, "bar")
assert_equal(arg.value, True)
def extractParse(type: type[utils.T], args: list[str]) -> utils.T:
schema = cli2.Schema.extract(type)
return schema.parse(args)
def test_cli_arg_int():
class IntArg:
value: int = cli2.arg(None, "value")
assert_equal(extractParse(IntArg, ["--value=-1"]).value, -1)
assert_equal(extractParse(IntArg, ["--value=0"]).value, 0)
assert_equal(extractParse(IntArg, ["--value=1"]).value, 1)
class StrArg:
value: str = cli2.arg(None, "value")
def test_cli_arg_str1():
v = extractParse(StrArg, ["--value=foo"])
assert v.value == "foo"
def test_cli_arg_str2():
v = extractParse(StrArg, ["--value='foo, bar'"])
assert v.value == "foo, bar"
class BoolArg:
value: bool = cli2.arg(None, "value")
def test_cli_arg_bool():
assert_is(extractParse(BoolArg, ["--value"]).value, True)
assert_is(extractParse(BoolArg, ["--value=true"]).value, True)
assert_is(extractParse(BoolArg, ["--value=True"]).value, True)
assert_is(extractParse(BoolArg, ["--value=y"]).value, True)
assert_is(extractParse(BoolArg, ["--value=yes"]).value, True)
assert_is(extractParse(BoolArg, ["--value=Y"]).value, True)
assert_is(extractParse(BoolArg, ["--value=Yes"]).value, True)
assert_is(extractParse(BoolArg, ["--value=1"]).value, True)
assert extractParse(BoolArg, ["--value=false"]).value is False
assert extractParse(BoolArg, ["--value=False"]).value is False
assert extractParse(BoolArg, ["--value=n"]).value is False
assert extractParse(BoolArg, ["--value=no"]).value is False
assert extractParse(BoolArg, ["--value=N"]).value is False
assert extractParse(BoolArg, ["--value=No"]).value is False
print(extractParse(BoolArg, ["--value=0"]).value)
assert extractParse(BoolArg, ["--value=0"]).value is False
class IntListArg:
value: list[int] = cli2.arg(None, "value")
def test_cli_arg_list_int1():
assert_equal(extractParse(IntListArg, []).value, [])
assert_equal(extractParse(IntListArg, ["--value=1", "--value=2"]).value, [1, 2])
assert_equal(extractParse(IntListArg, ["--value=1,2"]).value, [1, 2])
class StrListArg:
value: list[str] = cli2.arg(None, "value")
def test_cli_arg_list_str():
assert_equal(
extractParse(StrListArg, ["--value=foo", "--value=bar"]).value,
[
"foo",
"bar",
],
)
assert_equal(extractParse(StrListArg, ["--value=foo,bar"]).value, ["foo", "bar"])
assert_equal(extractParse(StrListArg, ["--value=foo,bar"]).value, ["foo", "bar"])
assert_equal(extractParse(StrListArg, ["--value='foo,bar'"]).value, ["foo,bar"])
assert_equal(extractParse(StrListArg, ["--value='foo, bar'"]).value, ["foo, bar"])
assert_equal(extractParse(StrListArg, ['--value="foo, bar"']).value, ["foo, bar"])
class StrDictArg:
value: dict[str, str] = cli2.arg(None, "value")
def test_cli_arg_dict_str():
assert_equal(extractParse(StrDictArg, ["--value:foo=bar"]).value, {"foo": "bar"})
assert_equal(
extractParse(StrDictArg, ["--value:foo=bar", "--value:baz=qux"]).value,
{
"foo": "bar",
"baz": "qux",
},
)