Skip to content

Commit

Permalink
Update the comment of the num_classes parameter of deeplab v3 (#2071)
Browse files Browse the repository at this point in the history
* Update deeplab_v3_plus.py

Update the comment of the `num_classes`parameter which contains the background class and the classes from the data.

* Update deeplab_v3_plus_test.py

Update the tests following the updating of 'num_classes' parameter (now including the background class)
  • Loading branch information
aaudevart authored Sep 12, 2023
1 parent 144fbb6 commit 5cdae3b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DeepLabV3Plus(Task):
somewhat sensible backbone to use in many cases is the
`keras_cv.models.ResNet50V2Backbone.from_preset("resnet50_v2_imagenet")`.
num_classes: int, the number of classes for the detection model. Note
that the `num_classes` doesn't contain the background class, and the
that the `num_classes` contains the background class, and the
classes from the data should be represented by integers with range
[0, `num_classes`).
projection_filters: int, number of filters in the convolution layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class DeepLabV3PlusTest(TestCase):
def test_deeplab_v3_plus_construction(self):
backbone = ResNet18V2Backbone(input_shape=[512, 512, 3])
model = DeepLabV3Plus(backbone=backbone, num_classes=1)
model = DeepLabV3Plus(backbone=backbone, num_classes=2)
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(),
Expand All @@ -42,7 +42,7 @@ def test_deeplab_v3_plus_construction(self):
@pytest.mark.large
def test_deeplab_v3_plus_call(self):
backbone = ResNet18V2Backbone(input_shape=[512, 512, 3])
model = DeepLabV3Plus(backbone=backbone, num_classes=1)
model = DeepLabV3Plus(backbone=backbone, num_classes=2)
images = np.random.uniform(size=(2, 512, 512, 3))
_ = model(images)
_ = model.predict(images)
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_saved_model(self, save_format, filename):
target_size = [512, 512, 3]

backbone = ResNet18V2Backbone(input_shape=target_size)
model = DeepLabV3Plus(backbone=backbone, num_classes=1)
model = DeepLabV3Plus(backbone=backbone, num_classes=2)

input_batch = np.ones(shape=[2] + target_size)
model_output = model(input_batch)
Expand Down

0 comments on commit 5cdae3b

Please sign in to comment.