Skip to content

Commit 10a466f

Browse files
committed
Improving geometry eq/hash to enable sets
1 parent 25f4626 commit 10a466f

File tree

8 files changed

+113
-9
lines changed

8 files changed

+113
-9
lines changed

src/build123d/build_common.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,8 @@ def __call__(self, select: Select = Select.ALL) -> T2_covar: ...
13501350

13511351
def __gen_context_component_getter(
13521352
func: Callable[[Builder, Select], T2],
1353-
) -> ContextComponentGetter[T2]:
1353+
# ) -> ContextComponentGetter[T2]:
1354+
) -> Callable[[Select], T2]:
13541355
"""
13551356
Wraps a Builder method to automatically provide the Builder context.
13561357

src/build123d/geometry.py

+51-6
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import warnings
4343

4444
from collections.abc import Iterable, Sequence
45-
from math import degrees, pi, radians, isclose
45+
from math import degrees, log10, pi, radians, isclose
4646
from typing import Any, overload, TypeAlias, TYPE_CHECKING
4747

4848
import OCP.TopAbs as TopAbs_ShapeEnum
@@ -88,6 +88,7 @@
8888
logger = logging.getLogger("build123d")
8989

9090
TOLERANCE = 1e-6
91+
TOL_DIGITS = abs(int(log10(TOLERANCE)))
9192
TOL = 1e-2
9293
DEG2RAD = pi / 180.0
9394
RAD2DEG = 180 / pi
@@ -445,11 +446,17 @@ def __eq__(self, other: object) -> bool:
445446
"""Vectors equal operator =="""
446447
if not isinstance(other, Vector):
447448
return NotImplemented
448-
return self.wrapped.IsEqual(other.wrapped, 0.00001, 0.00001)
449+
return self.wrapped.IsEqual(other.wrapped, TOLERANCE, TOLERANCE)
449450

450451
def __hash__(self) -> int:
451452
"""Hash of Vector"""
452-
return hash((round(self.X, 6), round(self.Y, 6), round(self.Z, 6)))
453+
return hash(
454+
(
455+
round(self.X, TOL_DIGITS - 1),
456+
round(self.Y, TOL_DIGITS - 1),
457+
round(self.Z, TOL_DIGITS - 1),
458+
)
459+
)
453460

454461
def __copy__(self) -> Vector:
455462
"""Return copy of self"""
@@ -690,6 +697,16 @@ def __deepcopy__(self, _memo) -> Axis:
690697
"""Return deepcopy of self"""
691698
return Axis(self.position, self.direction)
692699

700+
def __hash__(self) -> int:
701+
"""Hash of Axis"""
702+
return hash(
703+
(
704+
round(v, TOL_DIGITS - 1)
705+
for vector in [self.position, self.direction]
706+
for v in vector
707+
)
708+
)
709+
693710
def __repr__(self) -> str:
694711
"""Display self"""
695712
return f"({self.position.to_tuple()},{self.direction.to_tuple()})"
@@ -1660,7 +1677,25 @@ def __eq__(self, other: object) -> bool:
16601677
radians(other.orientation.Y),
16611678
radians(other.orientation.Z),
16621679
)
1663-
return self.position == other.position and quaternion1.IsEqual(quaternion2)
1680+
# Test quaternions with tolerance
1681+
q_values = [
1682+
[get_value() for get_value in (q.X, q.Y, q.Z, q.W)]
1683+
for q in (quaternion1, quaternion2)
1684+
]
1685+
quaternion_eq = all(
1686+
isclose(v1, v2, abs_tol=TOLERANCE) for v1, v2 in zip(*q_values)
1687+
)
1688+
return self.position == other.position and quaternion_eq
1689+
1690+
def __hash__(self) -> int:
1691+
"""Hash of Location"""
1692+
return hash(
1693+
(
1694+
round(v, TOL_DIGITS - 1)
1695+
for vector in [self.position, self.orientation]
1696+
for v in vector
1697+
)
1698+
)
16641699

16651700
def __neg__(self) -> Location:
16661701
"""Flip the orientation without changing the position operator -"""
@@ -2563,8 +2598,8 @@ def __eq__(self, other: object):
25632598
return NotImplemented
25642599

25652600
# equality tolerances
2566-
eq_tolerance_origin = 1e-6
2567-
eq_tolerance_dot = 1e-6
2601+
eq_tolerance_origin = TOLERANCE
2602+
eq_tolerance_dot = TOLERANCE
25682603

25692604
return (
25702605
# origins are the same
@@ -2575,6 +2610,16 @@ def __eq__(self, other: object):
25752610
and abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot
25762611
)
25772612

2613+
def __hash__(self) -> int:
2614+
"""Hash of Plane"""
2615+
return hash(
2616+
(
2617+
round(v, TOL_DIGITS - 1)
2618+
for vector in [self.origin, self.x_dir, self.z_dir]
2619+
for v in vector
2620+
)
2621+
)
2622+
25782623
def __neg__(self) -> Plane:
25792624
"""Reverse z direction of plane operator -"""
25802625
return Plane(self.origin, self.x_dir, -self.z_dir)

src/build123d/topology/shape_core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ def __deepcopy__(self, memo) -> Self:
866866
if self.wrapped is not None:
867867
memo[id(self.wrapped)] = downcast(BRepBuilderAPI_Copy(self.wrapped).Shape())
868868
for key, value in self.__dict__.items():
869-
if key == 'topo_parent':
869+
if key == "topo_parent":
870870
result.topo_parent = value
871871
else:
872872
setattr(result, key, copy.deepcopy(value, memo))

src/build123d/topology/two_d.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,10 @@ def _curvature_sign(self) -> float:
563563
float: The signed value; positive indicates convexity, negative indicates concavity.
564564
Returns 0 if the geometry type is unsupported.
565565
"""
566-
if self.geom_type == GeomType.CYLINDER:
566+
if (
567+
self.geom_type == GeomType.CYLINDER
568+
and type(self.geom_adaptor()) != Geom_RectangularTrimmedSurface
569+
):
567570
axis = self.axis_of_rotation
568571
if axis is None:
569572
raise ValueError("Can't find curvature of empty object")

tests/test_direct_api/test_axis.py

+13
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,19 @@ def test_axis_not_equal(self):
230230
random_obj = object()
231231
self.assertNotEqual(Axis.X, random_obj)
232232

233+
def test_set(self):
234+
a0 = Axis((0, 1, 2), (3, 4, 5))
235+
for i in range(1, 8):
236+
for j in range(1, 8):
237+
a1 = Axis(
238+
(a0.position.X + 1.0 / (10**i), a0.position.Y, a0.position.Z),
239+
(a0.direction.X + 1.0 / (10**j), a0.direction.Y, a0.direction.Z),
240+
)
241+
if a0 == a1:
242+
self.assertEqual(len(set([a0, a1])), 1)
243+
else:
244+
self.assertEqual(len(set([a0, a1])), 2)
245+
233246
def test_position_property(self):
234247
axis = Axis.X
235248
axis.position = 1, 2, 3

tests/test_direct_api/test_location.py

+17
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,23 @@ def test_not_equal(self):
266266
self.assertNotEqual(loc, diff_orientation)
267267
self.assertNotEqual(loc, object())
268268

269+
def test_set(self):
270+
l0 = Location((0, 1, 2), (3, 4, 5))
271+
for i in range(1, 8):
272+
for j in range(1, 8):
273+
l1 = Location(
274+
(l0.position.X + 1.0 / (10**i), l0.position.Y, l0.position.Z),
275+
(
276+
l0.orientation.X + 1.0 / (10**j),
277+
l0.orientation.Y,
278+
l0.orientation.Z,
279+
),
280+
)
281+
if l0 == l1:
282+
self.assertEqual(len(set([l0, l1])), 1)
283+
else:
284+
self.assertEqual(len(set([l0, l1])), 2)
285+
269286
def test_neg(self):
270287
loc = Location((1, 2, 3), (0, 35, 127))
271288
n_loc = -loc

tests/test_direct_api/test_plane.py

+15
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,21 @@ def test_plane_not_equal(self):
408408
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
409409
)
410410

411+
def test_set(self):
412+
p0 = Plane((0, 1, 2), (3, 4, 5), (6, 7, 8))
413+
for i in range(1, 8):
414+
for j in range(1, 8):
415+
for k in range(1, 8):
416+
p1 = Plane(
417+
(p0.origin.X + 1.0 / (10**i), p0.origin.Y, p0.origin.Z),
418+
(p0.x_dir.X + 1.0 / (10**j), p0.x_dir.Y, p0.x_dir.Z),
419+
(p0.z_dir.X + 1.0 / (10**k), p0.z_dir.Y, p0.z_dir.Z),
420+
)
421+
if p0 == p1:
422+
self.assertEqual(len(set([p0, p1])), 1)
423+
else:
424+
self.assertEqual(len(set([p0, p1])), 2)
425+
411426
def test_to_location(self):
412427
loc = Plane(origin=(1, 2, 3), x_dir=(0, 1, 0), z_dir=(0, 0, 1)).location
413428
self.assertAlmostEqual(loc.position, (1, 2, 3), 5)

tests/test_direct_api/test_vector.py

+10
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ def test_vector_not_equal(self):
156156
self.assertNotEqual(a, b)
157157
self.assertNotEqual(a, object())
158158

159+
def test_vector_sets(self):
160+
# Check that equal and hash work the same way to enable sets
161+
a = Vector(1, 2, 3)
162+
for i in range(1, 8):
163+
v = Vector(a.X + 1.0 / (10**i), a.Y, a.Z)
164+
if v == a:
165+
self.assertEqual(len(set([a, v])), 1)
166+
else:
167+
self.assertEqual(len(set([a, v])), 2)
168+
159169
def test_vector_distance(self):
160170
"""
161171
Test line distance from plane.

0 commit comments

Comments
 (0)