from __future__ import absolute_import
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
import numpy as np
from ._segmentation import _ve_step, _interaction_energy
NITERS = 10
NGB_SIZE = 26
BETA = 0.1
nonzero = lambda x: np.maximum(x, 1e-50)
log = lambda x: np.log(nonzero(x))
[docs]class Segmentation(object):
[docs] def __init__(self, data, mask=None, mu=None, sigma=None,
ppm=None, prior=None, U=None,
ngb_size=NGB_SIZE, beta=BETA):
"""
Class for multichannel Markov random field image segmentation
using the variational EM algorithm. For details regarding the
underlying algorithm, see:
Roche et al, 2011. On the convergence of EM-like algorithms
for image segmentation using Markov random fields. Medical
Image Analysis (DOI: 10.1016/j.media.2011.05.002).
Parameters
----------
data : array-like
Input image array
mask : array-like or tuple of array
Input mask to restrict the segmentation
beta : float
Markov regularization parameter
mu : array-like
Initial class-specific means
sigma : array-like
Initial class-specific variances
"""
data = data.squeeze()
if not len(data.shape) in (3, 4):
raise ValueError('Invalid input image')
if len(data.shape) == 3:
nchannels = 1
space_shape = data.shape
else:
nchannels = data.shape[-1]
space_shape = data.shape[0:-1]
self.nchannels = nchannels
# Make default mask (required by MRF regularization). This wil
# be passed to the _ve_step C-routine, which assumes a
# contiguous int array and raise an error otherwise. Voxels on
# the image borders are further rejected to avoid segmentation
# faults.
if mask is None:
mask = np.ones(space_shape, dtype=bool)
X, Y, Z = np.where(mask)
XYZ = np.zeros((X.shape[0], 3), dtype='intp')
XYZ[:, 0], XYZ[:, 1], XYZ[:, 2] = X, Y, Z
self.XYZ = XYZ
self.mask = mask
self.data = data[mask]
if nchannels == 1:
self.data = np.reshape(self.data, (self.data.shape[0], 1))
# By default, the ppm is initialized as a collection of
# uniform distributions
if ppm is None:
nclasses = len(mu)
self.ppm = np.zeros(list(space_shape) + [nclasses])
self.ppm[mask] = 1. / nclasses
self.is_ppm = False
self.mu = np.array(mu, dtype='double').reshape(\
(nclasses, nchannels))
self.sigma = np.array(sigma, dtype='double').reshape(\
(nclasses, nchannels, nchannels))
elif mu is None:
nclasses = ppm.shape[-1]
self.ppm = np.asarray(ppm)
self.is_ppm = True
self.mu = np.zeros((nclasses, nchannels))
self.sigma = np.zeros((nclasses, nchannels, nchannels))
else:
raise ValueError('missing information')
self.nclasses = nclasses
if prior is not None:
self.prior = np.asarray(prior)[self.mask].reshape(\
[self.data.shape[0], nclasses])
else:
self.prior = None
self.ngb_size = int(ngb_size)
self.set_markov_prior(beta, U=U)
[docs] def set_markov_prior(self, beta, U=None):
if U is not None: # make sure it's C-contiguous
self.U = np.asarray(U).copy()
else: # Potts model
U = np.ones((self.nclasses, self.nclasses))
U[_diag_indices(self.nclasses)] = 0
self.U = U
self.beta = float(beta)
[docs] def vm_step(self, freeze=()):
classes = list(range(self.nclasses))
for i in freeze:
classes.remove(i)
for i in classes:
P = self.ppm[..., i][self.mask].ravel()
Z = nonzero(P.sum())
tmp = self.data.T * P.T
mu = tmp.sum(1) / Z
mu_ = mu.reshape((len(mu), 1))
sigma = np.dot(tmp, self.data) / Z - np.dot(mu_, mu_.T)
self.mu[i] = mu
self.sigma[i] = sigma
[docs] def log_external_field(self):
"""
Compute the logarithm of the external field, where the
external field is defined as the likelihood times the
first-order component of the prior.
"""
lef = np.zeros([self.data.shape[0], self.nclasses])
for i in range(self.nclasses):
centered_data = self.data - self.mu[i]
if self.nchannels == 1:
inv_sigma = 1. / nonzero(self.sigma[i])
norm_factor = np.sqrt(inv_sigma.squeeze())
else:
inv_sigma = np.linalg.inv(self.sigma[i])
norm_factor = 1. / np.sqrt(\
nonzero(np.linalg.det(self.sigma[i])))
maha_dist = np.sum(centered_data * np.dot(inv_sigma,
centered_data.T).T, 1)
lef[:, i] = -.5 * maha_dist
lef[:, i] += log(norm_factor)
if self.prior is not None:
lef += log(self.prior)
return lef
[docs] def normalized_external_field(self):
f = self.log_external_field().T
f -= np.max(f, 0)
np.exp(f, f)
f /= f.sum(0)
return f.T
[docs] def ve_step(self):
nef = self.normalized_external_field()
if self.beta == 0:
self.ppm[self.mask] = np.reshape(\
nef, self.ppm[self.mask].shape)
else:
self.ppm = _ve_step(self.ppm, nef, self.XYZ,
self.U, self.ngb_size, self.beta)
[docs] def run(self, niters=NITERS, freeze=()):
if self.is_ppm:
self.vm_step(freeze=freeze)
for i in range(niters):
self.ve_step()
self.vm_step(freeze=freeze)
self.is_ppm = True
[docs] def map(self):
"""
Return the maximum a posterior label map
"""
return map_from_ppm(self.ppm, self.mask)
[docs] def free_energy(self, ppm=None):
"""
Compute the free energy defined as:
F(q, theta) = int q(x) log q(x)/p(x,y/theta) dx
associated with input parameters mu,
sigma and beta (up to an ignored constant).
"""
if ppm is None:
ppm = self.ppm
q = ppm[self.mask]
# Entropy term
lef = self.log_external_field()
f1 = np.sum(q * (log(q) - lef))
# Interaction term
if self.beta > 0.0:
f2 = self.beta * _interaction_energy(ppm, self.XYZ,
self.U, self.ngb_size)
else:
f2 = 0.0
return f1 + f2
def _diag_indices(n, ndim=2):
# diag_indices function present in numpy 1.4 and later. This for
# compatibility with numpy < 1.4
idx = np.arange(n)
return (idx,) * ndim
[docs]def moment_matching(dat, mu, sigma, glob_mu, glob_sigma):
"""
Moment matching strategy for parameter initialization to feed a
segmentation algorithm.
Parameters
----------
data: array
Image data.
mu : array
Template class-specific intensity means
sigma : array
Template class-specific intensity variances
glob_mu : float
Template global intensity mean
glob_sigma : float
Template global intensity variance
Returns
-------
dat_mu: array
Guess of class-specific intensity means
dat_sigma: array
Guess of class-specific intensity variances
"""
dat_glob_mu = float(np.mean(dat))
dat_glob_sigma = float(np.var(dat))
a = np.sqrt(dat_glob_sigma / glob_sigma)
b = dat_glob_mu - a * glob_mu
dat_mu = a * mu + b
dat_sigma = (a ** 2) * sigma
return dat_mu, dat_sigma
[docs]def map_from_ppm(ppm, mask=None):
x = np.zeros(ppm.shape[0:-1], dtype='uint8')
if mask is None:
mask = ppm == 0
x[mask] = ppm[mask].argmax(-1) + 1
return x
[docs]def binarize_ppm(q):
"""
Assume input ppm is masked (ndim==2)
"""
bin_q = np.zeros(q.shape)
bin_q[:q.shape[0], np.argmax(q, axis=1)] = 1
return bin_q