Skip to content

Commit 83e5fca

Browse files
authored
Merge pull request #2574 from crytic/dev-echidna-values
Echidna printer Improve values extraction
2 parents 79619f6 + f6b2509 commit 83e5fca

File tree

4 files changed

+272
-40
lines changed

4 files changed

+272
-40
lines changed

slither/printers/guidance/echidna.py

+65-34
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from slither.core.expressions import NewContract
1414
from slither.core.slither_core import SlitherCore
1515
from slither.core.solidity_types import TypeAlias
16+
from slither.core.source_mapping.source_mapping import SourceMapping
1617
from slither.core.variables.state_variable import StateVariable
1718
from slither.core.variables.variable import Variable
1819
from slither.printers.abstract_printer import AbstractPrinter
@@ -179,29 +180,74 @@ class ConstantValue(NamedTuple): # pylint: disable=inherit-non-class,too-few-pu
179180
type: str
180181

181182

182-
def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-nested-blocks
183+
def _extract_constant_from_read(
184+
ir: Operation,
185+
r: SourceMapping,
186+
all_cst_used: List[ConstantValue],
187+
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
188+
context_explored: Set[Node],
189+
) -> None:
190+
var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r
191+
# Do not report struct_name in a.struct_name
192+
if isinstance(ir, Member):
193+
return
194+
if isinstance(var_read, Variable) and var_read.is_constant:
195+
# In case of type conversion we use the destination type
196+
if isinstance(ir, TypeConversion):
197+
if isinstance(ir.type, TypeAlias):
198+
value_type = ir.type.type
199+
else:
200+
value_type = ir.type
201+
else:
202+
value_type = var_read.type
203+
try:
204+
value = ConstantFolding(var_read.expression, value_type).result()
205+
all_cst_used.append(ConstantValue(str(value), str(value_type)))
206+
except NotConstant:
207+
pass
208+
if isinstance(var_read, Constant):
209+
all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type)))
210+
if isinstance(var_read, StateVariable):
211+
if var_read.node_initialization:
212+
if var_read.node_initialization.irs:
213+
if var_read.node_initialization in context_explored:
214+
return
215+
context_explored.add(var_read.node_initialization)
216+
_extract_constants_from_irs(
217+
var_read.node_initialization.irs,
218+
all_cst_used,
219+
all_cst_used_in_binary,
220+
context_explored,
221+
)
222+
223+
224+
def _extract_constant_from_binary(
225+
ir: Binary,
226+
all_cst_used: List[ConstantValue],
227+
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
228+
):
229+
for r in ir.read:
230+
if isinstance(r, Constant):
231+
all_cst_used_in_binary[str(ir.type)].append(ConstantValue(str(r.value), str(r.type)))
232+
if isinstance(ir.variable_left, Constant) or isinstance(ir.variable_right, Constant):
233+
if ir.lvalue:
234+
try:
235+
type_ = ir.lvalue.type
236+
cst = ConstantFolding(ir.expression, type_).result()
237+
all_cst_used.append(ConstantValue(str(cst.value), str(type_)))
238+
except NotConstant:
239+
pass
240+
241+
242+
def _extract_constants_from_irs(
183243
irs: List[Operation],
184244
all_cst_used: List[ConstantValue],
185245
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
186246
context_explored: Set[Node],
187247
) -> None:
188248
for ir in irs:
189249
if isinstance(ir, Binary):
190-
for r in ir.read:
191-
if isinstance(r, Constant):
192-
all_cst_used_in_binary[str(ir.type)].append(
193-
ConstantValue(str(r.value), str(r.type))
194-
)
195-
if isinstance(ir.variable_left, Constant) or isinstance(
196-
ir.variable_right, Constant
197-
):
198-
if ir.lvalue:
199-
try:
200-
type_ = ir.lvalue.type
201-
cst = ConstantFolding(ir.expression, type_).result()
202-
all_cst_used.append(ConstantValue(str(cst.value), str(type_)))
203-
except NotConstant:
204-
pass
250+
_extract_constant_from_binary(ir, all_cst_used, all_cst_used_in_binary)
205251
if isinstance(ir, TypeConversion):
206252
if isinstance(ir.variable, Constant):
207253
if isinstance(ir.type, TypeAlias):
@@ -222,24 +268,9 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n
222268
except ValueError: # index could fail; should never happen in working solidity code
223269
pass
224270
for r in ir.read:
225-
var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r
226-
# Do not report struct_name in a.struct_name
227-
if isinstance(ir, Member):
228-
continue
229-
if isinstance(var_read, Constant):
230-
all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type)))
231-
if isinstance(var_read, StateVariable):
232-
if var_read.node_initialization:
233-
if var_read.node_initialization.irs:
234-
if var_read.node_initialization in context_explored:
235-
continue
236-
context_explored.add(var_read.node_initialization)
237-
_extract_constants_from_irs(
238-
var_read.node_initialization.irs,
239-
all_cst_used,
240-
all_cst_used_in_binary,
241-
context_explored,
242-
)
271+
_extract_constant_from_read(
272+
ir, r, all_cst_used, all_cst_used_in_binary, context_explored
273+
)
243274

244275

245276
def _extract_constants(

slither/visitors/expression/constants_folding.py

+166-6
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
TupleExpression,
1414
TypeConversion,
1515
CallExpression,
16+
MemberAccess,
1617
)
18+
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
1719
from slither.core.variables import Variable
1820
from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int
1921
from slither.visitors.expression.expression import ExpressionVisitor
@@ -27,7 +29,13 @@ class NotConstant(Exception):
2729
KEY = "ConstantFolding"
2830

2931
CONSTANT_TYPES_OPERATIONS = Union[
30-
Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion
32+
Literal,
33+
BinaryOperation,
34+
UnaryOperation,
35+
Identifier,
36+
TupleExpression,
37+
TypeConversion,
38+
MemberAccess,
3139
]
3240

3341

@@ -69,6 +77,9 @@ def result(self) -> "Literal":
6977
# pylint: disable=import-outside-toplevel
7078
def _post_identifier(self, expression: Identifier) -> None:
7179
from slither.core.declarations.solidity_variables import SolidityFunction
80+
from slither.core.declarations.enum import Enum
81+
from slither.core.solidity_types.type_alias import TypeAlias
82+
from slither.core.declarations.contract import Contract
7283

7384
if isinstance(expression.value, Variable):
7485
if expression.value.is_constant:
@@ -77,7 +88,14 @@ def _post_identifier(self, expression: Identifier) -> None:
7788
# Everything outside of literal
7889
if isinstance(
7990
expr,
80-
(BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
91+
(
92+
BinaryOperation,
93+
UnaryOperation,
94+
Identifier,
95+
TupleExpression,
96+
TypeConversion,
97+
MemberAccess,
98+
),
8199
):
82100
cf = ConstantFolding(expr, self._type)
83101
expr = cf.result()
@@ -88,20 +106,41 @@ def _post_identifier(self, expression: Identifier) -> None:
88106
elif isinstance(expression.value, SolidityFunction):
89107
set_val(expression, expression.value)
90108
else:
91-
raise NotConstant
109+
# Enum: We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value
110+
# We can't handle it here because we don't have the field accessed so we do it in _post_member_access
111+
# TypeAlias: Support when a .wrap() is done with a constant
112+
# Contract: Support when a constatn is use from a different contract
113+
if not isinstance(expression.value, (Enum, TypeAlias, Contract)):
114+
raise NotConstant
92115

93116
# pylint: disable=too-many-branches,too-many-statements
94117
def _post_binary_operation(self, expression: BinaryOperation) -> None:
95118
expression_left = expression.expression_left
96119
expression_right = expression.expression_right
97120
if not isinstance(
98121
expression_left,
99-
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
122+
(
123+
Literal,
124+
BinaryOperation,
125+
UnaryOperation,
126+
Identifier,
127+
TupleExpression,
128+
TypeConversion,
129+
MemberAccess,
130+
),
100131
):
101132
raise NotConstant
102133
if not isinstance(
103134
expression_right,
104-
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
135+
(
136+
Literal,
137+
BinaryOperation,
138+
UnaryOperation,
139+
Identifier,
140+
TupleExpression,
141+
TypeConversion,
142+
MemberAccess,
143+
),
105144
):
106145
raise NotConstant
107146
left = get_val(expression_left)
@@ -205,6 +244,34 @@ def _post_assignement_operation(self, expression: expressions.AssignmentOperatio
205244
raise NotConstant
206245

207246
def _post_call_expression(self, expression: expressions.CallExpression) -> None:
247+
from slither.core.declarations.solidity_variables import SolidityFunction
248+
from slither.core.declarations.enum import Enum
249+
from slither.core.solidity_types import TypeAlias
250+
251+
# pylint: disable=too-many-boolean-expressions
252+
if (
253+
isinstance(expression.called, Identifier)
254+
and expression.called.value == SolidityFunction("type()")
255+
and len(expression.arguments) == 1
256+
and (
257+
isinstance(expression.arguments[0], ElementaryTypeNameExpression)
258+
or isinstance(expression.arguments[0], Identifier)
259+
and isinstance(expression.arguments[0].value, Enum)
260+
)
261+
):
262+
# Returning early to support type(ElemType).max/min or type(MyEnum).max/min
263+
return
264+
if (
265+
isinstance(expression.called.expression, Identifier)
266+
and isinstance(expression.called.expression.value, TypeAlias)
267+
and isinstance(expression.called, MemberAccess)
268+
and expression.called.member_name == "wrap"
269+
and len(expression.arguments) == 1
270+
):
271+
# Handle constants in .wrap of user defined type
272+
set_val(expression, get_val(expression.arguments[0]))
273+
return
274+
208275
called = get_val(expression.called)
209276
args = [get_val(arg) for arg in expression.arguments]
210277
if called.name == "keccak256(bytes)":
@@ -220,12 +287,104 @@ def _post_conditional_expression(self, expression: expressions.ConditionalExpres
220287
def _post_elementary_type_name_expression(
221288
self, expression: expressions.ElementaryTypeNameExpression
222289
) -> None:
223-
raise NotConstant
290+
# We don't have to raise an exception to support type(uint112).max or similar
291+
pass
224292

225293
def _post_index_access(self, expression: expressions.IndexAccess) -> None:
226294
raise NotConstant
227295

296+
# pylint: disable=too-many-locals
228297
def _post_member_access(self, expression: expressions.MemberAccess) -> None:
298+
from slither.core.declarations import (
299+
SolidityFunction,
300+
Contract,
301+
EnumContract,
302+
EnumTopLevel,
303+
Enum,
304+
)
305+
from slither.core.solidity_types import UserDefinedType, TypeAlias
306+
307+
# pylint: disable=too-many-nested-blocks
308+
if isinstance(expression.expression, CallExpression) and expression.member_name in [
309+
"min",
310+
"max",
311+
]:
312+
if isinstance(expression.expression.called, Identifier):
313+
if expression.expression.called.value == SolidityFunction("type()"):
314+
assert len(expression.expression.arguments) == 1
315+
type_expression_found = expression.expression.arguments[0]
316+
type_found: Union[ElementaryType, UserDefinedType]
317+
if isinstance(type_expression_found, ElementaryTypeNameExpression):
318+
type_expression_found_type = type_expression_found.type
319+
assert isinstance(type_expression_found_type, ElementaryType)
320+
type_found = type_expression_found_type
321+
value = (
322+
type_found.max if expression.member_name == "max" else type_found.min
323+
)
324+
set_val(expression, value)
325+
return
326+
# type(enum).max/min
327+
# Case when enum is in another contract e.g. type(C.E).max
328+
if isinstance(type_expression_found, MemberAccess):
329+
contract = type_expression_found.expression.value
330+
assert isinstance(contract, Contract)
331+
for enum in contract.enums:
332+
if enum.name == type_expression_found.member_name:
333+
type_found_in_expression = enum
334+
type_found = UserDefinedType(enum)
335+
break
336+
else:
337+
assert isinstance(type_expression_found, Identifier)
338+
type_found_in_expression = type_expression_found.value
339+
assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel))
340+
type_found = UserDefinedType(type_found_in_expression)
341+
value = (
342+
type_found_in_expression.max
343+
if expression.member_name == "max"
344+
else type_found_in_expression.min
345+
)
346+
set_val(expression, value)
347+
return
348+
elif isinstance(expression.expression, Identifier) and isinstance(
349+
expression.expression.value, Enum
350+
):
351+
# Handle direct access to enum field
352+
set_val(expression, expression.expression.value.values.index(expression.member_name))
353+
return
354+
elif isinstance(expression.expression, Identifier) and isinstance(
355+
expression.expression.value, TypeAlias
356+
):
357+
# User defined type .wrap call handled in _post_call_expression
358+
return
359+
elif (
360+
isinstance(expression.expression.value, Contract)
361+
and expression.member_name in expression.expression.value.variables_as_dict
362+
and expression.expression.value.variables_as_dict[expression.member_name].is_constant
363+
):
364+
# Handles when a constant is accessed on another contract
365+
variables = expression.expression.value.variables_as_dict
366+
if isinstance(variables[expression.member_name].expression, MemberAccess):
367+
self._post_member_access(variables[expression.member_name].expression)
368+
set_val(expression, get_val(variables[expression.member_name].expression))
369+
return
370+
371+
# If the variable is a Literal we convert its value to int
372+
if isinstance(variables[expression.member_name].expression, Literal):
373+
value = convert_string_to_int(
374+
variables[expression.member_name].expression.converted_value
375+
)
376+
# If the variable is a UnaryOperation we need convert its value to int
377+
# and replacing possible spaces
378+
elif isinstance(variables[expression.member_name].expression, UnaryOperation):
379+
value = convert_string_to_int(
380+
str(variables[expression.member_name].expression).replace(" ", "")
381+
)
382+
else:
383+
value = variables[expression.member_name].expression
384+
385+
set_val(expression, value)
386+
return
387+
229388
raise NotConstant
230389

231390
def _post_new_array(self, expression: expressions.NewArray) -> None:
@@ -272,6 +431,7 @@ def _post_type_conversion(self, expression: expressions.TypeConversion) -> None:
272431
TupleExpression,
273432
TypeConversion,
274433
CallExpression,
434+
MemberAccess,
275435
),
276436
):
277437
raise NotConstant
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from pathlib import Path
2+
3+
from slither import Slither
4+
from slither.printers.guidance.echidna import _extract_constants, ConstantValue
5+
6+
TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"
7+
8+
9+
def test_enum_max_min(solc_binary_path) -> None:
10+
solc_path = solc_binary_path("0.8.19")
11+
slither = Slither(Path(TEST_DATA_DIR, "constantfolding.sol").as_posix(), solc=solc_path)
12+
13+
contracts = slither.get_contract_from_name("A")
14+
15+
constants = _extract_constants(contracts)[0]["A"]["use()"]
16+
17+
assert set(constants) == {
18+
ConstantValue(value="2", type="uint256"),
19+
ConstantValue(value="10", type="uint256"),
20+
ConstantValue(value="100", type="uint256"),
21+
ConstantValue(value="4294967295", type="uint32"),
22+
}

0 commit comments

Comments
 (0)