Skip to content

Commit

Permalink
Fixes various bugs and simplifies the usage
Browse files Browse the repository at this point in the history
ADDS: _is_increasing to BernsteinBijector -> allows evaluation of quartiles
ADDS: _mean to Bernstein Flow
FIXES: clipping interpolation values to prevent INFs in inverse
FIXES: reshape_out with partially defined tensor shape now working
FIXES: images in notebooks
MODIFIES: constrain_theta activation now allows negative coefficients
MODIFIES: minimal distance of thetas to 1e-4
REMOVES: order argument from BernsteinBijector (infers the order from theta)
  • Loading branch information
MArpogaus committed Dec 23, 2020
1 parent 824b87e commit fdfb7ef
Show file tree
Hide file tree
Showing 13 changed files with 318 additions and 362 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Bernstein-Polynomials as TensorFlow Probability Bijector

This Repository contains a implementation of a normalizing flow for conditional density estimation using Bernstein polynomials, as proposed in:
This Repository contains an implementation of a normalizing flow for conditional density estimation using Bernstein polynomials, as proposed in:

> Sick Beate, Hothorn Torsten and Dürr Oliver, *Deep transformation models: Tackling complex regression problems with neural network based transformation models*, 2020. [online](http://arxiv.org/abs/2004.00464)
Expand All @@ -32,9 +32,9 @@ However, the shape of the data distribution in many real use cases is much more

The following example of a classical data set containing the waiting time between eruptions of the [Old Faithful Geyser](https://en.wikipedia.org/wiki/Old_Faithful) in [Yellowstone National Park](https://en.wikipedia.org/wiki/Yellowstone_National_Park) is used as an example.

| Gaussian | Normalizing Flow |
|:-------------------------------------------------------------|:-------------------------------------------|
| ![gauss](gfx/gauss.png) | ![flow](gfx/flow.png) |
| Gaussian | Normalizing Flow |
|:--------------------------------|:------------------------------|
| ![gauss](./ipynb/gfx/gauss.png) | ![flow](./ipynb/gfx/flow.png) |

As shown in the left figure, the normality assumption is clearly violated by the bimodal nature of the data.
However, the proposed transformation model has the flexibility to adapt to this complexity.
Expand All @@ -51,7 +51,7 @@ Pull and install it directly from git using pip:
pip install git+https://github.com/MArpogaus/TensorFlow-Probability-Bernstein-Polynomial-Bijector.git
```

Or clone this repository an install it from there:
Or clone this repository and install it from there:

```bash
git clone https://github.com/MArpogaus/TensorFlow-Probability-Bernstein-Polynomial-Bijector.git ./bernstein_flow
Expand Down
40 changes: 24 additions & 16 deletions ipynb/Gaussian_vs_Transformation_Model.ipynb

Large diffs are not rendered by default.

260 changes: 107 additions & 153 deletions ipynb/TheoreticalBackground.ipynb

Large diffs are not rendered by default.

File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
122 changes: 54 additions & 68 deletions src/bernstein_flow/bijectors/bernstein_bijector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
#!env python3
# AUTHOR INFORMATION ##########################################################
# file : bernstein_bijector.py
# brief : [Description]
# file : bernstein_bijector.py
# brief : [Description]
#
# author : Marcel Arpogaus
# created : 2020-09-11 14:14:24
# changed : 2020-12-07 16:29:11
# DESCRIPTION #################################################################
#
# This project is following the PEP8 style guide:
#
# https://www.python.org/dev/peps/pep-0008/)
#
# author : Marcel Arpogaus
# date : 2020-09-11 14:14:24
# COPYRIGHT ###################################################################
# Copyright 2020 Marcel Arpogaus
#
Expand All @@ -19,18 +26,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTES ######################################################################
#
# This project is following the
# [PEP8 style guide](https://www.python.org/dev/peps/pep-0008/)
#
# CHANGELOG ##################################################################
# modified by : Marcel Arpogaus
# modified time : 2020-10-14 20:24:44
# changes made : ...
# modified by : Marcel Arpogaus
# modified time : 2020-09-11 14:14:24
# changes made : newly written
###############################################################################

# REQUIRED PYTHON MODULES #####################################################
Expand All @@ -45,6 +40,7 @@
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.internal import prefer_static


class BernsteinBijector(tfb.Bijector):
Expand All @@ -54,15 +50,12 @@ class BernsteinBijector(tfb.Bijector):
"""

def __init__(self,
order: int,
theta: tf.Tensor,
validate_args: bool = False,
name: str = 'bernstein_bijector'):
"""
Constructs a new instance of a Bernstein polynomial bijector.
:param order: The order of the Bernstein polynomial.
:type order: int
:param theta: The Bernstein coefficients.
:type theta: Tensor
:param validate_args: Whether to validate input with asserts.
Expand All @@ -78,12 +71,9 @@ def __init__(self,
self.theta = tensor_util.convert_nonref_to_tensor(
theta, dtype=dtype)

self.order = order

if tensorshape_util.rank(self.theta.shape) == 1:
self.batch_shape = tf.TensorShape([1])
else:
self.batch_shape = tf.TensorShape([self.theta.shape[0]])
shape = prefer_static.shape(self.theta)
self.order = shape[-1]
self.batch_shape = shape[:-1]

# Bernstein polynomials of order M,
# generated by the M + 1 beta-densities
Expand All @@ -109,31 +99,41 @@ def gen_inverse_interpolation(self) -> None:
"""
Generates the Spline Interpolation.
"""
y_fit = np.linspace(.0, 1, self.order * 10,
dtype=np.float32)[..., tf.newaxis]
n_points = 200
rank = tensorshape_util.rank(self.batch_shape)
shape = [...] + [tf.newaxis] * rank

y_fit = np.linspace(1e-5, 1-1e-5, n_points, dtype=np.float32)

z_fit = self.forward(y_fit)
z_fit = self.forward(y_fit[tuple(shape)])
z_fit = z_fit.numpy().reshape(n_points, -1)

self.z_min = np.min(z_fit, axis=0).reshape(-1, 1)
self.z_max = np.max(z_fit, axis=0).reshape(-1, 1)
self.z_min = np.min(z_fit, axis=0)
self.z_max = np.max(z_fit, axis=0)

ips = [I.interp1d(
x=np.squeeze(z_fit[..., i]),
y=np.squeeze(y_fit),
kind='cubic'
) for i in range(self.batch_shape[0])]
kind='cubic',
# bc_type='natural',
assume_sorted=True
) for i in range(z_fit.shape[-1])]

def ifn(z):
y = []
z_clip = np.clip(z, self.z_min + 1E-5, self.z_max - 1E-5)
z_clip = np.clip(z, self.z_min, self.z_max)
for i, ip in enumerate(ips):
y.append(ip(z_clip[:, i]).astype(np.float32))
y = np.stack(y, axis=1)

y.append(ip(z_clip[..., i]).astype(np.float32))
y = np.stack(y, axis=-1)
return y

self.interp = ifn

def reshape_out(self, sample_shape, y):
output_shape = prefer_static.broadcast_shape(
sample_shape, self.batch_shape)
return tf.reshape(y, output_shape)

def _inverse(self, z: tf.Tensor) -> tf.Tensor:
"""
Returns the inverse Bijector evaluation.
Expand All @@ -145,23 +145,18 @@ def _inverse(self, z: tf.Tensor) -> tf.Tensor:
:rtype: Tensor
"""
if tf.executing_eagerly():
if (tf.rank(z) == 0):
def reshape_out(y): return tf.squeeze(y)
elif z.shape == self.batch_shape:
# [sample_shape, batch_shape, event_shape]
z = z[tf.newaxis, ...]
def reshape_out(y): return y[0]
elif (tf.rank(z) == 2) and (z.shape[1] == self.batch_shape[0]):
# [sample_shape, batch_shape, event_shape]
z = z[..., tf.newaxis]
def reshape_out(y): return y.mean(axis=1) # [None,...]
batch_rank = tensorshape_util.rank(self.batch_shape)
sample_shape = z.shape

if sample_shape[-batch_rank:] == self.batch_shape:
shape = list(sample_shape[:-batch_rank]) + [-1]
z = tf.reshape(z, shape)
else:
z = z[..., tf.newaxis]
def reshape_out(y): return y[..., 0]
z = z[..., None]

if self.interp is None:
self.gen_inverse_interpolation()
y = reshape_out(self.interp(z))
y = self.reshape_out(sample_shape, self.interp(z))
else:
y = z

Expand All @@ -177,37 +172,25 @@ def _forward(self, y: tf.Tensor) -> tf.Tensor:
:returns: The forward Bijector evaluation.
:rtype: Tensor
"""
# if (tensorshape_util.rank(y.shape) == 1) and \
# (y.shape[0] != self.batch_shape):
# #y = tf.transpose(y, [1, 0])
# # [sample_shape, batch_shape, event_shape]
# y = y[..., tf.newaxis, tf.newaxis]
# def reshape_out(z): return z#tf.transpose(z, [1, 0])
# else:
sample_shape = prefer_static.shape(y)
y = y[..., tf.newaxis]

y = tf.clip_by_value(y, 1E-5, 1.0 - 1E-5)
y = tf.clip_by_value(y, 0, 1.0)
by = self.beta_dist_h.prob(y)
z = tf.reduce_mean(by * self.theta, axis=-1)

return z
return self.reshape_out(sample_shape, z)

def _forward_log_det_jacobian(self, y):
# if (tensorshape_util.rank(y.shape) == 1) and \
# (y.shape[0] != self.batch_shape):
# #y = tf.transpose(y, [1, 0])
# # [sample_shape, batch_shape, event_shape]
# y = y[..., tf.newaxis, tf.newaxis]
# def reshape_out(z): return z#tf.transpose(z, [1, 0])
# else:
sample_shape = prefer_static.shape(y)
y = y[..., tf.newaxis]

y = tf.clip_by_value(y, 1E-5, 1.0 - 1E-5)
y = tf.clip_by_value(y, 0, 1.0)
by = self.beta_dist_h_dash.prob(y)
dtheta = self.theta[..., 1:] - self.theta[..., 0:-1]
ldj = tf.math.log(tf.reduce_sum(by * dtheta, axis=-1))

return ldj
return self.reshape_out(sample_shape, ldj)

@classmethod
def constrain_theta(cls: type,
Expand All @@ -229,5 +212,8 @@ def constrain_theta(cls: type,
"""
d = tf.concat((tf.zeros_like(theta_unconstrained[..., :1]),
theta_unconstrained[..., :1],
fn(theta_unconstrained[..., 1:])), axis=-1)
fn(theta_unconstrained[..., 1:]) + 1e-4), axis=-1)
return tf.cumsum(d[..., 1:], axis=-1)

def _is_increasing(self, **kwargs):
return tf.reduce_all(self.theta[..., 1:] >= self.theta[..., :-1])
Loading

0 comments on commit fdfb7ef

Please sign in to comment.