Skip to content

Commit

Permalink
Remove diaSources with NaN centroids
Browse files Browse the repository at this point in the history
Downstream code will break in multiple places if there is a single diaSource with a NaN value from .getCentroid()
  • Loading branch information
isullivan committed Dec 15, 2023
1 parent 6cd1a0d commit fa1262a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/lsst/ip/diffim/detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,17 @@ def removeBadSources(self, diaSources):
self.log.info("Found and removed %d unphysical sources with flag %s.", nBad, flag)
selector &= ~flags
nBadTotal += nBad
diaSources = diaSources[selector].copy(deep=True)
self.metadata.add("nRemovedBadFlaggedSources", nBadTotal)
return diaSources
# Use slot_Centroid_x/y here instead of getX() method, since the former
# works on non-contiguous source tables and the latter does not.
centroidFlag = np.isfinite(diaSources["slot_Centroid_x"]) & np.isfinite(diaSources["slot_Centroid_y"])
nBad = np.count_nonzero(~centroidFlag)
if nBad > 0:
self.log.info("Found and removed %d unphysical sources with non-finite centroid.", nBad)
self.metadata.add("nRemovedBadCentroidSources", nBadTotal)
nBadTotal += nBad
selector &= centroidFlag
return diaSources[selector].copy(deep=True)

def addSkySources(self, diaSources, mask, seed):
"""Add sources in empty regions of the difference image
Expand Down
49 changes: 49 additions & 0 deletions tests/test_detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,55 @@ def test_remove_unphysical(self):
self.assertEqual(nBad2, 0)
self.assertEqual(len(diaSources2), len(diaSources) - nSetBad)

def test_remove_nan_centroid(self):
"""Check that sources with non-finite centroids are removed from the catalog.
"""
# Set up the simulated images
noiseLevel = 1.
staticSeed = 1
xSize = 256
ySize = 256
kwargs = {"psfSize": 2.4, "xSize": xSize, "ySize": ySize}
science, sources = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel, noiseSeed=6,
nSrc=1, **kwargs)
matchedTemplate, _ = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel/4, noiseSeed=7,
nSrc=1, **kwargs)
difference = science.clone()
bbox = difference.getBBox()
difference.maskedImage -= matchedTemplate.maskedImage

# Configure the detection Task, and do not remove unphysical sources
detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20,
badSourceFlags=["base_PixelFlags_flag_offimage", ])

# Run detection and check the results
diaSources = detectionTask.run(science, matchedTemplate, difference).diaSources
badDiaSrc0 = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBad0 = np.count_nonzero(badDiaSrc0)
# Verify that all sources are physical
self.assertEqual(nBad0, 0)
# Set a few centroids outside the image bounding box
nSetBad = 5
for i, src in enumerate(diaSources[0: nSetBad]):
if i % 3 == 0:
src["slot_Centroid_x"] = np.nan
elif i % 3 == 1:
src["slot_Centroid_y"] = np.nan
elif i % 3 == 2:
src["slot_Centroid_x"] = np.nan
src["slot_Centroid_y"] = np.nan
# Verify that these sources are outside the image
badDiaSrc1 = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBad1 = np.count_nonzero(badDiaSrc1)
self.assertEqual(nBad1, nSetBad)
diaSources2 = detectionTask.removeBadSources(diaSources)
badDiaSrc2 = ~bbox.contains(diaSources2.getX(), diaSources2.getY())
nBad2 = np.count_nonzero(badDiaSrc2)

# Verify that no sources outside the image bounding box remain
self.assertEqual(nBad2, 0)
self.assertEqual(len(diaSources2), len(diaSources) - nSetBad)

def test_detect_transients(self):
"""Run detection on a difference image containing transients.
"""
Expand Down

0 comments on commit fa1262a

Please sign in to comment.