Skip to content

Commit

Permalink
Refactor duplicated Operation logic into parent class
Browse files Browse the repository at this point in the history
  • Loading branch information
ValueRaider committed Jan 7, 2025
1 parent 835172b commit 8620a60
Showing 1 changed file with 49 additions and 220 deletions.
269 changes: 49 additions & 220 deletions yfinance/screener/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,50 +10,6 @@

class OperationBase(ABC):
def __init__(self, operator: str, operand: Union[numbers.Real, str, List['OperationBase']]):
self.operator = operator
self.operands = operand

@abstractmethod
def to_dict(self) -> Dict:
raise YFNotImplementedError('to_dict() needs to be implemented by children classes')

@abstractmethod
def __repr__(self) -> Dict:
raise YFNotImplementedError('__repr__() needs to be implemented by children classes')

@abstractmethod
def __str__(self) -> Dict:
raise YFNotImplementedError('__str__() needs to be implemented by children classes')


class EquityOperation(OperationBase):
"""
The `EquityOperation` class constructs filters for stocks based on specific criteria such as region, sector, exchange, and peer group.
The queries support operators: `GT` (greater than), `LT` (less than), `BTWN` (between), `EQ` (equals), and logical operators `AND` and `OR` for combining multiple conditions.
Example:
Operation for 'day_gainers':
.. code-block:: python
EquityOperation('and', [
EquityOperation('gt', ['percentchange', 3]),
EquityOperation('eq', ['region', 'us']),
EquityOperation('gte', ['intradaymarketcap', 2000000000]),
EquityOperation('gte', ['intradayprice', 5]),
EquityOperation('gt', ['dayvolume', 15000])
])
"""
def __init__(self, operator: str, operand: Union[numbers.Real, str, List['EquityOperation']]):
"""
.. seealso::
:attr:`EquityOperation.valid_operand_fields <yfinance.EquityOperation.valid_operand_fields>`
supported operand values for query
:attr:`EquityOperation.valid_eq_operand_map <yfinance.EquityOperation.valid_eq_operand_map>`
supported `EQ query operand parameters`
"""
operator = operator.upper()

if not isinstance(operand, list):
Expand All @@ -76,48 +32,38 @@ def __init__(self, operator: str, operand: Union[numbers.Real, str, List['Equity

self.operator = operator
self.operands = operand
self._valid_eq_operand_map = EQUITY_SCREENER_EQ_MAP
self._valid_operand_fields = EQUITY_SCREENER_FIELDS

@dynamic_docstring({"valid_eq_operand_map_table": generate_list_table_from_dict(EQUITY_SCREENER_EQ_MAP)})

@property
@abstractmethod
def valid_eq_operand_map(self) -> Dict:
"""
Valid Operand Map for Operator "EQ"
{valid_eq_operand_map_table}
"""
return self._valid_eq_operand_map
raise YFNotImplementedError('valid_eq_operand_map() needs to be implemented by child')

@dynamic_docstring({"valid_operand_fields_table": generate_list_table_from_dict(EQUITY_SCREENER_FIELDS)})
@property
def valid_operand_fields(self) -> Dict:
"""
Valid Operand Fields
{valid_operand_fields_table}
"""
return self._valid_operand_fields

def _validate_or_and_operand(self, operand: List['EquityOperation']) -> None:
@abstractmethod
def valid_operand_fields(self) -> List:
raise YFNotImplementedError('valid_operand_fields() needs to be implemented by child')

def _validate_or_and_operand(self, operand: List['OperationBase']) -> None:
if len(operand) <= 1:
raise ValueError('Operand must be length longer than 1')
if all(isinstance(e, EquityOperation) for e in operand) is False:
raise TypeError('Operand must be type EquityOperation for OR/AND')
if all(isinstance(e, OperationBase) for e in operand) is False:
raise TypeError(f'Operand must be type {type(self)} for OR/AND')

def _validate_eq_operand(self, operand: List[Union[str, numbers.Real]]) -> None:
if len(operand) != 2:
raise ValueError('Operand must be length 2 for EQ')

if not any(operand[0] in fields_by_type for fields_by_type in EQUITY_SCREENER_FIELDS.values()):
raise ValueError(f'Invalid field for EquityOperation "{operand[0]}"')
if operand[0] in EQUITY_SCREENER_EQ_MAP:
if operand[1] not in EQUITY_SCREENER_EQ_MAP[operand[0]]:
if not any(operand[0] in fields_by_type for fields_by_type in self.valid_operand_fields.values()):
raise ValueError(f'Invalid field for {type(self)} "{operand[0]}"')
if operand[0] in self.valid_eq_operand_map:
if operand[1] not in self.valid_eq_operand_map[operand[0]]:
raise ValueError(f'Invalid EQ value "{operand[1]}"')

def _validate_btwn_operand(self, operand: List[Union[str, numbers.Real]]) -> None:
if len(operand) != 3:
raise ValueError('Operand must be length 3 for BTWN')
if not any(operand[0] in fields_by_type for fields_by_type in EQUITY_SCREENER_FIELDS.values()):
raise ValueError('Invalid field for EquityOperation')
if not any(operand[0] in fields_by_type for fields_by_type in self.valid_operand_fields.values()):
raise ValueError(f'Invalid field for {type(self)}')
if isinstance(operand[1], numbers.Real) is False:
raise TypeError('Invalid comparison type for BTWN')
if isinstance(operand[2], numbers.Real) is False:
Expand All @@ -126,42 +72,39 @@ def _validate_btwn_operand(self, operand: List[Union[str, numbers.Real]]) -> Non
def _validate_gt_lt(self, operand: List[Union[str, numbers.Real]]) -> None:
if len(operand) != 2:
raise ValueError('Operand must be length 2 for GT/LT')
if not any(operand[0] in fields_by_type for fields_by_type in EQUITY_SCREENER_FIELDS.values()):
raise ValueError(f'Invalid field for EquityOperation "{operand[0]}"')
if not any(operand[0] in fields_by_type for fields_by_type in self.valid_operand_fields.values()):
raise ValueError(f'Invalid field for {type(self)} "{operand[0]}"')
if isinstance(operand[1], numbers.Real) is False:
raise TypeError('Invalid comparison type for GT/LT')

def _validate_isin_operand(self, operand: List['EquityOperation']) -> None:
def _validate_isin_operand(self, operand: List['OperationBase']) -> None:
if len(operand) < 2:
raise ValueError('Operand must be length 2+ for IS-IN')

if not any(operand[0] in fields_by_type for fields_by_type in EQUITY_SCREENER_FIELDS.values()):
raise ValueError(f'Invalid field for EquityOperation "{operand[0]}"')
if operand[0] in EQUITY_SCREENER_EQ_MAP:
if not any(operand[0] in fields_by_type for fields_by_type in self.valid_operand_fields.values()):
raise ValueError(f'Invalid field for {type(self)} "{operand[0]}"')
if operand[0] in self.valid_eq_operand_map:
for i in range(1, len(operand)):
if operand[i] not in EQUITY_SCREENER_EQ_MAP[operand[0]]:
if operand[i] not in self.valid_eq_operand_map[operand[0]]:
raise ValueError(f'Invalid EQ value "{operand[i]}"')

def to_dict(self) -> Dict:
if self.operator == 'IS-IN':
# Expand to OR of EQ queries
sub_queries = []
sub_ops = []
for v in self.operands[1:]:
sub_queries.append(EquityOperation('EQ', [self.operands[0], v]))
sub_ops.append(type(self)('EQ', [self.operands[0], v]))
self.operator = 'OR'
self.operands = sub_queries
self.operands = sub_ops
return {
"operator": self.operator,
"operands": [operand.to_dict() if isinstance(operand, EquityOperation) else operand for operand in self.operands]
"operands": [operand.to_dict() if isinstance(operand, type(self)) else operand for operand in self.operands]
}

def __repr__(self, root=True) -> str:
if root:
s = '"'
else:
s = ''
s = '"' if root else ''

s += f"EquityOperation({self.operator}, ["
s += f"{type(self).__name__}({self.operator}, ["
for i in range(len(self.operands)):
o = self.operands[i]
if isinstance(o, OperationBase):
Expand All @@ -179,68 +122,35 @@ def __str__(self) -> str:
return self.__repr__()


class FundOperation(OperationBase):
"""
The `FundOperation` class constructs filters for mutual funds based on specific criteria such as region, sector, exchange, and peer group.
The queries support operators: `GT` (greater than), `LT` (less than), `BTWN` (between), `EQ` (equals), and logical operators `AND` and `OR` for combining multiple conditions.
Example:
Operation for "solid_large_growth_funds":
.. code-block:: python
FundOperation('and', [
FundOperation('eq', ['categoryname', 'Large Growth']),
FundOperation('is-in', ['performanceratingoverall', 4, 5]),
FundOperation('lt', ['initialinvestment', 100001]),
FundOperation('lt', ['annualreturnnavy1categoryrank', 50]),
FundOperation('eq', ['exchange', 'NAS'])
])
"""
def __init__(self, operator: str, operand: Union[numbers.Real, str, List['FundOperation']]):
class EquityOperation(OperationBase):
@dynamic_docstring({"valid_eq_operand_map_table": generate_list_table_from_dict(EQUITY_SCREENER_EQ_MAP)})
@property
def valid_eq_operand_map(self) -> Dict:
"""
.. seealso::
:attr:`FundOperation.valid_operand_fields <yfinance.FundOperation.valid_operand_fields>`
supported operand values for query
:attr:`FundOperation.valid_eq_operand_map <yfinance.FundOperation.valid_eq_operand_map>`
supported `EQ query operand parameters`
Valid Operand Map for Operator "EQ"
{valid_eq_operand_map_table}
"""
operator = operator.upper()

if not isinstance(operand, list):
raise TypeError('Invalid operand type')
if len(operand) <= 0:
raise ValueError('Invalid field for Screener')

if operator == 'IS-IN':
self._validate_isin_operand(operand)
elif operator in {'OR','AND'}:
self._validate_or_and_operand(operand)
elif operator == 'EQ':
self._validate_eq_operand(operand)
elif operator == 'BTWN':
self._validate_btwn_operand(operand)
elif operator in {'GT','LT','GTE','LTE'}:
self._validate_gt_lt(operand)
else:
raise ValueError('Invalid Operator Value')
return EQUITY_SCREENER_EQ_MAP

self.operator = operator
self.operands = operand
self._valid_eq_operand_map = FUND_SCREENER_EQ_MAP
self._valid_operand_fields = FUND_SCREENER_FIELDS
@dynamic_docstring({"valid_operand_fields_table": generate_list_table_from_dict(EQUITY_SCREENER_FIELDS)})
@property
def valid_operand_fields(self) -> Dict:
"""
Valid Operand Fields
{valid_operand_fields_table}
"""
return EQUITY_SCREENER_FIELDS


class FundOperation(OperationBase):
@dynamic_docstring({"valid_eq_operand_map_table": generate_list_table_from_dict(FUND_SCREENER_EQ_MAP)})
@property
def valid_eq_operand_map(self) -> Dict:
"""
Valid Operand Map for Operator "EQ"
{valid_eq_operand_map_table}
"""
return self._valid_eq_operand_map
return FUND_SCREENER_EQ_MAP

@dynamic_docstring({"valid_operand_fields_table": generate_list_table_from_dict(FUND_SCREENER_FIELDS)})
@property
Expand All @@ -249,86 +159,5 @@ def valid_operand_fields(self) -> Dict:
Valid Operand Fields
{valid_operand_fields_table}
"""
return self._valid_operand_fields

def _validate_or_and_operand(self, operand: List['FundOperation']) -> None:
if len(operand) <= 1:
raise ValueError('Operand must be length longer than 1')
if all(isinstance(e, FundOperation) for e in operand) is False:
raise TypeError('Operand must be type FundOperation for OR/AND')

def _validate_eq_operand(self, operand: List[Union[str, numbers.Real]]) -> None:
if len(operand) != 2:
raise ValueError('Operand must be length 2 for EQ')

if not any(operand[0] in fields_by_type for fields_by_type in FUND_SCREENER_FIELDS.values()):
raise ValueError(f'Invalid field for FundOperation "{operand[0]}"')
if operand[0] in FUND_SCREENER_EQ_MAP:
if operand[1] not in FUND_SCREENER_EQ_MAP[operand[0]]:
raise ValueError(f'Invalid EQ value "{operand[1]}"')

def _validate_btwn_operand(self, operand: List[Union[str, numbers.Real]]) -> None:
if len(operand) != 3:
raise ValueError('Operand must be length 3 for BTWN')
if not any(operand[0] in fields_by_type for fields_by_type in FUND_SCREENER_FIELDS.values()):
raise ValueError('Invalid field for FundOperation')
if isinstance(operand[1], numbers.Real) is False:
raise TypeError('Invalid comparison type for BTWN')
if isinstance(operand[2], numbers.Real) is False:
raise TypeError('Invalid comparison type for BTWN')

def _validate_gt_lt(self, operand: List[Union[str, numbers.Real]]) -> None:
if len(operand) != 2:
raise ValueError('Operand must be length 2 for GT/LT')
if not any(operand[0] in fields_by_type for fields_by_type in FUND_SCREENER_FIELDS.values()):
raise ValueError(f'Invalid field for FundOperation "{operand[0]}"')
if isinstance(operand[1], numbers.Real) is False:
raise TypeError('Invalid comparison type for GT/LT')

def _validate_isin_operand(self, operand: List['FundOperation']) -> None:
if len(operand) < 2:
raise ValueError('Operand must be length 2+ for IS-IN')

if not any(operand[0] in fields_by_type for fields_by_type in FUND_SCREENER_FIELDS.values()):
raise ValueError(f'Invalid field for FundOperation "{operand[0]}"')
if operand[0] in FUND_SCREENER_EQ_MAP:
for i in range(1, len(operand)):
if operand[i] not in FUND_SCREENER_EQ_MAP[operand[0]]:
raise ValueError(f'Invalid EQ value "{operand[i]}"')

def to_dict(self) -> Dict:
if self.operator == 'IS-IN':
# Expand to OR of EQ queries
sub_queries = []
for v in self.operands[1:]:
sub_queries.append(FundOperation('EQ', [self.operands[0], v]))
self.operator = 'OR'
self.operands = sub_queries
return {
"operator": self.operator,
"operands": [operand.to_dict() if isinstance(operand, FundOperation) else operand for operand in self.operands]
}

def __repr__(self, root=True) -> str:
if root:
s = '"'
else:
s = ''

s += f"FundOperation({self.operator}, ["
for i in range(len(self.operands)):
o = self.operands[i]
if isinstance(o, OperationBase):
s += o.__repr__(root=False)
else:
s += o.__repr__()
if i < len(self.operands)-1:
s += ', '
s += ']'
if root:
s += '"'
return s

def __str__(self) -> str:
return self.__repr__()

return FUND_SCREENER_FIELDS

0 comments on commit 8620a60

Please sign in to comment.