wip
This commit is contained in:
parent
647043f4d3
commit
ad5cc391be
2 changed files with 21 additions and 15 deletions
|
@ -1,5 +1,7 @@
|
|||
import sys
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import typing as tp
|
||||
import dataclasses as dt
|
||||
|
||||
|
@ -266,7 +268,7 @@ def parseArgs(args: list[str]) -> list[Token]:
|
|||
# --- Schema ----------------------------------------------------------------- #
|
||||
|
||||
|
||||
class FieldType(Enum):
|
||||
class FieldKind(Enum):
|
||||
FLAG = 0
|
||||
OPERAND = 1
|
||||
EXTRA = 2
|
||||
|
@ -274,7 +276,7 @@ class FieldType(Enum):
|
|||
|
||||
@dt.dataclass
|
||||
class Field:
|
||||
type: FieldType
|
||||
kind: FieldKind
|
||||
shortName: Optional[str]
|
||||
longName: str
|
||||
description: str = ""
|
||||
|
@ -314,9 +316,13 @@ class Field:
|
|||
setattr(obj, self._fieldName, self.default)
|
||||
|
||||
def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None):
|
||||
if self._fieldName is None:
|
||||
raise ValueError("Field type is not defined")
|
||||
setattr(obj, self._fieldName, value)
|
||||
|
||||
def getAttr(self, obj: Any) -> Any:
|
||||
if self._fieldName is None:
|
||||
raise ValueError("Field name is not defined")
|
||||
return getattr(obj, self._fieldName)
|
||||
|
||||
|
||||
|
@ -326,15 +332,15 @@ def arg(
|
|||
description: str = "",
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
return Field(FieldType.FLAG, shortName, longName, description, default)
|
||||
return Field(FieldKind.FLAG, shortName, longName, description, default)
|
||||
|
||||
|
||||
def operand(longName: str = "", description: str = "") -> Any:
|
||||
return Field(FieldType.OPERAND, None, longName, description)
|
||||
return Field(FieldKind.OPERAND, None, longName, description)
|
||||
|
||||
|
||||
def extra(longName: str = "", description: str = "") -> Any:
|
||||
return Field(FieldType.EXTRA, None, longName, description)
|
||||
return Field(FieldKind.EXTRA, None, longName, description)
|
||||
|
||||
|
||||
@dt.dataclass
|
||||
|
@ -359,11 +365,11 @@ class Schema:
|
|||
|
||||
field.bind(typ, f)
|
||||
|
||||
if field.type == FieldType.FLAG:
|
||||
if field.type == FieldKind.FLAG:
|
||||
s.args.append(field)
|
||||
elif field.type == FieldType.OPERAND:
|
||||
elif field.type == FieldKind.OPERAND:
|
||||
s.operands.append(field)
|
||||
elif field.type == FieldType.EXTRA:
|
||||
elif field.type == FieldKind.EXTRA:
|
||||
if s.extras:
|
||||
raise ValueError("Only one extra argument is allowed")
|
||||
s.extras = field
|
||||
|
@ -417,11 +423,7 @@ class Schema:
|
|||
return None
|
||||
res = self.typ()
|
||||
for arg in self.args:
|
||||
field = arg.field
|
||||
if field:
|
||||
field.put(res, arg.default or field.default())
|
||||
else:
|
||||
raise ValueError("Argument has no field")
|
||||
arg.setDefault(res)
|
||||
return res
|
||||
|
||||
def parse(self, args: list[str]) -> Any:
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from cutekit import cli2, utils
|
||||
from asserts import (
|
||||
assert_is,
|
||||
assert_true,
|
||||
assert_equal,
|
||||
assert_raises,
|
||||
assert_is_instance,
|
||||
)
|
||||
|
||||
|
@ -115,6 +113,9 @@ def test_parse_key_subkey_arg():
|
|||
assert_equal(arg.value, True)
|
||||
|
||||
|
||||
# --- Extract Args ----------------------------------------------------------- #
|
||||
|
||||
|
||||
def extractParse(type: type[utils.T], args: list[str]) -> utils.T:
|
||||
schema = cli2.Schema.extract(type)
|
||||
return schema.parse(args)
|
||||
|
@ -129,6 +130,7 @@ def test_cli_arg_int():
|
|||
assert_equal(extractParse(IntArg, ["--value=1"]).value, 1)
|
||||
|
||||
|
||||
"""
|
||||
class StrArg:
|
||||
value: str = cli2.arg(None, "value")
|
||||
|
||||
|
@ -210,3 +212,5 @@ def test_cli_arg_dict_str():
|
|||
"baz": "qux",
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue