Skip to content

Commit

Permalink
Merge pull request #56 from AllenInstitute/cleanup
Browse files Browse the repository at this point in the history
cleanup docstrings & imports
  • Loading branch information
RussTorres authored Mar 29, 2024
2 parents 18193c2 + 072923f commit eadf8e5
Show file tree
Hide file tree
Showing 19 changed files with 443 additions and 106 deletions.
144 changes: 107 additions & 37 deletions em_stitch/lens_correction/lens_correction_solver.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import glob
import json
import logging
import os
import warnings

import cv2
import numpy as np

from argschema import ArgSchemaParser
from bigfeta import jsongz
import renderapi

from .schemas import LensCorrectionSchema
from ..utils.generate_EM_tilespecs_from_metafile import \
GenerateEMTileSpecsModule
from ..utils.generate_EM_tilespecs_from_metafile import (
GenerateEMTileSpecsModule)
from ..utils import utils as common_utils
from .mesh_and_solve_transform import MeshAndSolveTransform
from . import utils
from bigfeta import jsongz
import logging
import os
import glob
import json
import numpy as np
import renderapi
import cv2
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

logger = logging.getLogger(__name__)
Expand All @@ -39,6 +43,26 @@ class LensCorrectionException(Exception):


def one_file(fdir, fstub):
"""
Get a single file matching a directory and filename pattern.
Parameters
----------
fdir : str
The directory where the file should be located.
fstub : str
The filename pattern to match.
Returns
-------
str
The full path of the single file matching the pattern.
Raises
------
LensCorrectionException
If no file or more than one file is found matching the pattern.
"""
fullstub = os.path.join(fdir, fstub)
files = glob.glob(fullstub)
lf = len(files)
Expand All @@ -52,6 +76,27 @@ def one_file(fdir, fstub):

def tilespec_input_from_metafile(
metafile, mask_file, output_dir, log_level, compress):
"""
get tilespec input data from a metafile.
Parameters
----------
metafile : str
Path to the metafile.
mask_file : str
Path to the mask file.
output_dir : str
Directory where the output will be stored.
log_level : int
Log level for the operation.
compress : bool
Whether to compress the output.
Returns
-------
Dict[str, Union[str, int, bool]]
A dictionary containing the generated tilespec input data.
"""
result = {}
result['metafile'] = metafile

Expand All @@ -75,6 +120,35 @@ def filter_match_collection(
n_clusters=None, n_cluster_pts=20, ransacReprojThreshold=40.,
ignore_match_indices=(),
input_n_key="n_from_gpu", output_n_key="n_after_filter"):
"""
Filter a collection of matches based on specified criteria.
Parameters
----------
matches : List[Dict[str, Any]]
The collection of matches to filter.
threshold : float
Threshold value.
model : str, optional
Model type, by default "Similarity".
n_clusters : int, optional
Number of clusters, by default None.
n_cluster_pts : int, optional
Number of cluster points, by default 20.
ransacReprojThreshold : float, optional
RANSAC reprojection threshold, by default 40.0.
ignore_match_indices : Optional[Iterable[int]], optional
Indices of matches to ignore, by default None.
input_n_key : str, optional
Key to store input count in counts dictionary, by default "n_from_gpu".
output_n_key : str, optional
Key to store output count in counts dictionary, by default "n_after_filter".
Returns
-------
Tuple[List[Dict[str, Any]], List[Dict[str, int]]]
A tuple containing the filtered matches and their corresponding counts.
"""
ignore_match_indices = (set() if ignore_match_indices is None else ignore_match_indices)
ignore_match_indices = set(ignore_match_indices)

Expand Down Expand Up @@ -115,7 +189,27 @@ def make_collection_json(
thresh, # FIXME thresh not used in this version
compress,
ignore_match_indices=None):

"""
Create a JSON collection file from a template file.
Parameters
----------
template_file : str
Path to the template file.
output_dir : str
Directory where the output will be stored.
thresh : float
Threshold value.
compress : bool
Whether to compress the output.
ignore_match_indices : Optional[List[int]], optional
Indices of matches to ignore, by default None.
Returns
-------
Tuple[str, List[Dict[str, int]]]
A tuple containing the path to the collection file and a list of counts.
"""
with open(template_file, 'r') as f:
template_match_md = json.load(f)

Expand All @@ -124,31 +218,7 @@ def make_collection_json(
m, counts = filter_match_collection(
input_matches, thresh,
ignore_match_indices=ignore_match_indices
)

# counts = []
# for m in template_match_md['collection']:
# counts.append({})
# ind = np.arange(len(m['matches']['p'][0]))
# counts[-1]['n_from_gpu'] = ind.size

# _, _, w, _ = common_utils.pointmatch_filter(
# m,
# n_clusters=None,
# n_cluster_pts=20,
# ransacReprojThreshold=40.0,
# model='Similarity')

# m['matches']['w'] = w.tolist()

# counts[-1]['n_after_filter'] = np.count_nonzero(w)

# m = matches['collection']

# if ignore_match_indices:
# m = [match for i, match in enumerate(matches['collection'])
# if i not in ignore_match_indices]
# logger.warning("you are ignoring some point matches")
)

collection_file = os.path.join(output_dir, "collection.json")
collection_file = jsongz.dump(m, collection_file, compress=compress)
Expand Down
24 changes: 14 additions & 10 deletions em_stitch/lens_correction/mesh_and_solve_transform.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import copy
import datetime
import logging
import os

from six.moves import urllib

import cv2
import numpy as np
import triangle
import scipy.optimize
from scipy.spatial import Delaunay
import scipy.sparse as sparse
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import factorized
import renderapi
import copy
import os
import datetime
import cv2
from six.moves import urllib
from scipy.spatial import Delaunay
import triangle

from argschema import ArgSchemaParser
from bigfeta import jsongz
import renderapi

from .schemas import MeshLensCorrectionSchema
from .utils import remove_weighted_matches
from bigfeta import jsongz
import logging

try:
# pandas unique is faster than numpy, use where appropriate
Expand Down
7 changes: 5 additions & 2 deletions em_stitch/lens_correction/schemas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import warnings

from marshmallow.warnings import ChangedInMarshmallow3Warning
import marshmallow as mm

from argschema import ArgSchema
from argschema.schemas import DefaultSchema
from argschema.fields import (
Boolean, InputDir, InputFile, Float, List,
Int, OutputDir, Nested, Str, Dict)
Boolean, InputDir, InputFile, Float, List,
Int, OutputDir, Nested, Str, Dict)

warnings.simplefilter(
action='ignore',
category=ChangedInMarshmallow3Warning)
Expand Down
77 changes: 72 additions & 5 deletions em_stitch/lens_correction/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
import numpy as np
import cv2
import renderapi
import time
from scipy import ndimage
import logging
from ..utils import utils as common_utils

import cv2
import numpy as np
from scipy import ndimage

import renderapi

from ..utils import utils as common_utils

logger = logging.getLogger(__name__)


def split_inverse_tform(tform, src, block_size):
"""
Split and inverse transform the source array using the provided transformation.
Parameters
----------
tform : Any
Transformation object.
src : np.ndarray
Source array.
block_size : int
Size of each block.
Returns
-------
np.ndarray
Inverse transformed array.
"""
nsplit = np.ceil(float(src.shape[0]) / float(block_size))
split_src = np.array_split(src, nsplit, axis=0)
dst = []
Expand All @@ -21,6 +40,28 @@ def split_inverse_tform(tform, src, block_size):


def maps_from_tform(tform, width, height, block_size=10000, res=32):
"""
Generate maps and a mask for remapping based on the provided transformation.
Parameters
----------
tform : Any
Transformation object.
width : int
Width of the map.
height : int
Height of the map.
block_size : int, optional
Size of each block, by default 10000.
res : int, optional
cell resolution, by default 32.
Returns
-------
Tuple[np.ndarray, np.ndarray, np.ndarray]
Tuple containing map1, map2, and mask arrays.
"""
t0 = time.time()

x = np.arange(0, width + res, res)
Expand Down Expand Up @@ -53,6 +94,22 @@ def maps_from_tform(tform, width, height, block_size=10000, res=32):


def estimate_stage_affine(t0, t1):
"""
Estimate affine transformation between two sets of translations.
Parameters
----------
t0 : Iterable
List of transformations (t0).
t1 : Iterable
List of transformations (t1).
Returns
-------
renderapi.transform.AffineModel
Estimated affine transformation.
"""
src = np.array([t.tforms[0].translation for t in t0])
dst = np.array([t.tforms[1].translation for t in t1])
aff = renderapi.transform.AffineModel()
Expand All @@ -61,6 +118,16 @@ def estimate_stage_affine(t0, t1):


def remove_weighted_matches(matches, weight=0.0):
"""
Remove matches with specified weight.
Parameters
----------
matches : List[Dict[str, Any]]
List of matches.
weight : float, optional
Weight threshold, by default 0.0.
"""
for m in matches:
ind = np.invert(np.isclose(np.array(m['matches']['w']), weight))
m['matches']['p'] = np.array(m['matches']['p'])[:, ind].tolist()
Expand Down
2 changes: 1 addition & 1 deletion em_stitch/montage/meta_to_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
'''

import argparse
from enum import IntEnum
import glob
import json
import os
import sys
from enum import IntEnum


# Position codes in metafile
Expand Down
Loading

0 comments on commit eadf8e5

Please sign in to comment.