将字符串转换为Enum子类的对应实例的正确方法是什么?似乎getattr(YourEnumType, str)做的工作,但我不确定它是否足够安全。
举个例子,假设我有一个枚举
class BuildType(Enum):
debug = 200
release = 400
给定字符串'调试',我怎么能得到BuildType.debug作为结果?
将字符串转换为Enum子类的对应实例的正确方法是什么?似乎getattr(YourEnumType, str)做的工作,但我不确定它是否足够安全。
举个例子,假设我有一个枚举
class BuildType(Enum):
debug = 200
release = 400
给定字符串'调试',我怎么能得到BuildType.debug作为结果?
当前回答
更改你的类签名为:
class BuildType(str, Enum):
其他回答
我的类似java的解决方案。希望它能帮助到某人…
from enum import Enum, auto
class SignInMethod(Enum):
EMAIL = auto(),
GOOGLE = auto()
@classmethod
def value_of(cls, value):
for k, v in cls.__members__.items():
if k == value:
return v
else:
raise ValueError(f"'{cls.__name__}' enum not found for '{value}'")
sim = SignInMethod.value_of('EMAIL')
assert sim == SignInMethod.EMAIL
assert sim.name == 'EMAIL'
assert isinstance(sim, SignInMethod)
# SignInMethod.value_of("invalid sign-in method") # should raise `ValueError`
更改你的类签名为:
class BuildType(str, Enum):
由于MyEnum['dontexist']将导致错误KeyError: 'dontexist',您可能喜欢无声地失败(例如。返回None)。在这种情况下,你可以使用下面的静态方法:
class Statuses(enum.Enum):
Unassigned = 1
Assigned = 2
@staticmethod
def from_str(text):
statuses = [status for status in dir(
Statuses) if not status.startswith('_')]
if text in statuses:
return getattr(Statuses, text)
return None
Statuses.from_str('Unassigned')
另一个选择(特别有用,如果你的字符串不映射1-1到你的enum情况)是添加一个static方法到你的enum,例如:
class QuestionType(enum.Enum):
MULTI_SELECT = "multi"
SINGLE_SELECT = "single"
@staticmethod
def from_str(label):
if label in ('single', 'singleSelect'):
return QuestionType.SINGLE_SELECT
elif label in ('multi', 'multiSelect'):
return QuestionType.MULTI_SELECT
else:
raise NotImplementedError
然后你可以输入question_type = questiontype。from_str('singleSelect')
class LogLevel(IntEnum):
critical = logging.CRITICAL
fatal = logging.FATAL
error = logging.ERROR
warning = logging.WARNING
info = logging.INFO
debug = logging.DEBUG
notset = logging.NOTSET
def __str__(self):
return f'{self.__class__.__name__}.{self.name}'
@classmethod
def _missing_(cls, value):
if type(value) is str:
value = value.lower()
if value in dir(cls):
return cls[value]
raise ValueError("%r is not a valid %s" % (value, cls.__name__))
例子:
print(LogLevel('Info'))
print(LogLevel(logging.WARNING))
print(LogLevel(10)) # logging.DEBUG
print(LogLevel.fatal)
print(LogLevel(550))
输出:
LogLevel.info
LogLevel.warning
LogLevel.debug
LogLevel.critical
ValueError: 550 is not a valid LogLevel