Skip to content

Commit

Permalink
fixed bug with tiling with kornia update to 0.7.2
Browse files Browse the repository at this point in the history
  • Loading branch information
franioli committed Mar 15, 2024
1 parent 78ca0e2 commit c9e001a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"opencv-python",
"opencv-contrib-python",
"pydegensac",
"kornia>=0.7.1",
"kornia>=0.7.2",
"h5py",
"tqdm",
"easydict",
Expand Down
8 changes: 4 additions & 4 deletions src/deep_image_matching/utils/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import torch


def recent_konria_version(base_version: str = "0.7.1"):
def konria_071(base_version: str = "0.7.1"):
try:
from packaging import version
except ImportError:
return False
return version.parse(K.__version__) >= version.parse(base_version)
return version.parse(K.__version__) == version.parse(base_version)


# TODO: add possibility to specify the number of rows and columns in the grid
Expand Down Expand Up @@ -131,7 +131,7 @@ def compute_tiles_by_size(
patches = patches.squeeze(0)

# Compute number of rows and columns
if recent_konria_version():
if konria_071():
n_rows = (H + 2 * padding[0] - window_size[0]) // stride[0] + 1
n_cols = (W + 2 * padding[1] - window_size[1]) // stride[1] + 1
else:
Expand All @@ -143,7 +143,7 @@ def compute_tiles_by_size(
for row in range(n_rows):
for col in range(n_cols):
tile_idx = np.ravel_multi_index((row, col), (n_rows, n_cols), order="C")
if recent_konria_version():
if konria_071():
x = -padding[1] + col * stride[1]
y = -padding[0] + row * stride[0]
else:
Expand Down
20 changes: 10 additions & 10 deletions tests/test_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ def tiler():
return Tiler()


def recent_konria_version(base_version: str = "0.7.1"):
def konria_071(base_version: str = "0.7.1"):
try:
from packaging import version
except ImportError:
return False
return version.parse(kornia.__version__) >= version.parse(base_version)
return version.parse(kornia.__version__) == version.parse(base_version)


def test_compute_tiles_by_size_no_overlap_no_padding(tiler):
Expand All @@ -33,7 +33,7 @@ def test_compute_tiles_by_size_no_overlap_no_padding(tiler):
assert isinstance(tiles, dict)
assert isinstance(origins, dict)
assert isinstance(padding, tuple)
if recent_konria_version():
if konria_071():
assert len(padding) == 2
else:
assert len(padding) == 4
Expand All @@ -47,7 +47,7 @@ def test_compute_tiles_by_size_no_overlap_no_padding(tiler):
assert tile.shape == (window_size, window_size, 3)

# Assert the padding values
if recent_konria_version():
if konria_071():
assert padding == (0, 0)
else:
assert padding == (0, 0, 0, 0)
Expand All @@ -68,7 +68,7 @@ def test_compute_tiles_by_size_no_overlap_padding(tiler):
assert isinstance(tiles, dict)
assert isinstance(origins, dict)
assert isinstance(padding, tuple)
if recent_konria_version():
if konria_071():
assert len(padding) == 2
else:
assert len(padding) == 4
Expand All @@ -81,7 +81,7 @@ def test_compute_tiles_by_size_no_overlap_padding(tiler):
assert tile.shape == (window_size, window_size, 3)

# Assert the padding values
if recent_konria_version():
if konria_071():
assert padding == (10, 10)
else:
assert padding == (10, 10, 10, 10)
Expand All @@ -102,7 +102,7 @@ def test_compute_tiles_by_size_overlap_no_padding(tiler):
assert isinstance(tiles, dict)
assert isinstance(origins, dict)
assert isinstance(padding, tuple)
if recent_konria_version():
if konria_071():
assert len(padding) == 2
else:
assert len(padding) == 4
Expand All @@ -115,7 +115,7 @@ def test_compute_tiles_by_size_overlap_no_padding(tiler):
assert tile.shape == (window_size, window_size, 3)

# Assert the padding values
if recent_konria_version():
if konria_071():
assert padding == (0, 0)
else:
assert padding == (0, 0, 0, 0)
Expand All @@ -137,7 +137,7 @@ def test_compute_tiles_by_size_with_torch_tensor(tiler):
assert isinstance(tiles, dict)
assert isinstance(origins, dict)
assert isinstance(padding, tuple)
if recent_konria_version():
if konria_071():
assert len(padding) == 2
else:
assert len(padding) == 4
Expand All @@ -150,7 +150,7 @@ def test_compute_tiles_by_size_with_torch_tensor(tiler):
assert tile.shape == (window_size[0], window_size[1], channels)

# Assert the padding values
if recent_konria_version():
if konria_071():
assert padding == (0, 0)
else:
assert padding == (0, 0, 0, 0)
Expand Down

0 comments on commit c9e001a

Please sign in to comment.