This commit is contained in:
Sleepy Monax 2024-02-13 20:23:37 +01:00
parent 647043f4d3
commit ad5cc391be
2 changed files with 21 additions and 15 deletions

View file

@ -1,5 +1,7 @@
import sys
from enum import Enum
import typing as tp
import dataclasses as dt
@ -266,7 +268,7 @@ def parseArgs(args: list[str]) -> list[Token]:
# --- Schema ----------------------------------------------------------------- #
class FieldType(Enum):
class FieldKind(Enum):
FLAG = 0
OPERAND = 1
EXTRA = 2
@ -274,7 +276,7 @@ class FieldType(Enum):
@dt.dataclass
class Field:
type: FieldType
kind: FieldKind
shortName: Optional[str]
longName: str
description: str = ""
@ -314,9 +316,13 @@ class Field:
setattr(obj, self._fieldName, self.default)
def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None):
if self._fieldName is None:
raise ValueError("Field type is not defined")
setattr(obj, self._fieldName, value)
def getAttr(self, obj: Any) -> Any:
if self._fieldName is None:
raise ValueError("Field name is not defined")
return getattr(obj, self._fieldName)
@ -326,15 +332,15 @@ def arg(
description: str = "",
default: Any = None,
) -> Any:
return Field(FieldType.FLAG, shortName, longName, description, default)
return Field(FieldKind.FLAG, shortName, longName, description, default)
def operand(longName: str = "", description: str = "") -> Any:
return Field(FieldType.OPERAND, None, longName, description)
return Field(FieldKind.OPERAND, None, longName, description)
def extra(longName: str = "", description: str = "") -> Any:
return Field(FieldType.EXTRA, None, longName, description)
return Field(FieldKind.EXTRA, None, longName, description)
@dt.dataclass
@ -359,11 +365,11 @@ class Schema:
field.bind(typ, f)
if field.type == FieldType.FLAG:
if field.type == FieldKind.FLAG:
s.args.append(field)
elif field.type == FieldType.OPERAND:
elif field.type == FieldKind.OPERAND:
s.operands.append(field)
elif field.type == FieldType.EXTRA:
elif field.type == FieldKind.EXTRA:
if s.extras:
raise ValueError("Only one extra argument is allowed")
s.extras = field
@ -417,11 +423,7 @@ class Schema:
return None
res = self.typ()
for arg in self.args:
field = arg.field
if field:
field.put(res, arg.default or field.default())
else:
raise ValueError("Argument has no field")
arg.setDefault(res)
return res
def parse(self, args: list[str]) -> Any:

View file

@ -1,9 +1,7 @@
from cutekit import cli2, utils
from asserts import (
assert_is,
assert_true,
assert_equal,
assert_raises,
assert_is_instance,
)
@ -115,6 +113,9 @@ def test_parse_key_subkey_arg():
assert_equal(arg.value, True)
# --- Extract Args ----------------------------------------------------------- #
def extractParse(type: type[utils.T], args: list[str]) -> utils.T:
schema = cli2.Schema.extract(type)
return schema.parse(args)
@ -129,6 +130,7 @@ def test_cli_arg_int():
assert_equal(extractParse(IntArg, ["--value=1"]).value, 1)
"""
class StrArg:
value: str = cli2.arg(None, "value")
@ -210,3 +212,5 @@ def test_cli_arg_dict_str():
"baz": "qux",
},
)
"""