From 415cdb768478370916b12c7229437d6feaf239b5 Mon Sep 17 00:00:00 2001 From: landmanbester Date: Mon, 26 Aug 2024 20:42:33 +0200 Subject: [PATCH] flip_v everywhere and correct fits crpix2 convention --- pfb/operators/gridder.py | 225 ++++----- pfb/operators/hessian.py | 176 +++---- pfb/opt/pcg.py | 4 +- pfb/utils/dist.py | 8 +- pfb/utils/fits.py | 22 +- pfb/utils/misc.py | 39 +- pfb/utils/stokes2im.py | 27 +- pfb/utils/weighting.py | 13 +- pfb/workers/degrid.py | 5 - pfb/workers/fluxmop.py | 2 - pfb/workers/fwdbwd.py | 946 +++++++++++++++++------------------ pfb/workers/grid.py | 29 +- pfb/workers/klean.py | 5 +- pfb/workers/model2comps.py | 3 + pfb/workers/sara.py | 4 +- tests/test_hessian_approx.py | 193 ++++--- tests/test_klean.py | 7 +- tests/test_sara.py | 65 +-- 18 files changed, 876 insertions(+), 897 deletions(-) diff --git a/pfb/operators/gridder.py b/pfb/operators/gridder.py index 8b3cb7873..88cb3181f 100644 --- a/pfb/operators/gridder.py +++ b/pfb/operators/gridder.py @@ -10,7 +10,7 @@ import xarray as xr import dask import dask.array as da -from ducc0.wgridder import vis2dirty, dirty2vis +from ducc0.wgridder.experimental import vis2dirty, dirty2vis from ducc0.fft import c2r, r2c, c2c from africanus.constants import c as lightspeed from quartical.utils.dask import Blocker @@ -22,6 +22,20 @@ Fs = np.fft.fftshift +def wgridder_conventions(l0, m0): + ''' + Returns + + flip_u, flip_v, flip_w, x0, y0 + + according to the conventions documented here https://github.com/mreineck/ducc/issues/34 + + Note that these conventions are stored as dataset attributes in order + to call the operators acting on datasets with a consistent convention. + ''' + return False, True, False, -l0, -m0 + + def vis2im(uvw, freq, vis, @@ -29,10 +43,9 @@ def vis2im(uvw, mask, nx, ny, cellx, celly, - x0, y0, + l0, m0, epsilon, precision, - flip_v, do_wgridding, divide_by_n, nthreads, @@ -56,6 +69,8 @@ def vis2im(uvw, if mask is not None: mask = np.require(mask, dtype=np.uint8) + flip_u, flip_v, flip_w, x0, y0 = wgridder_conventions(l0, m0) + return vis2dirty(uvw=uvw, freq=freq, vis=vis, @@ -65,7 +80,9 @@ def vis2im(uvw, pixsize_x=cellx, pixsize_y=celly, center_x=x0, center_y=y0, epsilon=epsilon, + flip_u=flip_u, flip_v=flip_v, + flip_w=flip_w, do_wgridding=do_wgridding, divide_by_n=divide_by_n, nthreads=nthreads, @@ -80,15 +97,15 @@ def im2vis(uvw, celly, freq_bin_idx, freq_bin_counts, - x0=0, y0=0, + l0=0, m0=0, epsilon=1e-7, - flip_v=False, do_wgridding=True, divide_by_n=False, nthreads=1): # adjust for chunking # need a copy here if using multiple row chunks freq_bin_idx2 = freq_bin_idx - freq_bin_idx.min() + flip_u, flip_v, flip_w, x0, y0 = wgridder_conventions(l0, m0) nband, nx, ny = image.shape nrow = uvw.shape[0] nchan = freq.size @@ -103,12 +120,13 @@ def im2vis(uvw, pixsize_y=celly, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, nthreads=nthreads, do_wgridding=do_wgridding, - divide_by_n=divide_by_n, - flip_v=flip_v - ) + divide_by_n=divide_by_n) return vis @@ -121,14 +139,9 @@ def comps2vis( tbin_idx, tbin_cnts, fbin_idx, fbin_cnts, mds, - # comps, - # Ix, Iy, modelf, tfunc, ffunc, - # nx, ny, - # cellx, celly, - # x0=0, y0=0, epsilon=1e-7, nthreads=1, do_wgridding=True, @@ -151,18 +164,9 @@ def comps2vis( fbin_idx, 'f', fbin_cnts, 'f', mds, None, - # comps, None, - # Ix, None, - # Iy, None, modelf, None, tfunc, None, ffunc, None, - # nx, None, - # ny, None, - # cellx, None, - # celly, None, - # x0, None, - # y0, None, epsilon, None, nthreads, None, do_wgridding, None, @@ -185,14 +189,9 @@ def _comps2vis( tbin_idx, tbin_cnts, fbin_idx, fbin_cnts, mds, - # comps, - # Ix, Iy, modelf, tfunc, ffunc, - # nx, ny, - # cellx, celly, - # x0=0, y0=0, epsilon=1e-7, nthreads=1, do_wgridding=True, @@ -208,14 +207,9 @@ def _comps2vis( tbin_idx, tbin_cnts, fbin_idx, fbin_cnts, mds, - # comps, - # Ix, Iy, modelf, tfunc, ffunc, - # nx, ny, - # cellx, celly, - # x0=x0, y0=y0, epsilon=epsilon, nthreads=nthreads, do_wgridding=do_wgridding, @@ -233,14 +227,9 @@ def _comps2vis_impl(uvw, tbin_idx, tbin_cnts, fbin_idx, fbin_cnts, mds, - # comps, - # Ix, Iy, modelf, tfunc, ffunc, - # nx, ny, - # cellx, celly, - # x0=0, y0=0, epsilon=1e-7, nthreads=1, do_wgridding=True, @@ -271,8 +260,12 @@ def _comps2vis_impl(uvw, celly = mds.cell_rad_x nx = mds.npix_x ny = mds.npix_y + # these are taken from dataset attrs to make sure they remain consistent x0 = mds.center_x y0 = mds.center_y + flip_u = mds.flip_u + flip_v = mds.flip_v + flip_w = mds.flip_v for t in range(ntime): indt = slice(tbin_idx2[t], tbin_idx2[t] + tbin_cnts[t]) # TODO - clean up this logic. row_mapping holds the number of rows per @@ -294,6 +287,9 @@ def _comps2vis_impl(uvw, dirty=image, pixsize_x=cellx, pixsize_y=celly, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, do_wgridding=do_wgridding, divide_by_n=divide_by_n, @@ -312,12 +308,11 @@ def image_data_products(dsl, attrs, model=None, robustness=None, - x0=0.0, y0=0.0, + l0=0.0, m0=0.0, nthreads=1, epsilon=1e-7, do_wgridding=True, double_accum=True, - # divide_by_n=False, l2reweight_dof=None, do_dirty=True, do_psf=True, @@ -339,6 +334,8 @@ def image_data_products(dsl, sum of beam ''' + flip_u, flip_v, flip_w, x0, y0 = wgridder_conventions(l0, m0) + # TODO - assign ug,vg-coordinates x = (-nx/2 + np.arange(nx)) * cellx + x0 y = (-ny/2 + np.arange(ny)) * celly + y0 @@ -347,6 +344,7 @@ def image_data_products(dsl, 'y': y } + # expects a list if isinstance(dsl, str): dsl = [dsl] @@ -404,28 +402,27 @@ def image_data_products(dsl, center_y=y0, epsilon=epsilon, do_wgridding=do_wgridding, - flip_v=False, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, nthreads=nthreads, - divide_by_n=False, + divide_by_n=False, # incorporte in smooth beam sigma_min=1.1, sigma_max=3.0) residual_vis *= -1 # negate model residual_vis += vis - # apply mask for reweighting - # residual_vis *= mask if l2reweight_dof: - # careful mask needs to be bool here ressq = (residual_vis*residual_vis.conj()).real + # mask needs to be bool here ssq = ressq[mask>0].sum() ovar = ssq/mask.sum() - # ovar = np.var(residual_vis[mask]) chi2_dofp = np.mean(ressq[mask>0]*wgt[mask>0]) mean_dev = np.mean(ressq[mask>0]/ovar) if ovar: wgt = (l2reweight_dof + 1)/(l2reweight_dof + ressq/ovar) # now divide by ovar to scale to absolute units - # the chi2_dof after reweighting should be close to one + # the chi2_dof after reweighting should be closer to one wgt /= ovar chi2_dof = np.mean(ressq[mask>0]*wgt[mask>0]) print(f'Band {bandid} chi2-dof changed from {chi2_dofp} to {chi2_dof} with mean deviation of {mean_dev}') @@ -441,14 +438,18 @@ def image_data_products(dsl, nx, ny, cellx, celly, uvw.dtype, - ngrid=np.minimum(nthreads, 8)) # limit number of grids + ngrid=np.minimum(nthreads, 8), # limit number of grids + usign=1.0 if flip_u else -1.0, + vsign=1.0 if flip_v else -1.0) imwgt = counts_to_weights( counts, uvw, freq, nx, ny, cellx, celly, - robustness) + robustness, + usign=1.0 if flip_u else -1.0, + vsign=1.0 if flip_v else -1.0) if wgt is not None: wgt *= imwgt else: @@ -474,9 +475,11 @@ def image_data_products(dsl, pixsize_x=cellx, pixsize_y=celly, center_x=x0, center_y=y0, epsilon=epsilon, - flip_v=False, # hardcoded for now + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, do_wgridding=do_wgridding, - divide_by_n=False, # hardcoded for now + divide_by_n=False, # incorporte in smooth beam nthreads=nthreads, sigma_min=1.1, sigma_max=3.0, double_precision_accumulation=double_accum) @@ -505,11 +508,13 @@ def image_data_products(dsl, pixsize_y=celly, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, do_wgridding=do_wgridding, nthreads=nthreads, - divide_by_n=False, - flip_v=False, # hardcoded for now + divide_by_n=False, # incorporte in smooth beam sigma_min=1.1, sigma_max=3.0) else: @@ -528,10 +533,12 @@ def image_data_products(dsl, npix_x=nx_psf, npix_y=ny_psf, pixsize_x=cellx, pixsize_y=celly, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, - flip_v=False, # hardcoded for now do_wgridding=do_wgridding, - divide_by_n=False, # hardcoded for now + divide_by_n=False, # incorporte in smooth beam nthreads=nthreads, sigma_min=1.1, sigma_max=3.0, double_precision_accumulation=double_accum) @@ -560,10 +567,12 @@ def image_data_products(dsl, npix_x=nx, npix_y=ny, pixsize_x=cellx, pixsize_y=celly, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, - flip_v=False, # hardcoded for now do_wgridding=do_wgridding, - divide_by_n=False, # hardcoded for now + divide_by_n=False, # incorporte in smooth beam nthreads=nthreads, sigma_min=1.1, sigma_max=3.0, double_precision_accumulation=double_accum) @@ -589,10 +598,12 @@ def image_data_products(dsl, npix_x=nx, npix_y=ny, pixsize_x=cellx, pixsize_y=celly, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, - flip_v=False, # hardcoded for now do_wgridding=do_wgridding, - divide_by_n=False, # hardcoded for now + divide_by_n=False, # incorporte in smooth beam nthreads=nthreads, sigma_min=1.1, sigma_max=3.0, double_precision_accumulation=double_accum) @@ -605,7 +616,10 @@ def image_data_products(dsl, dso['BEAM'] = (('x', 'y'), np.ones((nx, ny), dtype=wgt.dtype)) # save - dso = dso.assign_attrs(wsum=wsum) + dso = dso.assign_attrs(wsum=wsum, x0=x0, y0=y0, l0=l0, m0=m0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w) dso.to_zarr(output_name, mode='a') # return residual to report stats @@ -621,7 +635,6 @@ def compute_residual(dsl, cellx, celly, output_name, model, - x0=0.0, y0=0.0, nthreads=1, epsilon=1e-7, do_wgridding=True, @@ -629,7 +642,6 @@ def compute_residual(dsl, ''' Function to compute residual and write it to disk ''' - # expects a list if isinstance(dsl, str): dsl = [dsl] @@ -643,6 +655,11 @@ def compute_residual(dsl, beam = ds.BEAM.values dirty = ds.DIRTY.values freq = ds.FREQ.values + flip_u = ds.flip_u + flip_v = ds.flip_v + flip_w = ds.flip_w + x0 = ds.x0 + y0 = ds.y0 # do not apply weights in this direction model_vis = dirty2vis( @@ -653,11 +670,13 @@ def compute_residual(dsl, pixsize_y=celly, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, do_wgridding=do_wgridding, - flip_v=False, nthreads=nthreads, - divide_by_n=False, + divide_by_n=False, # incorporate in smooth beam sigma_min=1.1, sigma_max=3.0) @@ -670,10 +689,12 @@ def compute_residual(dsl, npix_x=nx, npix_y=ny, pixsize_x=cellx, pixsize_y=celly, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, - flip_v=False, # hardcoded for now do_wgridding=do_wgridding, - divide_by_n=False, # hardcoded for now + divide_by_n=False, # incorporate in smooth beam nthreads=nthreads, sigma_min=1.1, sigma_max=3.0, double_precision_accumulation=double_accum) @@ -689,77 +710,3 @@ def compute_residual(dsl, ds.to_zarr(output_name, mode='a') return residual - - -from pfb.operators.hessian import _hessian_impl -def estimate_epsilon(dsl, - nx, ny, - cellx, celly, - output_name, - x0=0.0, y0=0.0, - nthreads=1, - epsilon=1e-7, - do_wgridding=True, - double_accum=True): - ''' - Function to estimate epsilon required for stable preconditioning. - Done by minimising - - res.conj().T * res - - where - - res = (IR - hess(precond(IR)))[ix, iy] - - and ix and iy select out the image in the untapered part of the image. - - ''' - - def sse(sigma, IR, hess_approx, precond, slicex, slicey): - tmp = precond(IR, sigma) - res = hess_approx(res)[slicex, slicey] - return np.vdot(res, res) - - # expects a list - if isinstance(dsl, str): - dsl = [dsl] - - # currently only a single dds - ds = xds_from_list(dsl, nthreads=nthreads)[0] - - uvw = ds.UVW.values - wgt = ds.WEIGHT.values - mask = ds.MASK.values - beam = ds.BEAM.values - if 'RESIDUAL' in ds: - residual = ds.RESIDUAL.values - else: - residual = ds.DIRTY.values - freq = ds.FREQ.values - psfhat = ds.PSFHAT.values - wsum = ds.wsum - - hess = partial(_hessian_impl, - uvw=ds.UVW.values, - weight=ds.WEIGHT.values, - vis_mask=ds.MASK.values, - freq=ds.FREQ.values, - beam=ds.BEAM.values, - cell=ds.cell_rad, - x0=ds.x0, - y0=ds.y0, - do_wgridding=do_wgridding, - epsilon=epsilon, - double_accum=double_accum, - nthreads=nthreads, - sigmainvsq=0.0, - wsum=wsum) - - - - ds = ds.assign_attrs(sigma=sigma) - - # save - ds.to_zarr(output_name, mode='a') - - return sigma diff --git a/pfb/operators/hessian.py b/pfb/operators/hessian.py index 12c48ba1b..7741c7a92 100644 --- a/pfb/operators/hessian.py +++ b/pfb/operators/hessian.py @@ -1,7 +1,7 @@ import numpy as np import dask import dask.array as da -from ducc0.wgridder import vis2dirty, dirty2vis +from ducc0.wgridder.experimental import vis2dirty, dirty2vis from ducc0.fft import r2c, c2r from ducc0.misc import make_noncritical from uuid import uuid4 @@ -9,57 +9,6 @@ psf_convolve_cube) -def hessian_xds(x, xds, hessopts, wsum, sigmainv, mask, - compute=True, use_beam=True): - ''' - Vis space Hessian reduction over dataset. - Hessian will be applied to x - ''' - if not isinstance(x, da.Array): - x = da.from_array(x, chunks=(1, -1, -1), - name="x-" + uuid4().hex) - - if not isinstance(mask, da.Array): - mask = da.from_array(mask, chunks=(-1, -1), - name="mask-" + uuid4().hex) - - assert mask.ndim == 2 - - nband, nx, ny = x.shape - - # LB - what is the point of specifying name? - convims = [da.zeros((nx, ny), - chunks=(-1, -1), name="zeros-" + uuid4().hex) - for _ in range(nband)] - - for ds in xds: - wgt = ds.WEIGHT.data - vis_mask = ds.MASK.data - uvw = ds.UVW.data - freq = ds.FREQ.data - b = ds.bandid - if use_beam: - beam = ds.BEAM.data * mask - else: - # TODO - separate implementation without - # unnecessary beam application - beam = mask - - convim = hessian(x[b], uvw, wgt, vis_mask, freq, beam, hessopts) - - convims[b] += convim - - convim = da.stack(convims)/wsum - - if sigmainv: - convim += x * sigmainv**2 - - if compute: - return convim.compute() - else: - return convim - - def _hessian_impl(x, uvw=None, weight=None, @@ -68,6 +17,9 @@ def _hessian_impl(x, beam=None, x0=0.0, y0=0.0, + flip_u=False, + flip_v=True, + flip_w=False, cell=None, do_wgridding=None, epsilon=None, @@ -75,6 +27,17 @@ def _hessian_impl(x, nthreads=None, sigmainvsq=None, wsum=1.0): + ''' + Apply vis space Hessian approximation on a slice of an image. + + Important! + x0, y0, flip_u, flip_v and flip_w must be consistent with the + conventions defined in pfb.operators.gridder.wgridder_conventions + + These are inputs here to allow for testing but should generally be taken + from the attrs of the datasets produced by + pfb.operators.gridder.image_data_products + ''' if not x.any(): return np.zeros_like(x) nx, ny = x.shape @@ -86,27 +49,34 @@ def _hessian_impl(x, pixsize_y=cell, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, nthreads=nthreads, do_wgridding=do_wgridding, divide_by_n=False) - convim = vis2dirty(uvw=uvw, - freq=freq, - vis=mvis, - wgt=weight, - mask=vis_mask, - npix_x=nx, - npix_y=ny, - pixsize_x=cell, - pixsize_y=cell, - center_x=x0, - center_y=y0, - epsilon=epsilon, - nthreads=nthreads, - do_wgridding=do_wgridding, - double_precision_accumulation=double_accum, - divide_by_n=False) + convim = vis2dirty( + uvw=uvw, + freq=freq, + vis=mvis, + wgt=weight, + mask=vis_mask, + npix_x=nx, + npix_y=ny, + pixsize_x=cell, + pixsize_y=cell, + center_x=x0, + center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, + epsilon=epsilon, + nthreads=nthreads, + do_wgridding=do_wgridding, + double_precision_accumulation=double_accum, + divide_by_n=False) convim /= wsum if beam is not None: @@ -118,24 +88,25 @@ def _hessian_impl(x, return convim -def _hessian(x, uvw, weight, vis_mask, freq, beam, hessopts): - return _hessian_impl(x, uvw[0][0], weight[0][0], vis_mask[0][0], freq[0], - beam, **hessopts) +# Kept in case we need them in the future +# def _hessian(x, uvw, weight, vis_mask, freq, beam, hessopts): +# return _hessian_impl(x, uvw[0][0], weight[0][0], vis_mask[0][0], freq[0], +# beam, **hessopts) -def hessian(x, uvw, weight, vis_mask, freq, beam, hessopts): - if beam is None: - bout = None - else: - bout = ('nx', 'ny') - return da.blockwise(_hessian, ('nx', 'ny'), - x, ('nx', 'ny'), - uvw, ('row', 'three'), - weight, ('row', 'chan'), - vis_mask, ('row', 'chan'), - freq, ('chan',), - beam, bout, - hessopts, None, - dtype=x.dtype) +# def hessian(x, uvw, weight, vis_mask, freq, beam, hessopts): +# if beam is None: +# bout = None +# else: +# bout = ('nx', 'ny') +# return da.blockwise(_hessian, ('nx', 'ny'), +# x, ('nx', 'ny'), +# uvw, ('row', 'three'), +# weight, ('row', 'chan'), +# vis_mask, ('row', 'chan'), +# freq, ('chan',), +# beam, bout, +# hessopts, None, +# dtype=x.dtype) def _hessian_psf_slice( @@ -167,10 +138,10 @@ def _hessian_psf_slice( if wsum is not None: xout /= wsum - # if sigmainv: - # xout += x * sigmainv + if sigmainv: + xout += x * sigmainv - return xout + x * sigmainv + return xout def hessian_psf_cube( @@ -202,7 +173,10 @@ def hessian_psf_cube( if wsum is not None: xout /= wsum - return xout + x * sigmainv + if sigmainv: + xout += x * sigmainv + + return xout else: raise NotImplementedError @@ -220,7 +194,10 @@ def hess_direct(x, # input image, not overwritten mode='forward'): nband, nx, ny = x.shape xpad[...] = 0.0 - xpad[:, 0:nx, 0:ny] = x * taperxy[None] + if mode == 'forward': + xpad[:, 0:nx, 0:ny] = x / taperxy[None] + else: + xpad[:, 0:nx, 0:ny] = x * taperxy[None] r2c(xpad, out=xhat, axes=(1,2), forward=True, inorm=0, nthreads=nthreads) if mode=='forward': @@ -231,7 +208,11 @@ def hess_direct(x, # input image, not overwritten lastsize=lastsize, inorm=2, nthreads=nthreads, allow_overwriting_input=True) xout[...] = xpad[:, 0:nx, 0:ny] - return xout * taperxy[None] + if mode=='forward': + xout /= taperxy[None] + else: + xout *= taperxy[None] + return xout def hess_direct_slice(x, # input image, not overwritten @@ -254,16 +235,15 @@ def hess_direct_slice(x, # input image, not overwritten r2c(xpad, out=xhat, axes=(0,1), forward=True, inorm=0, nthreads=nthreads) if mode=='forward': - # xhat *= (psfhat + sigmainvsq) - xhat *= psfhat + xhat *= (psfhat + sigmainvsq) else: - # xhat /= (psfhat + sigmainvsq) - xhat /= psfhat + xhat /= (psfhat + sigmainvsq) c2r(xhat, axes=(0, 1), forward=False, out=xpad, lastsize=lastsize, inorm=2, nthreads=nthreads, allow_overwriting_input=True) xout[...] = xpad[0:nx, 0:ny] - if mode=='foward': - return xout / taperxy + if mode=='forward': + xout /= taperxy else: - return xout * taperxy + xout *= taperxy + return xout diff --git a/pfb/opt/pcg.py b/pfb/opt/pcg.py index ae797cdee..daf6bf197 100644 --- a/pfb/opt/pcg.py +++ b/pfb/opt/pcg.py @@ -292,7 +292,7 @@ def pcg_dds(ds_name, # set precond if PSF is present if 'PSFHAT' in ds and use_psf: - psfhat = np.abs(ds.PSFHAT.values)/wsum + sigma + psfhat = np.abs(ds.PSFHAT.values)/wsum ds.drop_vars(('PSFHAT')) nx_psf, nyo2 = psfhat.shape ny_psf = 2*(nyo2-1) # is this always the case? @@ -323,7 +323,7 @@ def pcg_dds(ds_name, taperxy=taperxy, lastsize=ny_psf, nthreads=nthreads, - sigmainvsq=1.0, # not used + sigmainvsq=sigma, mode='backward') x0 = precond(j) diff --git a/pfb/utils/dist.py b/pfb/utils/dist.py index adbc40ea1..43f864e5d 100644 --- a/pfb/utils/dist.py +++ b/pfb/utils/dist.py @@ -132,7 +132,12 @@ def __init__(self, xds_list, opts, bandid, cache_path, max_freq, uv_max): self.uv_max = uv_max nx, ny, nx_psf, ny_psf, cell_N, cell_rad = set_image_size(uv_max, max_freq, - opts) + opts.field_of_view, + opts.super_resolution_factor, + opts.cell_size, + opts.nx, + opts.ny, + opts.psf_oversize) cell_deg = np.rad2deg(cell_rad) cell_size = cell_deg * 3600 # print(f"Super resolution factor = {cell_N/cell_rad}", file=log) @@ -354,7 +359,6 @@ def set_residual(self, k, x=None): self.cell_rad, self.cell_rad, self.cache_path, # output_name (same as dsl names?) x, - x0=self.x0, y0=self.y0, nthreads=self.nthreads, epsilon=self.opts.epsilon, do_wgridding=self.opts.do_wgridding, diff --git a/pfb/utils/fits.py b/pfb/utils/fits.py index 422d49072..3d25a8397 100644 --- a/pfb/utils/fits.py +++ b/pfb/utils/fits.py @@ -1,7 +1,6 @@ import numpy as np from astropy.io import fits from astropy.wcs import WCS -from pfb.utils.misc import to4d import dask.array as da from dask import delayed from datetime import datetime @@ -10,6 +9,19 @@ from pfb.utils.naming import xds_from_list +def to4d(data): + if data.ndim == 4: + return data + elif data.ndim == 2: + return data[None, None] + elif data.ndim == 3: + return data[None] + elif data.ndim == 1: + return data[None, None, None] + else: + raise ValueError("Only arrays with ndim <= 4 can be broadcast to 4D.") + + def data_from_header(hdr, axis=3): npix = hdr['NAXIS' + str(axis)] refpix = hdr['CRPIX' + str(axis)] @@ -20,13 +32,13 @@ def data_from_header(hdr, axis=3): def load_fits(name, dtype=np.float32): data = fits.getdata(name) - data = np.transpose(to4d(data)[:, :, ::-1], axes=(0, 1, 3, 2)) + data = np.transpose(to4d(data), axes=(0, 1, 3, 2)) return np.require(data, dtype=dtype, requirements='C') def save_fits(data, name, hdr, overwrite=True, dtype=np.float32): hdu = fits.PrimaryHDU(header=hdr) - data = np.transpose(to4d(data), axes=(0, 1, 3, 2))[:, :, ::-1] + data = np.transpose(to4d(data), axes=(0, 1, 3, 2)) hdu.data = np.require(data, dtype=dtype, requirements='F') hdu.writeto(name, overwrite=overwrite) return @@ -62,9 +74,7 @@ def set_wcs(cell_x, cell_y, nx, ny, radec, freq, ref_freq = freq crpix3 = 1 w.wcs.crval = [radec[0]*180.0/np.pi, radec[1]*180.0/np.pi, ref_freq, 1] - # y axis treated differently because of wgridder convention? - # https://github.com/mreineck/ducc/issues/34 - w.wcs.crpix = [1 + nx//2, ny//2, crpix3, 1] + w.wcs.crpix = [1 + nx//2, 1 + ny//2, crpix3, 1] header = w.to_header() header['RESTFRQ'] = ref_freq diff --git a/pfb/utils/misc.py b/pfb/utils/misc.py index 338309669..f09cb1eb9 100644 --- a/pfb/utils/misc.py +++ b/pfb/utils/misc.py @@ -94,19 +94,6 @@ def kron_matvec2(A, b): return x -def to4d(data): - if data.ndim == 4: - return data - elif data.ndim == 2: - return data[None, None] - elif data.ndim == 3: - return data[None] - elif data.ndim == 1: - return data[None, None, None] - else: - raise ValueError("Only arrays with ndim <= 4 can be broadcast to 4D.") - - def Gaussian2D(xin, yin, GaussPar=(1., 1., 0.), normalise=True, nsigma=5): S0, S1, PA = GaussPar Smaj = S0 #np.maximum(S0, S1) @@ -1471,24 +1458,30 @@ def combine_columns(x, y, dc, dc1, dc2): return x -def set_image_size(uv_max, max_freq, opts): +def set_image_size( + uv_max, + max_freq, + field_of_view, + super_resolution_factor, + cell_size=None, + nx=None, ny=None, + psf_oversize=2.0): # max cell size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) - if opts.cell_size is not None: - cell_size = opts.cell_size + if cell_size is not None: cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError("Requested cell size too large. " "Super resolution factor = ", cell_N / cell_rad) else: - cell_rad = cell_N / opts.super_resolution_factor + cell_rad = cell_N / super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi - if opts.nx is None: - fov = opts.field_of_view * 3600 + if nx is None: + fov = field_of_view * 3600 npix = int(fov / cell_size) npix = good_size(npix) while npix % 2: @@ -1497,18 +1490,18 @@ def set_image_size(uv_max, max_freq, opts): nx = npix ny = npix else: - nx = opts.nx - ny = opts.ny if opts.ny is not None else nx + nx = nx + ny = ny if ny is not None else nx cell_deg = np.rad2deg(cell_rad) fovx = nx*cell_deg fovy = ny*cell_deg - nx_psf = good_size(int(opts.psf_oversize * nx)) + nx_psf = good_size(int(psf_oversize * nx)) while nx_psf % 2: nx_psf += 1 nx_psf = good_size(nx_psf) - ny_psf = good_size(int(opts.psf_oversize * ny)) + ny_psf = good_size(int(psf_oversize * ny)) while ny_psf % 2: ny_psf += 1 ny_psf = good_size(ny_psf) diff --git a/pfb/utils/stokes2im.py b/pfb/utils/stokes2im.py index fdf23b4e5..6bd3f8566 100644 --- a/pfb/utils/stokes2im.py +++ b/pfb/utils/stokes2im.py @@ -9,8 +9,8 @@ weight_data, filter_extreme_counts) from pfb.utils.misc import eval_coeffs_to_slice from pfb.utils.fits import set_wcs, save_fits -from pfb.operators.gridder import im2vis -from ducc0.wgridder import vis2dirty, dirty2vis +from pfb.operators.gridder import wgridder_conventions +from ducc0.wgridder import vis2dirty from casacore.quanta import quantity from datetime import datetime from ducc0.fft import c2r, r2c, good_size @@ -203,9 +203,8 @@ def single_stokes_image( tcoords[0,1] = tdec coords0 = np.array((ds.ra, ds.dec)) lm0 = radec_to_lm(tcoords, coords0).squeeze() - # LB - why the negative? - x0 = -lm0[0] - y0 = -lm0[1] + x0 = lm0[0] + y0 = lm0[1] else: x0 = 0.0 y0 = 0.0 @@ -224,7 +223,7 @@ def single_stokes_image( mask = (~flag).astype(np.uint8) - # TODO - this subtraction woul dbe better to do inside weight_data + # TODO - this subtraction would be better to do inside weight_data if opts.model_column is not None: ne.evaluate('(data-model_vis)*mask', out=data) @@ -238,6 +237,8 @@ def single_stokes_image( else: weight = None + flip_u, flip_v, flip_w, x0, y0 = wgridder_conventions(x0, y0) + if opts.robustness is not None: counts = _compute_counts(uvw, freq, @@ -246,7 +247,9 @@ def single_stokes_image( nx, ny, cell_rad, cell_rad, uvw.dtype, - ngrid=1) + ngrid=1, + usign=1.0 if flip_u else -1.0, + vsign=1.0 if flip_v else -1.0) imwgt = counts_to_weights( counts, @@ -254,7 +257,9 @@ def single_stokes_image( freq, nx, ny, cell_rad, cell_rad, - opts.robustness) + opts.robustness, + usign=1.0 if flip_u else -1.0, + vsign=1.0 if flip_v else -1.0) if weight is not None: weight *= imwgt else: @@ -271,10 +276,12 @@ def single_stokes_image( npix_x=nx, npix_y=ny, pixsize_x=cell_rad, pixsize_y=cell_rad, center_x=x0, center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=opts.epsilon, - flip_v=False, # hardcoded for now do_wgridding=opts.do_wgridding, - divide_by_n=False, # hardcoded for now + divide_by_n=True, # no rephasing or smooth beam so do it here nthreads=opts.nthreads, sigma_min=1.1, sigma_max=3.0, double_precision_accumulation=opts.double_accum, diff --git a/pfb/utils/weighting.py b/pfb/utils/weighting.py index d2f65dcdd..808f9208e 100644 --- a/pfb/utils/weighting.py +++ b/pfb/utils/weighting.py @@ -64,7 +64,7 @@ def compute_counts(dsl, @njit(nogil=True, cache=True, parallel=True) def _compute_counts(uvw, freq, mask, wgt, nx, ny, cell_size_x, cell_size_y, dtype, - k=6, ngrid=1): # support hardcoded for now + k=6, ngrid=1, usign=1.0, vsign=-1.0): # support hardcoded for now # ufreq u_cell = 1/(nx*cell_size_x) # shifts fftfreq such that they start at zero @@ -105,8 +105,8 @@ def _compute_counts(uvw, freq, mask, wgt, nx, ny, continue # current uv coords chan_normfreq = normfreq[c] - u_tmp = uvw_row[0] * chan_normfreq - v_tmp = uvw_row[1] * chan_normfreq + u_tmp = uvw_row[0] * chan_normfreq * usign + v_tmp = uvw_row[1] * chan_normfreq * vsign # pixel coordinates ug = (u_tmp + umax)/u_cell vg = (v_tmp + vmax)/v_cell @@ -138,7 +138,8 @@ def _es_kernel(x, beta, k): @njit(nogil=True, cache=True, parallel=True) def counts_to_weights(counts, uvw, freq, nx, ny, - cell_size_x, cell_size_y, robust): + cell_size_x, cell_size_y, robust, + usign=1.0, vsign=-1.0): # ufreq u_cell = 1/(nx*cell_size_x) umax = np.abs(-1/cell_size_x/2 - u_cell/2) @@ -169,8 +170,8 @@ def counts_to_weights(counts, uvw, freq, nx, ny, for c in range(nchan): # get current uv chan_normfreq = normfreq[c] - u_tmp = uvw_row[0] * chan_normfreq - v_tmp = uvw_row[1] * chan_normfreq + u_tmp = uvw_row[0] * chan_normfreq * usign + v_tmp = uvw_row[1] * chan_normfreq * vsign # get u index u_idx = int(np.floor((u_tmp + umax)/u_cell)) # get v index diff --git a/pfb/workers/degrid.py b/pfb/workers/degrid.py index 5b01c9471..c0941068d 100644 --- a/pfb/workers/degrid.py +++ b/pfb/workers/degrid.py @@ -216,14 +216,9 @@ def _degrid(**kw): tidx, tcnts, fidx, fcnts, mds, - # coeffs, - # locx, locy, modelf, tfunc, ffunc, - # nx, ny, - # cell_rad, cell_rad, - # x0=x0, y0=y0, nthreads=opts.nthreads, epsilon=opts.epsilon, do_wgridding=opts.do_wgridding, diff --git a/pfb/workers/fluxmop.py b/pfb/workers/fluxmop.py index c497c8980..e1fd1122a 100644 --- a/pfb/workers/fluxmop.py +++ b/pfb/workers/fluxmop.py @@ -160,8 +160,6 @@ def _fluxmop(ddsi=None, **kw): from daskms.fsspec_store import DaskMSStore from pfb.utils.naming import xds_from_url, xds_from_list from pfb.utils.misc import init_mask, dds2cubes - from pfb.operators.hessian import hessian_xds, hessian_psf_cube - from pfb.operators.psf import psf_convolve_cube from pfb.opt.pcg import pcg_dds from ducc0.misc import resize_thread_pool, thread_pool_size from ducc0.fft import c2c diff --git a/pfb/workers/fwdbwd.py b/pfb/workers/fwdbwd.py index 9b97d0b96..18050e6c4 100644 --- a/pfb/workers/fwdbwd.py +++ b/pfb/workers/fwdbwd.py @@ -1,473 +1,473 @@ -# flake8: noqa -import os -from pathlib import Path -from contextlib import ExitStack -from pfb.workers.main import cli -from functools import partial -import click -from omegaconf import OmegaConf -import pyscilog -pyscilog.init('pfb') -log = pyscilog.get_logger('FWDBWD') - -from scabha.schema_utils import clickify_parameters -from pfb.parser.schemas import schema - -# create default parameters from schema -defaults = {} -for key in schema.fwdbwd["inputs"].keys(): - defaults[key.replace("-", "_")] = schema.fwdbwd["inputs"][key]["default"] - -@cli.command(context_settings={'show_default': True}) -@clickify_parameters(schema.fwdbwd) -def fwdbwd(**kw): - ''' - Minimises - - (V - R f(x)).H W (V - R f(x)) + sigma_{21} | psi.H x |_{2,1} - - where f: R^N -> R^N is some function, R is the degridding operator - and psi is an over-complete dictionary of functions. Here sigma_{21} - is the strength of the regulariser. - ''' - defaults.update(kw) - opts = OmegaConf.create(defaults) - import time - timestamp = time.strftime("%Y%m%d-%H%M%S") - ldir = Path(opts.log_directory).resolve() - ldir.mkdir(parents=True, exist_ok=True) - pyscilog.log_to_file(f'{str(ldir)}/fwdbwd_{timestamp}.log') - print(f'Logs will be written to {str(ldir)}/fwdbwd_{timestamp}.log', file=log) - - if opts.nworkers is None: - if opts.scheduler=='distributed': - opts.nworkers = opts.nband - else: - opts.nworkers = 1 - - OmegaConf.set_struct(opts, True) - - with ExitStack() as stack: - from pfb import set_client - opts = set_client(opts, stack, log, scheduler=opts.scheduler) - - # TODO - prettier config printing - print('Input Options:', file=log) - for key in opts.keys(): - print(' %25s = %s' % (key, opts[key]), file=log) - - _fwdbwd(**opts) - - print("All done here.", file=log) - -def _fwdbwd(ddsi=None, **kw): - opts = OmegaConf.create(kw) - OmegaConf.set_struct(opts, True) - - import numpy as np - import xarray as xr - import numexpr as ne - import dask - import dask.array as da - from dask.distributed import performance_report - from pfb.utils.fits import (set_wcs, save_fits, dds2fits, - dds2fits_mfs, load_fits) - from pfb.utils.misc import dds2cubes - from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr - from pfb.opt.power_method import power_method - from pfb.opt.pcg import pcg - from pfb.opt.primal_dual import primal_dual_optimised as primal_dual - from pfb.utils.misc import l1reweight_func, setup_parametrisation - from pfb.operators.hessian import hessian_xds - from pfb.operators.psf import psf_convolve_cube - from pfb.operators.psi import im2coef - from pfb.operators.psi import coef2im - from copy import copy, deepcopy - from ducc0.misc import make_noncritical - from pfb.wavelets.wavelets import wavelet_setup - from pfb.prox.prox_21m import prox_21m_numba as prox_21 - # from pfb.prox.prox_21 import prox_21 - from pfb.utils.misc import fitcleanbeam - - basename = f'{opts.output_filename}_{opts.product.upper()}' - - dds_name = f'{basename}_{opts.suffix}.dds' - if ddsi is not None: - dds = [] - for ds in ddsi: - dds.append(ds.chunk({'row':-1, - 'chan':-1, - 'x':-1, - 'y':-1, - 'x_psf':-1, - 'y_psf':-1, - 'yo2':-1})) - else: - dds = xds_from_zarr(dds_name, chunks={'row':-1, - 'chan':-1, - 'x':-1, - 'y':-1, - 'x_psf':-1, - 'y_psf':-1, - 'yo2':-1}) - - if opts.memory_greedy: - dds = dask.persist(dds)[0] - - nx_psf, ny_psf = dds[0].x_psf.size, dds[0].y_psf.size - lastsize = ny_psf - - # stitch dirty/psf in apparent scale - print("Combining slices into cubes", file=log) - output_type = dds[0].DIRTY.dtype - dirty, model, residual, psf, psfhat, beam, wsums, dual = dds2cubes( - dds, - opts.nband, - apparent=False) - wsum = np.sum(wsums) - psf_mfs = np.sum(psf, axis=0) - assert (psf_mfs.max() - 1.0) < 2*opts.epsilon - dirty_mfs = np.sum(dirty, axis=0) - if residual is None: - residual = dirty.copy() - residual_mfs = dirty_mfs.copy() - else: - residual_mfs = np.sum(residual, axis=0) - - # for intermediary results (not currently written) - freq_out = [] - for ds in dds: - freq_out.append(ds.freq_out) - freq_out = np.unique(np.array(freq_out)) - nband = opts.nband - nx = dds[0].x.size - ny = dds[0].y.size - ra = dds[0].ra - dec = dds[0].dec - radec = [ra, dec] - cell_rad = dds[0].cell_rad - cell_deg = np.rad2deg(cell_rad) - ref_freq = np.mean(freq_out) - hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq) - - # set up vis space Hessian - hessopts = {} - hessopts['cell'] = dds[0].cell_rad - hessopts['do_wgridding'] = opts.do_wgridding - hessopts['epsilon'] = opts.epsilon - hessopts['double_accum'] = opts.double_accum - hessopts['nthreads'] = opts.nthreads # nvthreads since dask parallel over band - # always clean in apparent scale so no beam - # mask is applied to residual after hessian application - hess = partial(hessian_xds, xds=dds, hessopts=hessopts, - wsum=wsum, sigmainv=0, - mask=np.ones((nx, ny), dtype=output_type), - compute=True, use_beam=False) - - - # image space hessian - # pre-allocate arrays for doing FFT's - xout = np.empty(dirty.shape, dtype=dirty.dtype, order='C') - xout = make_noncritical(xout) - xpad = np.empty(psf.shape, dtype=dirty.dtype, order='C') - xpad = make_noncritical(xpad) - xhat = np.empty(psfhat.shape, dtype=psfhat.dtype) - xhat = make_noncritical(xhat) - # We use nthreads = nvthreads*nthreads_dask because dask not involved - psf_convolve = partial(psf_convolve_cube, xpad, xhat, xout, psfhat, lastsize, - nthreads=opts.nthreads*opts.nthreads_dask) - - print("Setting up dictionary", file=log) - bases = tuple(opts.bases.split(',')) - nbasis = len(bases) - iy, sy, ntot, nmax = wavelet_setup( - np.zeros((1, nx, ny), dtype=dirty.dtype), - bases, opts.nlevels) - ntot = tuple(ntot) - - psiH = partial(im2coef, - bases=bases, - ntot=ntot, - nmax=nmax, - nlevels=opts.nlevels, - nthreads=opts.nthreads*opts.nthreads_dask) # nthreads = nvthreads*nthreads_dask because dask not involved - psi = partial(coef2im, - bases=bases, - ntot=ntot, - iy=iy, - sy=sy, - nx=nx, - ny=ny, - nthreads=opts.nthreads*opts.nthreads_dask) # nthreads = nvthreads*nthreads_dask because dask not involved - - def hesspsi(x): - tmpa = np.zeros((nband, nbasis, nmax), dtype=dirty.dtype) - psiH(x, tmpa) - tmpx = np.zeros((nband, nx, ny), dtype=dirty.dtype) - psi(tmpa, tmpx) - return tmpx - - psinorm, _ = power_method(hesspsi, (nband, nx, ny), - tol=opts.pm_tol, - maxit=opts.pm_maxit, - verbosity=opts.pm_verbose, - report_freq=opts.pm_report_freq) - - print(f"psinorm = {psinorm}", file=log) - - # get clean beam area to convert residual units during l1reweighting - # TODO - could refine this with comparison between dirty and restored - # if continuing the deconvolution - GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)[0] - pix_per_beam = GaussPar[0]*GaussPar[1]*np.pi/4 - print(f"Number of pixels per beam estimated as {pix_per_beam}", - file=log) - - # We do the following to set hyper-parameters in an intuitive way - # i) convert residual units so it is comparable to model - # ii) project residual into dual domain - # iii) compute the rms in the space where thresholding happens - psiHoutvar = np.zeros((nband, nbasis, nmax), dtype=dirty.dtype) - fsel = wsums > 0 - tmp2 = residual.copy() - tmp2[fsel] *= wsum/wsums[fsel, None, None] - psiH(tmp2/pix_per_beam, psiHoutvar) - rms_comps = np.std(np.sum(psiHoutvar, axis=0), - axis=-1)[:, None] # preserve axes - - import ipdb; ipdb.set_trace() - func, finv, dfunc, dhfunc = setup_parametrisation(mode=opts.parametrisation, - minval=np.median(model[model>0]), - sigma=1.0, - freq=freq_out, - lscale=1.0) - - def gradf(residual, x, dhf): - return -2*residual - - def hessian_psf(psfo, x0, sigmainv, df, dhf, v): - ''' - psfo is the convolution operator and x0 is the fixed value of x at which - we evaluate the operator. v is the vector to be acted on. - ''' - dx0 = df(x0) - return 2 * dhf(psfo(df(v))) + v*sigmainv - - def get_scaling(hessf): - tmpx = np.random.randn(*dirty.shape) - convx = hessf(tmpx) - convx[fsel] *= wsum/wsums[fsel, None, None] - return np.std(convx) - - if 'PARAM' in dds[0] and dds[0].parametrisation == opts.parametrisation and not opts.restart: - print("Found matching parametrisation for PARAM in dds", file=log) - x = [ds.PARAM.data for ds in dds] - x = da.stack(x).compute() - elif model.any() and not opts.restart: - print("Initialsing PARAM from MODEL in dds", file=log) - # fall back and compute param from model in this case - x = finv(model) - # finv is not necessarily exact so we need to recompute residual - print("Computing residual", file=log) - model = func(x) - convimage = hess(model) - ne.evaluate('dirty - convimage', out=residual, - casting='same_kind') - ne.evaluate('sum(residual, axis=0)', out=residual_mfs, - casting='same_kind') - # in this case the dual is also probably not useful - dual = None - else: - print("Initialising PARAM to all zeros", file=log) - x = np.zeros_like(dirty) - model = func(x) - residual = dirty.copy() - residual_mfs = dirty_mfs.copy() - dual = None # force reset - - if dual is None: - dual = np.zeros((nband, nbasis, nmax), dtype=dirty.dtype) - l1weight = np.ones((nbasis, nmax), dtype=dirty.dtype) - reweighter = None - else: - if opts.l1reweight_from == 0: - print('Initialising with L1 reweighted', file=log) - reweighter = partial(l1reweight_func, psiH, psiHoutvar, opts.rmsfactor, rms_comps) - l1weight = reweighter(x) - # l1weight[l1weight < 1.0] = 0.0 - else: - l1weight = np.ones((nbasis, nmax), dtype=dirty.dtype) - reweighter = None - - - # for generality the prox function only takes the - # array variable and step size as inputs - # prox21 = partial(prox_21, weight=l1weight) - - hessbeta = None - rms = np.std(residual_mfs) - rmax = np.abs(residual_mfs).max() - print(f"Iter 0: peak residual = {rmax:.3e}, rms = {rms:.3e}", - file=log) - for k in range(opts.niter): - xp = x.copy() - df = partial(dfunc, xp) - dhf = partial(dhfunc, xp) - j = -gradf(residual, xp, dhf) - print("Finding spectral norm of Hessian approximation", file=log) - # hessian depends on x and sigmainv so need to do this at every iteration - sigmainv = np.maximum(np.std(j), opts.sigmainv) - hesspsf = partial(hessian_psf, psf_convolve, xp, sigmainv, df, dhf) - hess_norm, hessbeta = power_method(hesspsf, (nband, nx, ny), - b0=hessbeta, - tol=opts.pm_tol, - maxit=opts.pm_maxit, - verbosity=opts.pm_verbose, - report_freq=opts.pm_report_freq) - - print(f"Solving forward step with sigmainv = {sigmainv}", file=log) - delx = pcg(hesspsf, - j, - tol=opts.cg_tol, - maxit=opts.cg_maxit, - minit=opts.cg_minit, - verbosity=opts.cg_verbose, - report_freq=opts.cg_report_freq, - backtrack=opts.backtrack) - - save_fits(np.mean(delx, axis=0), - basename + f'_{opts.suffix}_update_{k+1}.fits', - hdr_mfs) - - # compute scaling of the residual - rscale = get_scaling(hesspsf) - # get rms in space where thresholding happens - psiH(delx, psiHoutvar) - rmscomps = np.std(np.sum(psiHoutvar, axis=0)) - - if opts.sigma21 is None: - sigma21 = opts.rmsfactor*np.std(j/rscale) - # sigma21 = opts.rmsfactor*rmscomps - else: - sigma21 = opts.sigma21 - if sigma21: - print(f'Solving backward step with sig21 = {sigma21}', file=log) - data = xp + opts.gamma * delx - if opts.parametrisation != 'id': - if xp.any(): - bedges = np.histogram_bin_edges(xp.ravel(), bins='fd') - else: - bedges = np.histogram_bin_edges(data.ravel(), bins='fd') - dhist, _ = np.histogram(data.ravel(), bins=bedges) - dmax = dhist.argmax() - # dmode = (bedges[dmax] + bedges[dmax+1])/2.0 - dmode = bedges[dmax] - print(f"Removing mode = {dmode} prior to backward step", file=log) - data -= dmode - x -= dmode - grad21 = lambda v: hesspsf(v - data) - x, dual = primal_dual(x, - dual, - sigma21, - psi, - psiH, - hess_norm, - prox_21, - l1weight, - reweighter, - grad21, - nu=psinorm, - positivity=opts.positivity, - tol=opts.pd_tol, - maxit=opts.pd_maxit, - verbosity=opts.pd_verbose, - report_freq=opts.pd_report_freq, - gamma=opts.gamma) - if opts.parametrisation != 'id': - x += dmode - else: - x = xp + opts.gamma * delx - save_fits(np.mean(x, axis=0), - basename + f'_{opts.suffix}_param_{k+1}.fits', - hdr_mfs) - - model = func(x) - save_fits(np.mean(model, axis=0), - basename + f'_{opts.suffix}_model_{k+1}.fits', - hdr_mfs) - - print("Getting residual", file=log) - convimage = hess(model) - ne.evaluate('dirty - convimage', out=residual, - casting='same_kind') - ne.evaluate('sum(residual, axis=0)', out=residual_mfs, - casting='same_kind') - - save_fits(residual_mfs, - basename + f'_{opts.suffix}_residual_{k+1}.fits', - hdr_mfs) - - rms = np.std(residual_mfs) - rmax = np.abs(residual_mfs).max() - eps = np.linalg.norm(x - xp)/np.linalg.norm(x) - - print(f"Iter {k+1}: peak residual = {rmax:.3e}, " - f"rms = {rms:.3e}, eps = {eps:.3e}", - file=log) - - if k+1 >= opts.l1reweight_from: - print('Computing L1 weights', file=log) - # convert residual units so it is comparable to model - # tmp2[fsel] = residual[fsel] * wsum/wsums[fsel, None, None] - # psiH(tmp2/pix_per_beam, psiHoutvar) - # psiH(tmp2/rscale, psiHoutvar) - psiH(delx, psiHoutvar) - rms_comps = np.std(np.sum(psiHoutvar, axis=0), - axis=-1)[:, None] # preserve axes - # we redefine the reweighter here since the rms has changed - reweighter = partial(l1reweight_func, psiH, psiHoutvar, opts.rmsfactor, rms_comps) - l1weight = reweighter(x) - # l1weight[l1weight < 1.0] = 0.0 - # prox21 = partial(prox_21, weight=l1weight, axis=0) - - print("Updating results", file=log) - dds_out = [] - for ds in dds: - b = ds.bandid - r = da.from_array(residual[b]*wsum) - m = da.from_array(model[b]) - d = da.from_array(dual[b]) - xb = da.from_array(x[b]) - ds_out = ds.assign(**{'RESIDUAL': (('x', 'y'), r), - 'MODEL': (('x', 'y'), m), - 'DUAL': (('c', 'n'), d), - 'PARAM': (('x', 'y'), xb)}) - ds_out = ds_out.assign_attrs({'parametrisation': opts.parametrisation}) - dds_out.append(ds_out) - writes = xds_to_zarr(dds_out, dds_name, - columns=('RESIDUAL', 'MODEL', 'DUAL', 'PARAM'), - rechunk=True) - dask.compute(writes) - - if eps < opts.tol: - print(f"Converged after {k+1} iterations.", file=log) - break - - - dds = xds_from_zarr(dds_name, chunks={'x': -1, 'y': -1}) - - # convert to fits files - fitsout = [] - if opts.fits_mfs: - fitsout.append(dds2fits_mfs(dds, 'RESIDUAL', f'{basename}_{opts.suffix}', norm_wsum=True)) - fitsout.append(dds2fits_mfs(dds, 'MODEL', f'{basename}_{opts.suffix}', norm_wsum=False)) - fitsout.append(dds2fits_mfs(dds, 'PARAM', f'{basename}_{opts.suffix}', norm_wsum=False)) - - if opts.fits_cubes: - fitsout.append(dds2fits(dds, 'RESIDUAL', f'{basename}_{opts.suffix}', norm_wsum=True)) - fitsout.append(dds2fits(dds, 'MODEL', f'{basename}_{opts.suffix}', norm_wsum=False)) - fitsout.append(dds2fits(dds, 'PARAM', f'{basename}_{opts.suffix}', norm_wsum=False)) - - if len(fitsout): - print("Writing fits", file=log) - dask.compute(fitsout) +# # flake8: noqa +# import os +# from pathlib import Path +# from contextlib import ExitStack +# from pfb.workers.main import cli +# from functools import partial +# import click +# from omegaconf import OmegaConf +# import pyscilog +# pyscilog.init('pfb') +# log = pyscilog.get_logger('FWDBWD') + +# from scabha.schema_utils import clickify_parameters +# from pfb.parser.schemas import schema + +# # create default parameters from schema +# defaults = {} +# for key in schema.fwdbwd["inputs"].keys(): +# defaults[key.replace("-", "_")] = schema.fwdbwd["inputs"][key]["default"] + +# @cli.command(context_settings={'show_default': True}) +# @clickify_parameters(schema.fwdbwd) +# def fwdbwd(**kw): +# ''' +# Minimises + +# (V - R f(x)).H W (V - R f(x)) + sigma_{21} | psi.H x |_{2,1} + +# where f: R^N -> R^N is some function, R is the degridding operator +# and psi is an over-complete dictionary of functions. Here sigma_{21} +# is the strength of the regulariser. +# ''' +# defaults.update(kw) +# opts = OmegaConf.create(defaults) +# import time +# timestamp = time.strftime("%Y%m%d-%H%M%S") +# ldir = Path(opts.log_directory).resolve() +# ldir.mkdir(parents=True, exist_ok=True) +# pyscilog.log_to_file(f'{str(ldir)}/fwdbwd_{timestamp}.log') +# print(f'Logs will be written to {str(ldir)}/fwdbwd_{timestamp}.log', file=log) + +# if opts.nworkers is None: +# if opts.scheduler=='distributed': +# opts.nworkers = opts.nband +# else: +# opts.nworkers = 1 + +# OmegaConf.set_struct(opts, True) + +# with ExitStack() as stack: +# from pfb import set_client +# opts = set_client(opts, stack, log, scheduler=opts.scheduler) + +# # TODO - prettier config printing +# print('Input Options:', file=log) +# for key in opts.keys(): +# print(' %25s = %s' % (key, opts[key]), file=log) + +# _fwdbwd(**opts) + +# print("All done here.", file=log) + +# def _fwdbwd(ddsi=None, **kw): +# opts = OmegaConf.create(kw) +# OmegaConf.set_struct(opts, True) + +# import numpy as np +# import xarray as xr +# import numexpr as ne +# import dask +# import dask.array as da +# from dask.distributed import performance_report +# from pfb.utils.fits import (set_wcs, save_fits, dds2fits, +# dds2fits_mfs, load_fits) +# from pfb.utils.misc import dds2cubes +# from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr +# from pfb.opt.power_method import power_method +# from pfb.opt.pcg import pcg +# from pfb.opt.primal_dual import primal_dual_optimised as primal_dual +# from pfb.utils.misc import l1reweight_func, setup_parametrisation +# from pfb.operators.hessian import hessian_xds +# from pfb.operators.psf import psf_convolve_cube +# from pfb.operators.psi import im2coef +# from pfb.operators.psi import coef2im +# from copy import copy, deepcopy +# from ducc0.misc import make_noncritical +# from pfb.wavelets.wavelets import wavelet_setup +# from pfb.prox.prox_21m import prox_21m_numba as prox_21 +# # from pfb.prox.prox_21 import prox_21 +# from pfb.utils.misc import fitcleanbeam + +# basename = f'{opts.output_filename}_{opts.product.upper()}' + +# dds_name = f'{basename}_{opts.suffix}.dds' +# if ddsi is not None: +# dds = [] +# for ds in ddsi: +# dds.append(ds.chunk({'row':-1, +# 'chan':-1, +# 'x':-1, +# 'y':-1, +# 'x_psf':-1, +# 'y_psf':-1, +# 'yo2':-1})) +# else: +# dds = xds_from_zarr(dds_name, chunks={'row':-1, +# 'chan':-1, +# 'x':-1, +# 'y':-1, +# 'x_psf':-1, +# 'y_psf':-1, +# 'yo2':-1}) + +# if opts.memory_greedy: +# dds = dask.persist(dds)[0] + +# nx_psf, ny_psf = dds[0].x_psf.size, dds[0].y_psf.size +# lastsize = ny_psf + +# # stitch dirty/psf in apparent scale +# print("Combining slices into cubes", file=log) +# output_type = dds[0].DIRTY.dtype +# dirty, model, residual, psf, psfhat, beam, wsums, dual = dds2cubes( +# dds, +# opts.nband, +# apparent=False) +# wsum = np.sum(wsums) +# psf_mfs = np.sum(psf, axis=0) +# assert (psf_mfs.max() - 1.0) < 2*opts.epsilon +# dirty_mfs = np.sum(dirty, axis=0) +# if residual is None: +# residual = dirty.copy() +# residual_mfs = dirty_mfs.copy() +# else: +# residual_mfs = np.sum(residual, axis=0) + +# # for intermediary results (not currently written) +# freq_out = [] +# for ds in dds: +# freq_out.append(ds.freq_out) +# freq_out = np.unique(np.array(freq_out)) +# nband = opts.nband +# nx = dds[0].x.size +# ny = dds[0].y.size +# ra = dds[0].ra +# dec = dds[0].dec +# radec = [ra, dec] +# cell_rad = dds[0].cell_rad +# cell_deg = np.rad2deg(cell_rad) +# ref_freq = np.mean(freq_out) +# hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq) + +# # set up vis space Hessian +# hessopts = {} +# hessopts['cell'] = dds[0].cell_rad +# hessopts['do_wgridding'] = opts.do_wgridding +# hessopts['epsilon'] = opts.epsilon +# hessopts['double_accum'] = opts.double_accum +# hessopts['nthreads'] = opts.nthreads # nvthreads since dask parallel over band +# # always clean in apparent scale so no beam +# # mask is applied to residual after hessian application +# hess = partial(hessian_xds, xds=dds, hessopts=hessopts, +# wsum=wsum, sigmainv=0, +# mask=np.ones((nx, ny), dtype=output_type), +# compute=True, use_beam=False) + + +# # image space hessian +# # pre-allocate arrays for doing FFT's +# xout = np.empty(dirty.shape, dtype=dirty.dtype, order='C') +# xout = make_noncritical(xout) +# xpad = np.empty(psf.shape, dtype=dirty.dtype, order='C') +# xpad = make_noncritical(xpad) +# xhat = np.empty(psfhat.shape, dtype=psfhat.dtype) +# xhat = make_noncritical(xhat) +# # We use nthreads = nvthreads*nthreads_dask because dask not involved +# psf_convolve = partial(psf_convolve_cube, xpad, xhat, xout, psfhat, lastsize, +# nthreads=opts.nthreads*opts.nthreads_dask) + +# print("Setting up dictionary", file=log) +# bases = tuple(opts.bases.split(',')) +# nbasis = len(bases) +# iy, sy, ntot, nmax = wavelet_setup( +# np.zeros((1, nx, ny), dtype=dirty.dtype), +# bases, opts.nlevels) +# ntot = tuple(ntot) + +# psiH = partial(im2coef, +# bases=bases, +# ntot=ntot, +# nmax=nmax, +# nlevels=opts.nlevels, +# nthreads=opts.nthreads*opts.nthreads_dask) # nthreads = nvthreads*nthreads_dask because dask not involved +# psi = partial(coef2im, +# bases=bases, +# ntot=ntot, +# iy=iy, +# sy=sy, +# nx=nx, +# ny=ny, +# nthreads=opts.nthreads*opts.nthreads_dask) # nthreads = nvthreads*nthreads_dask because dask not involved + +# def hesspsi(x): +# tmpa = np.zeros((nband, nbasis, nmax), dtype=dirty.dtype) +# psiH(x, tmpa) +# tmpx = np.zeros((nband, nx, ny), dtype=dirty.dtype) +# psi(tmpa, tmpx) +# return tmpx + +# psinorm, _ = power_method(hesspsi, (nband, nx, ny), +# tol=opts.pm_tol, +# maxit=opts.pm_maxit, +# verbosity=opts.pm_verbose, +# report_freq=opts.pm_report_freq) + +# print(f"psinorm = {psinorm}", file=log) + +# # get clean beam area to convert residual units during l1reweighting +# # TODO - could refine this with comparison between dirty and restored +# # if continuing the deconvolution +# GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)[0] +# pix_per_beam = GaussPar[0]*GaussPar[1]*np.pi/4 +# print(f"Number of pixels per beam estimated as {pix_per_beam}", +# file=log) + +# # We do the following to set hyper-parameters in an intuitive way +# # i) convert residual units so it is comparable to model +# # ii) project residual into dual domain +# # iii) compute the rms in the space where thresholding happens +# psiHoutvar = np.zeros((nband, nbasis, nmax), dtype=dirty.dtype) +# fsel = wsums > 0 +# tmp2 = residual.copy() +# tmp2[fsel] *= wsum/wsums[fsel, None, None] +# psiH(tmp2/pix_per_beam, psiHoutvar) +# rms_comps = np.std(np.sum(psiHoutvar, axis=0), +# axis=-1)[:, None] # preserve axes + +# import ipdb; ipdb.set_trace() +# func, finv, dfunc, dhfunc = setup_parametrisation(mode=opts.parametrisation, +# minval=np.median(model[model>0]), +# sigma=1.0, +# freq=freq_out, +# lscale=1.0) + +# def gradf(residual, x, dhf): +# return -2*residual + +# def hessian_psf(psfo, x0, sigmainv, df, dhf, v): +# ''' +# psfo is the convolution operator and x0 is the fixed value of x at which +# we evaluate the operator. v is the vector to be acted on. +# ''' +# dx0 = df(x0) +# return 2 * dhf(psfo(df(v))) + v*sigmainv + +# def get_scaling(hessf): +# tmpx = np.random.randn(*dirty.shape) +# convx = hessf(tmpx) +# convx[fsel] *= wsum/wsums[fsel, None, None] +# return np.std(convx) + +# if 'PARAM' in dds[0] and dds[0].parametrisation == opts.parametrisation and not opts.restart: +# print("Found matching parametrisation for PARAM in dds", file=log) +# x = [ds.PARAM.data for ds in dds] +# x = da.stack(x).compute() +# elif model.any() and not opts.restart: +# print("Initialsing PARAM from MODEL in dds", file=log) +# # fall back and compute param from model in this case +# x = finv(model) +# # finv is not necessarily exact so we need to recompute residual +# print("Computing residual", file=log) +# model = func(x) +# convimage = hess(model) +# ne.evaluate('dirty - convimage', out=residual, +# casting='same_kind') +# ne.evaluate('sum(residual, axis=0)', out=residual_mfs, +# casting='same_kind') +# # in this case the dual is also probably not useful +# dual = None +# else: +# print("Initialising PARAM to all zeros", file=log) +# x = np.zeros_like(dirty) +# model = func(x) +# residual = dirty.copy() +# residual_mfs = dirty_mfs.copy() +# dual = None # force reset + +# if dual is None: +# dual = np.zeros((nband, nbasis, nmax), dtype=dirty.dtype) +# l1weight = np.ones((nbasis, nmax), dtype=dirty.dtype) +# reweighter = None +# else: +# if opts.l1reweight_from == 0: +# print('Initialising with L1 reweighted', file=log) +# reweighter = partial(l1reweight_func, psiH, psiHoutvar, opts.rmsfactor, rms_comps) +# l1weight = reweighter(x) +# # l1weight[l1weight < 1.0] = 0.0 +# else: +# l1weight = np.ones((nbasis, nmax), dtype=dirty.dtype) +# reweighter = None + + +# # for generality the prox function only takes the +# # array variable and step size as inputs +# # prox21 = partial(prox_21, weight=l1weight) + +# hessbeta = None +# rms = np.std(residual_mfs) +# rmax = np.abs(residual_mfs).max() +# print(f"Iter 0: peak residual = {rmax:.3e}, rms = {rms:.3e}", +# file=log) +# for k in range(opts.niter): +# xp = x.copy() +# df = partial(dfunc, xp) +# dhf = partial(dhfunc, xp) +# j = -gradf(residual, xp, dhf) +# print("Finding spectral norm of Hessian approximation", file=log) +# # hessian depends on x and sigmainv so need to do this at every iteration +# sigmainv = np.maximum(np.std(j), opts.sigmainv) +# hesspsf = partial(hessian_psf, psf_convolve, xp, sigmainv, df, dhf) +# hess_norm, hessbeta = power_method(hesspsf, (nband, nx, ny), +# b0=hessbeta, +# tol=opts.pm_tol, +# maxit=opts.pm_maxit, +# verbosity=opts.pm_verbose, +# report_freq=opts.pm_report_freq) + +# print(f"Solving forward step with sigmainv = {sigmainv}", file=log) +# delx = pcg(hesspsf, +# j, +# tol=opts.cg_tol, +# maxit=opts.cg_maxit, +# minit=opts.cg_minit, +# verbosity=opts.cg_verbose, +# report_freq=opts.cg_report_freq, +# backtrack=opts.backtrack) + +# save_fits(np.mean(delx, axis=0), +# basename + f'_{opts.suffix}_update_{k+1}.fits', +# hdr_mfs) + +# # compute scaling of the residual +# rscale = get_scaling(hesspsf) +# # get rms in space where thresholding happens +# psiH(delx, psiHoutvar) +# rmscomps = np.std(np.sum(psiHoutvar, axis=0)) + +# if opts.sigma21 is None: +# sigma21 = opts.rmsfactor*np.std(j/rscale) +# # sigma21 = opts.rmsfactor*rmscomps +# else: +# sigma21 = opts.sigma21 +# if sigma21: +# print(f'Solving backward step with sig21 = {sigma21}', file=log) +# data = xp + opts.gamma * delx +# if opts.parametrisation != 'id': +# if xp.any(): +# bedges = np.histogram_bin_edges(xp.ravel(), bins='fd') +# else: +# bedges = np.histogram_bin_edges(data.ravel(), bins='fd') +# dhist, _ = np.histogram(data.ravel(), bins=bedges) +# dmax = dhist.argmax() +# # dmode = (bedges[dmax] + bedges[dmax+1])/2.0 +# dmode = bedges[dmax] +# print(f"Removing mode = {dmode} prior to backward step", file=log) +# data -= dmode +# x -= dmode +# grad21 = lambda v: hesspsf(v - data) +# x, dual = primal_dual(x, +# dual, +# sigma21, +# psi, +# psiH, +# hess_norm, +# prox_21, +# l1weight, +# reweighter, +# grad21, +# nu=psinorm, +# positivity=opts.positivity, +# tol=opts.pd_tol, +# maxit=opts.pd_maxit, +# verbosity=opts.pd_verbose, +# report_freq=opts.pd_report_freq, +# gamma=opts.gamma) +# if opts.parametrisation != 'id': +# x += dmode +# else: +# x = xp + opts.gamma * delx +# save_fits(np.mean(x, axis=0), +# basename + f'_{opts.suffix}_param_{k+1}.fits', +# hdr_mfs) + +# model = func(x) +# save_fits(np.mean(model, axis=0), +# basename + f'_{opts.suffix}_model_{k+1}.fits', +# hdr_mfs) + +# print("Getting residual", file=log) +# convimage = hess(model) +# ne.evaluate('dirty - convimage', out=residual, +# casting='same_kind') +# ne.evaluate('sum(residual, axis=0)', out=residual_mfs, +# casting='same_kind') + +# save_fits(residual_mfs, +# basename + f'_{opts.suffix}_residual_{k+1}.fits', +# hdr_mfs) + +# rms = np.std(residual_mfs) +# rmax = np.abs(residual_mfs).max() +# eps = np.linalg.norm(x - xp)/np.linalg.norm(x) + +# print(f"Iter {k+1}: peak residual = {rmax:.3e}, " +# f"rms = {rms:.3e}, eps = {eps:.3e}", +# file=log) + +# if k+1 >= opts.l1reweight_from: +# print('Computing L1 weights', file=log) +# # convert residual units so it is comparable to model +# # tmp2[fsel] = residual[fsel] * wsum/wsums[fsel, None, None] +# # psiH(tmp2/pix_per_beam, psiHoutvar) +# # psiH(tmp2/rscale, psiHoutvar) +# psiH(delx, psiHoutvar) +# rms_comps = np.std(np.sum(psiHoutvar, axis=0), +# axis=-1)[:, None] # preserve axes +# # we redefine the reweighter here since the rms has changed +# reweighter = partial(l1reweight_func, psiH, psiHoutvar, opts.rmsfactor, rms_comps) +# l1weight = reweighter(x) +# # l1weight[l1weight < 1.0] = 0.0 +# # prox21 = partial(prox_21, weight=l1weight, axis=0) + +# print("Updating results", file=log) +# dds_out = [] +# for ds in dds: +# b = ds.bandid +# r = da.from_array(residual[b]*wsum) +# m = da.from_array(model[b]) +# d = da.from_array(dual[b]) +# xb = da.from_array(x[b]) +# ds_out = ds.assign(**{'RESIDUAL': (('x', 'y'), r), +# 'MODEL': (('x', 'y'), m), +# 'DUAL': (('c', 'n'), d), +# 'PARAM': (('x', 'y'), xb)}) +# ds_out = ds_out.assign_attrs({'parametrisation': opts.parametrisation}) +# dds_out.append(ds_out) +# writes = xds_to_zarr(dds_out, dds_name, +# columns=('RESIDUAL', 'MODEL', 'DUAL', 'PARAM'), +# rechunk=True) +# dask.compute(writes) + +# if eps < opts.tol: +# print(f"Converged after {k+1} iterations.", file=log) +# break + + +# dds = xds_from_zarr(dds_name, chunks={'x': -1, 'y': -1}) + +# # convert to fits files +# fitsout = [] +# if opts.fits_mfs: +# fitsout.append(dds2fits_mfs(dds, 'RESIDUAL', f'{basename}_{opts.suffix}', norm_wsum=True)) +# fitsout.append(dds2fits_mfs(dds, 'MODEL', f'{basename}_{opts.suffix}', norm_wsum=False)) +# fitsout.append(dds2fits_mfs(dds, 'PARAM', f'{basename}_{opts.suffix}', norm_wsum=False)) + +# if opts.fits_cubes: +# fitsout.append(dds2fits(dds, 'RESIDUAL', f'{basename}_{opts.suffix}', norm_wsum=True)) +# fitsout.append(dds2fits(dds, 'MODEL', f'{basename}_{opts.suffix}', norm_wsum=False)) +# fitsout.append(dds2fits(dds, 'PARAM', f'{basename}_{opts.suffix}', norm_wsum=False)) + +# if len(fitsout): +# print("Writing fits", file=log) +# dask.compute(fitsout) diff --git a/pfb/workers/grid.py b/pfb/workers/grid.py index 14ecaa0d2..68ff83902 100644 --- a/pfb/workers/grid.py +++ b/pfb/workers/grid.py @@ -251,7 +251,12 @@ def _grid(xdsi=None, **kw): nx, ny, nx_psf, ny_psf, cell_N, cell_rad = set_image_size( uv_max, max_freq, - opts + opts.field_of_view, + opts.super_resolution_factor, + opts.cell_size, + opts.nx, + opts.ny, + opts.psf_oversize ) cell_deg = np.rad2deg(cell_rad) cell_size = cell_deg * 3600 @@ -407,21 +412,19 @@ def _grid(xdsi=None, **kw): tcoords[0,1] = tdec coords0 = np.array((ra, dec)) lm0 = radec_to_lm(tcoords, coords0).squeeze() - # The negative stems from a wgridder convention - # https://github.com/mreineck/ducc/issues/34 - x0 = -lm0[0] - y0 = -lm0[1] + l0 = lm0[0] + m0 = lm0[1] else: - x0 = 0.0 - y0 = 0.0 + l0 = 0.0 + m0 = 0.0 tra = ds.ra tdec = ds.dec attrs = { 'ra': tra, 'dec': tdec, - 'x0': x0, - 'y0': y0, + 'l0': l0, + 'm0': m0, 'cell_rad': cell_rad, 'bandid': bandid, 'timeid': timeid, @@ -454,10 +457,12 @@ def _grid(xdsi=None, **kw): ) elif from_cache: - if opts.use_best_model: + if opts.use_best_model and 'BEST_MODEL' in out_ds: model = out_ds.MODEL_BEST.values - else: + elif 'MODEL' in out_ds: model = out_ds.MODEL.values + else: + model = None else: model = None @@ -471,7 +476,7 @@ def _grid(xdsi=None, **kw): attrs, model=model, robustness=opts.robustness, - x0=x0, y0=y0, + l0=l0, m0=m0, nthreads=opts.nthreads, epsilon=opts.epsilon, do_wgridding=opts.do_wgridding, diff --git a/pfb/workers/klean.py b/pfb/workers/klean.py index 43b4f8674..75e65fa20 100644 --- a/pfb/workers/klean.py +++ b/pfb/workers/klean.py @@ -279,6 +279,9 @@ def _klean(ddsi=None, **kw): 'fexpr': fexpr, 'center_x': dds[0].x0, 'center_y': dds[0].y0, + 'flip_u': dds[0].flip_u, + 'flip_v': dds[0].flip_v, + 'flip_v': dds[0].flip_v, 'ra': dds[0].ra, 'dec': dds[0].dec, 'stokes': opts.product, # I,Q,U,V, IQ/IV, IQUV @@ -305,7 +308,6 @@ def _klean(ddsi=None, **kw): cell_rad, cell_rad, ds_name, model[b], - x0=ds.x0, y0=ds.y0, nthreads=opts.nthreads, epsilon=opts.epsilon, do_wgridding=opts.do_wgridding, @@ -382,7 +384,6 @@ def _klean(ddsi=None, **kw): cell_rad, cell_rad, ds_name, model[b], - x0=ds.x0, y0=ds.y0, nthreads=opts.nthreads, epsilon=opts.epsilon, do_wgridding=opts.do_wgridding, diff --git a/pfb/workers/model2comps.py b/pfb/workers/model2comps.py index 477b594b9..a2879eb35 100644 --- a/pfb/workers/model2comps.py +++ b/pfb/workers/model2comps.py @@ -280,6 +280,9 @@ def _model2comps(ddsi=None, **kw): 'fexpr': fexpr, 'center_x': x0, 'center_y': y0, + 'flip_u': dds[0].flip_u, + 'flip_v': dds[0].flip_v, + 'flip_v': dds[0].flip_v, 'ra': dds[0].ra, 'dec': dds[0].dec, 'stokes': opts.product, # I,Q,U,V, IQ/IV, IQUV diff --git a/pfb/workers/sara.py b/pfb/workers/sara.py index f159c3775..1f66f4f85 100644 --- a/pfb/workers/sara.py +++ b/pfb/workers/sara.py @@ -434,6 +434,9 @@ def _sara(ddsi=None, **kw): 'fexpr': fexpr, 'center_x': dds[0].x0, 'center_y': dds[0].y0, + 'flip_u': dds[0].flip_u, + 'flip_v': dds[0].flip_v, + 'flip_v': dds[0].flip_v, 'ra': dds[0].ra, 'dec': dds[0].dec, 'stokes': opts.product, # I,Q,U,V, IQ/IV, IQUV @@ -485,7 +488,6 @@ def _sara(ddsi=None, **kw): cell_rad, cell_rad, ds_name, model[b], - x0=ds.x0, y0=ds.y0, nthreads=opts.nthreads, epsilon=opts.epsilon, do_wgridding=opts.do_wgridding, diff --git a/tests/test_hessian_approx.py b/tests/test_hessian_approx.py index 9c6ab4742..cb0cd0048 100644 --- a/tests/test_hessian_approx.py +++ b/tests/test_hessian_approx.py @@ -1,61 +1,140 @@ import itertools import numpy as np import pytest -pmp = pytest.mark.parametrize -from ducc0.wgridder import dirty2vis, vis2dirty +from pathlib import Path +from pfb.operators.gridder import wgridder_conventions +from pfb.operators.hessian import _hessian_impl as hessian +from pfb.operators.psf import psf_convolve_slice +from pfb.utils.misc import set_image_size +from ducc0.wgridder.experimental import vis2dirty, dirty2vis +from scipy.constants import c as lightspeed +from daskms import xds_from_ms, xds_from_table from ducc0.fft import c2r, r2c iFs = np.fft.ifftshift Fs = np.fft.fftshift -from pfb.operators.hessian import _hessian_impl as hessian -from pfb.operators.psf import psf_convolve_slice -from ducc0.misc import make_noncritical -from ducc0.wgridder import vis2dirty -from ducc0.wgridder import dirty2vis +pmp = pytest.mark.parametrize + +@pmp("center_offset", [(0.0, 0.0), (0.1, -0.17), (0.2, 0.5)]) +def test_psfvis(center_offset, ms_name): + test_dir = Path(ms_name).resolve().parent + xds = xds_from_ms(ms_name, + chunks={'row': -1, 'chan': -1})[0] + spw = xds_from_table(f'{ms_name}::SPECTRAL_WINDOW')[0] + uvw = xds.UVW.values + freq = spw.CHAN_FREQ.values.squeeze() + + # uvw = ms.getcol('UVW') + nrow = uvw.shape[0] + nchan = freq.size + + umax = np.abs(uvw[:, 0]).max() + vmax = np.abs(uvw[:, 1]).max() + uv_max = np.maximum(umax, vmax) + max_freq = freq.max() -''' -R.H W R x \approx Z.H F.H Ihat F Z x + nx, ny, nx_psf, ny_psf, cell_N, cell_rad = set_image_size( + uv_max, + max_freq, + 1.0, + 2.0) + x0, y0 = center_offset + flip_u, flip_v, flip_w, x0, y0 = wgridder_conventions(x0, y0) + epsilon = 1e-10 + signu = -1.0 if flip_u else 1.0 + signv = -1.0 if flip_v else 1.0 + # we need these in the test because of flipped wgridder convention + # https://github.com/mreineck/ducc/issues/34 + signx = -1.0 if flip_u else 1.0 + signy = -1.0 if flip_v else 1.0 + # produce PSF visibilities centered at x0, y0 + n = np.sqrt(1 - x0**2 - y0**2) + freqfactor = -2j*np.pi*freq[None, :]/lightspeed + psf_vis = np.exp(freqfactor*(signu*uvw[:, 0:1]*x0*signx + + signv*uvw[:, 1:2]*y0*signy - + uvw[:, 2:]*(n-1))) + x = np.zeros((nx, ny), dtype='f8') + x[nx//2, ny//2] = 1.0 + psf_vis2 = dirty2vis( + uvw=uvw, + freq=freq, + dirty=x, + pixsize_x=cell_rad, + pixsize_y=cell_rad, + center_x=x0, + center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, + epsilon=epsilon, + nthreads=8, + do_wgridding=True, + divide_by_n=False) -''' + assert np.abs(psf_vis - psf_vis2).max() <= epsilon -# @pytest.mark.parametrize("center_offset", [(0.0, 0.0), (0.1, -0.17), (0.2, 0.5)]) -def test_hessian(): - # np.random.seed(42) - nx, ny = 128, 128 - nant = 10 - nchan = 2 - pixsize = 0.5 * np.pi / 180 / 3600. # 1 arcsec ~ 4 pixels / beam, so we'll avoid aliasing - l0, m0 = 0.0, 0.0 - dl = pixsize - dm = pixsize - ant1, ant2 = np.asarray(list(itertools.combinations(range(nant), 2))).T - antennas = 10e3 * np.random.normal(size=(nant, 3)) - antennas[:, 2] *= 0.001 - uvw = antennas[ant2] - antennas[ant1] +@pmp("center_offset", [(0.0, 0.0), (0.1, -0.17), (0.2, 0.5)]) +def test_hessian(center_offset, ms_name): + test_dir = Path(ms_name).resolve().parent + xds = xds_from_ms(ms_name, + chunks={'row': -1, 'chan': -1})[0] + spw = xds_from_table(f'{ms_name}::SPECTRAL_WINDOW')[0] + uvw = xds.UVW.values + freq = spw.CHAN_FREQ.values.squeeze() + nrow = uvw.shape[0] - freqs = np.linspace(700e6, 2000e6, nchan) + nchan = freq.size + + umax = np.abs(uvw[:, 0]).max() + vmax = np.abs(uvw[:, 1]).max() + uv_max = np.maximum(umax, vmax) + max_freq = freq.max() + + x0, y0 = center_offset + nx, ny, nx_psf, ny_psf, cell_N, cell_rad = set_image_size( + uv_max, + max_freq, + 1.5, + 2.0) + + flip_u, flip_v, flip_w, x0, y0 = wgridder_conventions(x0, y0) + epsilon = 1e-10 + signu = -1.0 if flip_u else 1.0 + signv = -1.0 if flip_v else 1.0 + # we need these in the test because of flipped wgridder convention + # https://github.com/mreineck/ducc/issues/34 + signx = -1.0 if flip_u else 1.0 + signy = -1.0 if flip_v else 1.0 + + # produce PSF visibilities centered at x0, y0 + n = np.sqrt(1 - x0**2 - y0**2) + freqfactor = -2j*np.pi*freq[None, :]/lightspeed + psf_vis = np.exp(freqfactor*(signu*uvw[:, 0:1]*x0*signx + + signv*uvw[:, 1:2]*y0*signy)) - epsilon = 1e-12 - uvwneg = uvw.copy() - uvwneg[:, 2] *= -1 + + x = np.zeros((nx, ny), dtype='f8') + x[nx//2, ny//2] = 1.0 psf = vis2dirty( - uvw=uvwneg, - freq=freqs, - vis=np.ones((nrow, nchan), dtype='c16'), + uvw=uvw, + freq=freq, + vis=psf_vis, wgt=None, - npix_x=2*nx, - npix_y=2*ny, - pixsize_x=dl, - pixsize_y=dm, - center_x=l0, - center_y=m0, + npix_x=nx_psf, + npix_y=ny_psf, + pixsize_x=cell_rad, + pixsize_y=cell_rad, + center_x=x0, + center_y=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, epsilon=epsilon, do_wgridding=False, - flip_v=False, - divide_by_n=False, # else we also need it in the PSF convolve - nthreads=1, + divide_by_n=False, # else we also need it in PSF convolve + nthreads=8, verbosity=0, ) @@ -63,37 +142,33 @@ def test_hessian(): nthreads=8, forward=True, inorm=0) - x = np.random.normal(size=(nx, ny)) - # x[...] = 0.0 - # x[nx//2, ny//2] = 1.0 - # x[nx//2-1, ny//2] = 0.5 - # x[nx//2, ny//2-1] = 0.5 - beam = np.ones((nx, ny)) + res1 = hessian( x, - uvw, - np.ones((nrow, nchan), dtype='f8'), - np.ones((nrow, nchan), dtype=np.uint8), - freqs, - beam, - x0=l0, y0=m0, - cell=pixsize, + uvw=uvw, + weight=np.ones((nrow, nchan), dtype='f8'), + vis_mask=np.ones((nrow, nchan), dtype=np.uint8), + freq=freq, + cell=cell_rad, + x0=x0, + y0=y0, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, do_wgridding=False, epsilon=epsilon, double_accum=True, nthreads=8 ) - res2 = psf_convolve_slice(np.zeros((2*nx, 2*ny)), + res2 = psf_convolve_slice(np.zeros((nx_psf, ny_psf)), np.zeros_like(psfhat), np.zeros_like(x), psfhat, - 2*ny, - x*beam, + ny_psf, + x, nthreads=8) scale = np.abs(res2).max() - - diff = res2-res1 + diff = (res2-res1)/scale assert np.allclose(1 + diff, 1) - diff --git a/tests/test_klean.py b/tests/test_klean.py index 5a80b6eb2..d27f98f99 100644 --- a/tests/test_klean.py +++ b/tests/test_klean.py @@ -32,7 +32,8 @@ def test_klean(do_gains, ms_name): from daskms.experimental.zarr import xds_to_zarr from pfb.utils.naming import xds_from_url from africanus.constants import c as lightspeed - from ducc0.wgridder import dirty2vis + from ducc0.wgridder.experimental import dirty2vis + from pfb.operators.gridder import wgridder_conventions test_dir = Path(ms_name).resolve().parent @@ -90,6 +91,7 @@ def test_klean(do_gains, ms_name): # model vis epsilon = 1e-7 + flip_u, flip_v, flip_w, x0, y0 = wgridder_conventions(0.0, 0.0) model_vis = np.zeros((nrow, nchan, ncorr), dtype=np.complex128) for c in range(nchan): model_vis[:, c:c+1, 0] = dirty2vis(uvw=uvw, @@ -98,6 +100,9 @@ def test_klean(do_gains, ms_name): pixsize_x=cell_rad, pixsize_y=cell_rad, epsilon=epsilon, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, do_wgridding=True, nthreads=8) model_vis[:, c, -1] = model_vis[:, c, 0] diff --git a/tests/test_sara.py b/tests/test_sara.py index 282fb1ec1..2c08a981a 100644 --- a/tests/test_sara.py +++ b/tests/test_sara.py @@ -33,7 +33,8 @@ def test_sara(ms_name): from pfb.utils.misc import Gaussian2D, give_edges from africanus.constants import c as lightspeed from ducc0.fft import good_size - from ducc0.wgridder import dirty2vis + from ducc0.wgridder.experimental import dirty2vis + from pfb.operators.gridder import wgridder_conventions from pfb.parser.schemas import schema from pfb.workers.init import _init from pfb.workers.grid import _grid @@ -109,6 +110,7 @@ def test_sara(ms_name): model[:, mx, my] += spectrum[:, None, None] * gauss[None, gx, gy] # model vis + flip_u, flip_v, flip_w, x0, y0 = wgridder_conventions(0.0, 0.0) epsilon = 1e-7 model_vis = np.zeros((nrow, nchan, ncorr), dtype=np.complex128) for c in range(nchan): @@ -120,7 +122,9 @@ def test_sara(ms_name): epsilon=epsilon, do_wgridding=True, divide_by_n=False, - flip_v=False, + flip_u=flip_u, + flip_v=flip_v, + flip_w=flip_w, nthreads=8, sigma_min=1.1, sigma_max=3.0) @@ -190,7 +194,8 @@ def test_sara(ms_name): _sara(**sara_args) - # get the inferred model + # the computed by the grid worker should be idenitcal to that + # computed in sara when passing in model dds = xds_from_url(dds_name) freqs_dds = [] times_dds = [] @@ -205,59 +210,7 @@ def test_sara(ms_name): ntime_dds = times_dds.size nfreq_dds = freqs_dds.size - model_inferred = np.zeros((ntime_dds, nfreq_dds, nx, ny)) - for ds in dds: - b = int(ds.bandid) - t = int(ds.timeid) - model_inferred[t, b, :, :] = ds.MODEL.values - - model2comps_args = {} - for key in schema.model2comps["inputs"].keys(): - model2comps_args[key.replace("-", "_")] = schema.model2comps["inputs"][key]["default"] - model2comps_args["output_filename"] = outname - model2comps_args["nbasisf"] = nchan - model2comps_args["fit_mode"] = 'Legendre' - model2comps_args["overwrite"] = True - model2comps_args["use_wsum"] = False - model2comps_args["sigmasq"] = 1e-14 - _model2comps(**model2comps_args) - - mds_name = f'{outname}_main_model.mds' - mds = xr.open_zarr(mds_name) - - # grid spec - cell_rad = mds.cell_rad_x - cell_deg = np.rad2deg(cell_rad) - nx = mds.npix_x - ny = mds.npix_y - x0 = mds.center_x - y0 = mds.center_y - radec = (mds.ra, mds.dec) - - # model func - params = sm.symbols(('t','f')) - params += sm.symbols(tuple(mds.params.values)) - symexpr = parse_expr(mds.parametrisation) - modelf = lambdify(params, symexpr) - texpr = parse_expr(mds.texpr) - tfunc = lambdify(params[0], texpr) - fexpr = parse_expr(mds.fexpr) - ffunc = lambdify(params[1], fexpr) - - # model coeffs - coeffs = mds.coefficients.values - locx = mds.location_x.values - locy = mds.location_y.values - - model_test = np.zeros((ntime_dds, nfreq_dds, nx, ny), dtype=float) - for i in range(ntime_dds): - tout = tfunc(times_dds[i]) - for j in range(nchan): - fout = ffunc(freqs_dds[j]) - model_test[i,j,locx,locy] = modelf(tout, fout, *coeffs) - - # models need to match exactly - assert_allclose(1 + model_test, 1 + model_inferred) + # degrid from coeffs populating MODEL_DATA degrid_args = {}