Skip to content

Replace inner1d bei einsum #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions sfs/fd/wfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def plot(d, selection, secondary_source):

"""
import numpy as _np
from numpy.core.umath_tests import inner1d as _inner1d
from scipy.special import hankel2 as _hankel2

from . import secondary_source_line as _secondary_source_line
Expand Down Expand Up @@ -91,7 +90,7 @@ def line_2d(omega, x0, n0, xs, *, c=None):
k = _util.wavenumber(omega, c)
ds = x0 - xs
r = _np.linalg.norm(ds, axis=1)
d = -1j/2 * k * _inner1d(ds, n0) / r * _hankel2(1, k * r)
d = -1j / 2 * k * _np.einsum('ij,ij->i', ds, n0) / r * _hankel2(1, k * r)
selection = _util.source_selection_line(n0, x0, xs)
return d, selection, _secondary_source_line(omega, c)

Expand Down Expand Up @@ -147,7 +146,8 @@ def _point(omega, x0, n0, xs, *, c=None):
k = _util.wavenumber(omega, c)
ds = x0 - xs
r = _np.linalg.norm(ds, axis=1)
d = 1j * k * _inner1d(ds, n0) / r ** (3 / 2) * _np.exp(-1j * k * r)
d = 1j * k * _np.einsum('ij,ij->i', ds, n0) / r**(3 / 2) * _np.exp(
-1j * k * r)
selection = _util.source_selection_point(n0, x0, xs)
return d, selection, _secondary_source_point(omega, c)

Expand Down Expand Up @@ -234,7 +234,7 @@ def point_25d(omega, x0, n0, xs, xref=[0, 0, 0], c=None, omalias=None):
preeq_25d(omega, omalias, c) *
_np.sqrt(8 * _np.pi) *
_np.sqrt((r * s) / (r + s)) *
_inner1d(n0, ds) / s *
_np.einsum('ij,ij->i', ds, n0) / s *
_np.exp(-1j * k * s) / (4 * _np.pi * s))
selection = _util.source_selection_point(n0, x0, xs)
return d, selection, _secondary_source_point(omega, c)
Expand Down Expand Up @@ -316,8 +316,8 @@ def point_25d_legacy(omega, x0, n0, xs, xref=[0, 0, 0], c=None, omalias=None):
r = _np.linalg.norm(ds, axis=1)
d = (
preeq_25d(omega, omalias, c) *
_np.sqrt(_np.linalg.norm(xref - x0)) * _inner1d(ds, n0) /
r ** (3 / 2) * _np.exp(-1j * k * r))
_np.sqrt(_np.linalg.norm(xref - x0)) * _np.einsum('ij,ij->i', ds, n0) /
r**(3 / 2) * _np.exp(-1j * k * r))
selection = _util.source_selection_point(n0, x0, xs)
return d, selection, _secondary_source_point(omega, c)

Expand Down Expand Up @@ -499,7 +499,8 @@ def _focused(omega, x0, n0, xs, ns, *, c=None):
k = _util.wavenumber(omega, c)
ds = x0 - xs
r = _np.linalg.norm(ds, axis=1)
d = 1j * k * _inner1d(ds, n0) / r ** (3 / 2) * _np.exp(1j * k * r)
d = 1j * k * _np.einsum('ij,ij->i', ds, n0) / r**(3 / 2) * _np.exp(
1j * k * r)
selection = _util.source_selection_focused(ns, x0, xs)
return d, selection, _secondary_source_point(omega, c)

Expand Down Expand Up @@ -569,8 +570,8 @@ def focused_25d(omega, x0, n0, xs, ns, *, xref=[0, 0, 0], c=None,
r = _np.linalg.norm(ds, axis=1)
d = (
preeq_25d(omega, omalias, c) *
_np.sqrt(_np.linalg.norm(xref - x0)) * _inner1d(ds, n0) /
r ** (3 / 2) * _np.exp(1j * k * r))
_np.sqrt(_np.linalg.norm(xref - x0)) * _np.einsum('ij,ij->i', ds, n0) /
r**(3 / 2) * _np.exp(1j * k * r))
selection = _util.source_selection_focused(ns, x0, xs)
return d, selection, _secondary_source_point(omega, c)

Expand Down Expand Up @@ -682,22 +683,22 @@ def soundfigure_3d(omega, x0, n0, figure, npw=[0, 0, 1], *, c=None):
figure = _np.fft.fftshift(figure, axes=(0, 1)) # sign of spatial DFT
figure = _np.fft.fft2(figure)
# wavenumbers
kx = _np.fft.fftfreq(nx, 1./nx)
ky = _np.fft.fftfreq(ny, 1./ny)
kx = _np.fft.fftfreq(nx, 1. / nx)
ky = _np.fft.fftfreq(ny, 1. / ny)
# shift spectrum due to desired plane wave
figure = _np.roll(figure, int(k*npw[0]), axis=0)
figure = _np.roll(figure, int(k*npw[1]), axis=1)
figure = _np.roll(figure, int(k * npw[0]), axis=0)
figure = _np.roll(figure, int(k * npw[1]), axis=1)
# search and iterate over propagating plane wave components
kxx, kyy = _np.meshgrid(kx, ky, sparse=True)
rho = _np.sqrt((kxx) ** 2 + (kyy) ** 2)
rho = _np.sqrt((kxx)**2 + (kyy)**2)
d = 0
for n in range(nx):
for m in range(ny):
if(rho[n, m] < k):
if (rho[n, m] < k):
# dispertion relation
kz = _np.sqrt(k**2 - rho[n, m]**2)
# normal vector of plane wave
npw = 1/k * _np.asarray([kx[n], ky[m], kz])
npw = 1 / k * _np.asarray([kx[n], ky[m], kz])
npw = npw / _np.linalg.norm(npw)
# driving function of plane wave with positive kz
d_component, selection, secondary_source = plane_3d(
Expand Down
15 changes: 7 additions & 8 deletions sfs/td/wfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def plot(d, selection, secondary_source, t=0):

"""
import numpy as _np
from numpy.core.umath_tests import inner1d as _inner1d

from . import apply_delays as _apply_delays
from . import secondary_source_point as _secondary_source_point
Expand Down Expand Up @@ -119,8 +118,8 @@ def plane_25d(x0, n0, n=[0, 1, 0], xref=[0, 0, 0], c=None):
n = _util.normalize_vector(n)
xref = _util.asarray_1d(xref)
g0 = _np.sqrt(2 * _np.pi * _np.linalg.norm(xref - x0, axis=1))
delays = _inner1d(n, x0) / c
weights = 2 * g0 * _inner1d(n, n0)
delays = _np.einsum('i,ji->j', n, x0) / c
weights = 2 * g0 * _np.einsum('i,ji->j', n, n0)
selection = _util.source_selection_plane(n0, n)
return delays, weights, selection, _secondary_source_point(c)

Expand Down Expand Up @@ -208,7 +207,7 @@ def point_25d(x0, n0, xs, xref=[0, 0, 0], c=None):
g0 *= _np.sqrt((x0xs_n*x0xref_n)/(x0xs_n+x0xref_n))

delays = x0xs_n/c
weights = g0*_inner1d(x0xs, n0)
weights = g0*_np.einsum('ij,ij->i', x0xs, n0)
selection = _util.source_selection_point(n0, x0, xs)
return delays, weights, selection, _secondary_source_point(c)

Expand Down Expand Up @@ -295,8 +294,8 @@ def point_25d_legacy(x0, n0, xs, xref=[0, 0, 0], c=None):
g0 = _np.sqrt(2 * _np.pi * _np.linalg.norm(xref - x0, axis=1))
ds = x0 - xs
r = _np.linalg.norm(ds, axis=1)
delays = r/c
weights = g0 * _inner1d(ds, n0) / (2 * _np.pi * r**(3/2))
delays = r / c
weights = g0 * _np.einsum('ij,ij->i', ds, n0) / (2 * _np.pi * r**(3 / 2))
selection = _util.source_selection_point(n0, x0, xs)
return delays, weights, selection, _secondary_source_point(c)

Expand Down Expand Up @@ -378,8 +377,8 @@ def focused_25d(x0, n0, xs, ns, xref=[0, 0, 0], c=None):
r = _np.linalg.norm(ds, axis=1)
g0 = _np.sqrt(_np.linalg.norm(xref - x0, axis=1)
/ (_np.linalg.norm(xref - x0, axis=1) + r))
delays = -r/c
weights = g0 * _inner1d(ds, n0) / (2 * _np.pi * r**(3/2))
delays = -r / c
weights = g0 * _np.einsum('ij,ij->i', ds, n0) / (2 * _np.pi * r**(3 / 2))
selection = _util.source_selection_focused(ns, x0, xs)
return delays, weights, selection, _secondary_source_point(c)

Expand Down
14 changes: 7 additions & 7 deletions sfs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import collections
import numpy as np
from numpy.core.umath_tests import inner1d
from scipy.special import spherical_jn, spherical_yn
from . import default

Expand Down Expand Up @@ -51,7 +50,7 @@ def wavenumber(omega, c=None):
return omega / c


def direction_vector(alpha, beta=np.pi/2):
def direction_vector(alpha, beta=np.pi / 2):
"""Compute normal vector from azimuth, colatitude."""
return sph2cart(alpha, beta, 1)

Expand Down Expand Up @@ -503,19 +502,20 @@ def image_sources_for_box(x, L, N, *, prune=True):
number of reflections at individual walls for each source.

"""

def _images_1d_unit_box(x, N):
result = np.arange(-N, N + 1, dtype=x.dtype)
result[N % 2::2] += x
result[1 - (N % 2)::2] += 1 - x
return result

def _count_walls_1d(a):
b = np.floor(a/2)
c = np.ceil((a-1)/2)
b = np.floor(a / 2)
c = np.ceil((a - 1) / 2)
return np.abs(np.stack([b, c], axis=1)).astype(int)

L = asarray_1d(L)
x = asarray_1d(x)/L
x = asarray_1d(x) / L
D = len(x)
xs = [_images_1d_unit_box(coord, N) for coord in x]
xs = np.reshape(np.transpose(np.meshgrid(*xs, indexing='ij')), (-1, D))
Expand Down Expand Up @@ -576,7 +576,7 @@ def source_selection_point(n0, x0, xs):
x0 = asarray_of_rows(x0)
xs = asarray_1d(xs)
ds = x0 - xs
return inner1d(ds, n0) >= default.selection_tolerance
return np.einsum('ij,ij->i', ds, n0) >= default.selection_tolerance


def source_selection_line(n0, x0, xs):
Expand All @@ -598,7 +598,7 @@ def source_selection_focused(ns, x0, xs):
xs = asarray_1d(xs)
ns = normalize_vector(ns)
ds = xs - x0
return inner1d(ns, ds) >= default.selection_tolerance
return np.einsum('i,ji->j', ns, ds) >= default.selection_tolerance


def source_selection_all(N):
Expand Down