Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers_as_arg #22

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions geo_inference/config/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ arguments:
vec: False # Vector Conversion: bool
yolo: False # YOLO Conversion: bool
coco: False # COCO Conversion: bool
transformers : True
transformer_flip : False
transformer_rotate : True
device: "gpu" # cpu or gpu: str
gpu_id: 0
mgpu: False
Expand Down
27 changes: 26 additions & 1 deletion geo_inference/geo_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import threading
import numpy as np
import xarray as xr
import ttach as tta
from typing import Dict
from dask import config
import dask.array as da
Expand Down Expand Up @@ -83,6 +84,9 @@ def __init__(
gpu_id: int = 0,
num_classes: int = 5,
prediction_threshold : float = 0.3,
transformers : bool = False,
transformer_flip: bool = False,
transformer_rotate: bool = False,
):
self.work_dir: Path = get_directory(work_dir)
self.device = (
Expand All @@ -95,6 +99,23 @@ def __init__(
),
map_location=self.device,
)
if transformers:
if transformer_flip and transformer_rotate: # do all
transforms = tta.aliases.d4_transform()
elif transformer_rotate: # do rotate only
transforms = tta.Compose(
[
tta.Rotate90(angles=[90]),
]
)
elif transformer_flip: # do flip only
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
]
)
self.model = tta.SegmentationTTAWrapper(self.model, transforms, merge_mode='mean')
self.mask_to_vec = mask_to_vec
self.mask_to_coco = mask_to_coco
self.mask_to_yolo = mask_to_yolo
Expand Down Expand Up @@ -363,7 +384,10 @@ def main() -> None:
device=arguments["device"],
gpu_id=arguments["gpu_id"],
num_classes=arguments["classes"],
prediction_threshold=arguments["prediction_threshold"]
prediction_threshold=arguments["prediction_threshold"],
transformers=arguments["transformers"],
transformer_flip=arguments["transformer_flip"],
transformer_rotate=arguments["transformer_rotate"],
)
inference_mask_layer_name = geo_inference(
inference_input=arguments["image"],
Expand All @@ -372,6 +396,7 @@ def main() -> None:
workers=arguments["workers"],
bbox=arguments["bbox"],
)
print(inference_mask_layer_name)



Expand Down
15 changes: 15 additions & 0 deletions geo_inference/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ def cmd_interface(argv=None):

parser.add_argument("-pr", "--prediction_thr", type=float, nargs=1, help="Prediction Threshold")

parser.add_argument("-tr", "--transformers", nargs=1, help="Transformers Addition")
parser.add_argument("-tr_f", "--transformer_flip", nargs=1, help="Transformers Addition - Flip")
parser.add_argument("-tr_e", "--transformer_rotate", nargs=1, help="Transformers Addition - Rotate")

args = parser.parse_args()

if args.args:
Expand All @@ -452,6 +456,10 @@ def cmd_interface(argv=None):
classes = config["arguments"]["classes"]
patch_size = config["arguments"]["patch_size"]
prediction_threshold = config["arguments"]["prediction_thr"]
transformers = config["arguments"]["transformers"]
transformer_flip = config["arguments"]["transformer_flip"]
transformer_rotate = config["arguments"]["transformer_rotate"]

elif args.image:
image =args.image[0]
model = args.model[0] if args.model else None
Expand All @@ -468,6 +476,10 @@ def cmd_interface(argv=None):
classes = args.classes[0] if args.classes else 5
patch_size = args.patch_size[0] if args.patch_size else 1024
prediction_threshold = args.prediction_thr[0] if args.prediction_thr else 0.3
transformers = args.transformers[0] if args.transformers else False
transformer_flip = args.transformer_flip if args.transformer_flip else False
transformer_rotate = args.transformer_rotate if args.transformer_rotate else False

else:
print("use the help [-h] option for correct usage")
raise SystemExit
Expand All @@ -487,6 +499,9 @@ def cmd_interface(argv=None):
"gpu_id": gpu_id,
"patch_size": patch_size,
"prediction_threshold": prediction_threshold,
"transformers": transformers,
"transformer_flip": transformer_flip,
"transformer_rotate":transformer_rotate,
}
return arguments

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ dask[distributed]>=2024.6.2
requests>=2.32.3
xarray>=2024.6.0
pystac>=1.10.1
rioxarray>=0.15.6
rioxarray>=0.15.6
ttach>=0.0.3
3 changes: 3 additions & 0 deletions tests/data/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ arguments:
classes : 5
n_workers: 20
prediction_thr : 0.3
transformers: False
transformer_flip : False
transformer_rotate : False
patch_size: 1024
9 changes: 9 additions & 0 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def test_read_yaml(test_data_dir):
"mgpu": False,
"classes": 5,
"prediction_thr": 0.3,
"transformers": False,
"transformer_flip" : False,
"transformer_rotate" : False,
"n_workers": 20
}

Expand Down Expand Up @@ -145,6 +148,9 @@ def test_cmd_interface_with_args(monkeypatch, test_data_dir):
"classes": 5,
"multi_gpu": False,
"prediction_threshold": 0.3,
"transformers": False,
"transformer_flip" : False,
"transformer_rotate" : False,
"patch_size": 1024
}

Expand All @@ -169,6 +175,9 @@ def test_cmd_interface_with_image(monkeypatch):
"gpu_id": 0,
"classes": 5,
"prediction_threshold": 0.3,
"transformers": False,
"transformer_flip" : False,
"transformer_rotate" : False,
"multi_gpu": False,
}

Expand Down
Loading