This commit is contained in:
Sleepy Monax 2024-02-13 20:23:37 +01:00
parent 647043f4d3
commit ad5cc391be
2 changed files with 21 additions and 15 deletions

View file

@ -1,5 +1,7 @@
import sys import sys
from enum import Enum from enum import Enum
import typing as tp import typing as tp
import dataclasses as dt import dataclasses as dt
@ -266,7 +268,7 @@ def parseArgs(args: list[str]) -> list[Token]:
# --- Schema ----------------------------------------------------------------- # # --- Schema ----------------------------------------------------------------- #
class FieldType(Enum): class FieldKind(Enum):
FLAG = 0 FLAG = 0
OPERAND = 1 OPERAND = 1
EXTRA = 2 EXTRA = 2
@ -274,7 +276,7 @@ class FieldType(Enum):
@dt.dataclass @dt.dataclass
class Field: class Field:
type: FieldType kind: FieldKind
shortName: Optional[str] shortName: Optional[str]
longName: str longName: str
description: str = "" description: str = ""
@ -314,9 +316,13 @@ class Field:
setattr(obj, self._fieldName, self.default) setattr(obj, self._fieldName, self.default)
def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None): def putValue(self, obj: Any, value: Any, subkey: Optional[str] = None):
if self._fieldName is None:
raise ValueError("Field type is not defined")
setattr(obj, self._fieldName, value) setattr(obj, self._fieldName, value)
def getAttr(self, obj: Any) -> Any: def getAttr(self, obj: Any) -> Any:
if self._fieldName is None:
raise ValueError("Field name is not defined")
return getattr(obj, self._fieldName) return getattr(obj, self._fieldName)
@ -326,15 +332,15 @@ def arg(
description: str = "", description: str = "",
default: Any = None, default: Any = None,
) -> Any: ) -> Any:
return Field(FieldType.FLAG, shortName, longName, description, default) return Field(FieldKind.FLAG, shortName, longName, description, default)
def operand(longName: str = "", description: str = "") -> Any: def operand(longName: str = "", description: str = "") -> Any:
return Field(FieldType.OPERAND, None, longName, description) return Field(FieldKind.OPERAND, None, longName, description)
def extra(longName: str = "", description: str = "") -> Any: def extra(longName: str = "", description: str = "") -> Any:
return Field(FieldType.EXTRA, None, longName, description) return Field(FieldKind.EXTRA, None, longName, description)
@dt.dataclass @dt.dataclass
@ -359,11 +365,11 @@ class Schema:
field.bind(typ, f) field.bind(typ, f)
if field.type == FieldType.FLAG: if field.type == FieldKind.FLAG:
s.args.append(field) s.args.append(field)
elif field.type == FieldType.OPERAND: elif field.type == FieldKind.OPERAND:
s.operands.append(field) s.operands.append(field)
elif field.type == FieldType.EXTRA: elif field.type == FieldKind.EXTRA:
if s.extras: if s.extras:
raise ValueError("Only one extra argument is allowed") raise ValueError("Only one extra argument is allowed")
s.extras = field s.extras = field
@ -417,11 +423,7 @@ class Schema:
return None return None
res = self.typ() res = self.typ()
for arg in self.args: for arg in self.args:
field = arg.field arg.setDefault(res)
if field:
field.put(res, arg.default or field.default())
else:
raise ValueError("Argument has no field")
return res return res
def parse(self, args: list[str]) -> Any: def parse(self, args: list[str]) -> Any:

View file

@ -1,9 +1,7 @@
from cutekit import cli2, utils from cutekit import cli2, utils
from asserts import ( from asserts import (
assert_is, assert_is,
assert_true,
assert_equal, assert_equal,
assert_raises,
assert_is_instance, assert_is_instance,
) )
@ -115,6 +113,9 @@ def test_parse_key_subkey_arg():
assert_equal(arg.value, True) assert_equal(arg.value, True)
# --- Extract Args ----------------------------------------------------------- #
def extractParse(type: type[utils.T], args: list[str]) -> utils.T: def extractParse(type: type[utils.T], args: list[str]) -> utils.T:
schema = cli2.Schema.extract(type) schema = cli2.Schema.extract(type)
return schema.parse(args) return schema.parse(args)
@ -129,6 +130,7 @@ def test_cli_arg_int():
assert_equal(extractParse(IntArg, ["--value=1"]).value, 1) assert_equal(extractParse(IntArg, ["--value=1"]).value, 1)
"""
class StrArg: class StrArg:
value: str = cli2.arg(None, "value") value: str = cli2.arg(None, "value")
@ -210,3 +212,5 @@ def test_cli_arg_dict_str():
"baz": "qux", "baz": "qux",
}, },
) )
"""