Skip to content

Commit

Permalink
Merge pull request #330 from activeloopai/fix/segmentation
Browse files Browse the repository at this point in the history
Fixed segmentation
  • Loading branch information
AbhinavTuli authored Dec 15, 2020
2 parents a701434 + 0f81310 commit 546a797
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
10 changes: 7 additions & 3 deletions hub/schema/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@ def get_attr_dict(self):
def __str__(self):
out = super().__str__()
out = "Segmentation" + out[6:-1]
out = out + ", names=" + self.names if self.names is not None else out
out = (
out + ", num_classes=" + self.num_classes
if self.num_classes is not None
out + ", names=" + str(self.class_labels._names)
if self.class_labels._names is not None
else out
)
out = (
out + ", num_classes=" + str(self.class_labels._num_classes)
if self.class_labels._num_classes is not None
else out
)
out += ")"
Expand Down
25 changes: 25 additions & 0 deletions hub/schema/tests/test_features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from hub.schema import Segmentation
from hub.schema.class_label import ClassLabel, _load_names_from_file
from hub.schema.features import HubSchema, SchemaDict
import pytest
Expand Down Expand Up @@ -49,9 +50,33 @@ def test_feature_dict_repr():
assert expected_output == feature_dict_object.__repr__()


def test_segmentation_repr():
seg1 = Segmentation(shape=(3008, 3008), dtype="uint8", num_classes=5)
seg2 = Segmentation(
shape=(3008, 3008), dtype="uint8", names=["apple", "orange", "banana"]
)

text1 = "Segmentation(shape=(3008, 3008), dtype='uint8', num_classes=5)"
text2 = "Segmentation(shape=(3008, 3008), dtype='uint8', names=['apple', 'orange', 'banana'], num_classes=3)"
assert seg1.__repr__() == text1
assert seg2.__repr__() == text2


def test_classlabel_repr():
cl1 = ClassLabel(num_classes=5)
cl2 = ClassLabel(names=["apple", "orange", "banana"])

text1 = "ClassLabel(shape=(), dtype='int64', num_classes=5)"
text2 = "ClassLabel(shape=(), dtype='int64', names=['apple', 'orange', 'banana'], num_classes=3)"
assert cl1.__repr__() == text1
assert cl2.__repr__() == text2


if __name__ == "__main__":
test_load_names_from_file()
test_class_label()
test_hub_feature_flatten()
test_feature_dict_str()
test_feature_dict_repr()
test_classlabel_repr()
test_segmentation_repr()

0 comments on commit 546a797

Please sign in to comment.