Source code for selectinf.truncated.base

"""
This module implements the class `truncated_gaussian` which 
performs (conditional) UMPU tests for Gaussians
restricted to a set of intervals.

"""
import numpy as np
from scipy.stats import chi
from mpmath import fsum

import warnings

from abc import ABCMeta, abstractmethod


[docs]class truncated(object): """ A distribution, truncated to a union of intervals HOW TO MAKE A SUBCLASS : You have to implement : __init__(self, args*) : It has to call the method from the base class Since the method is abstract, you can't have an instance of the subclass if the method __init__ is not implemented _cdf_notTruncated(self, a, b, dps) : With these two methods, you can use : -> cdf -> sf You should implement : _pdf_notTruncated(self, z, dps) : it allows you to use : -> pdf -> plt_pdf (if you also have _quantile_notTruncated) _quantile_notTruncated(self, q, tol) : it allows you to use : -> quantile -> rvs -> plt_cdf -> plt_pdf (if you also have _pdf_notTruncated) """ __metaclass__ = ABCMeta
[docs] @abstractmethod def __init__(self, intervals): """ Create a new truncated distribution object This method is abstract : it has to be overriden Parameters ---------- intervals : [(float, float)] The intervals the distribution is truncated to """ self.intervals = intervals dps = 15 not_precise = True while not_precise: Q = [self._cdf_notTruncated(a, b, dps) for a, b in intervals] dps *= 2 not_precise = (fsum(Q) == 0.) self._sumQ = fsum(Q) self._dps = dps self._Q = Q
@abstractmethod def _cdf_notTruncated(self, a, b, dps=15): """ Compute the probability of being in the interval (a, b) for a variable with this distribution (not truncated) Parameters ---------- a, b : float Bounds of the interval. Can be infinite. dps : int Decimal precision (decimal places). Used in mpmath Returns ------- p : float The probability of being in the intervals (a, b) P( a < X < b) for a non truncated variable WARNING : This is the fundamental method of the truncated class It has to be overriden for each distribution """ pass def _quantile_notTruncated(self, q, dps, tol=1.e-6): """ Compute the quantile for the non truncated distribution Parameters ---------- q : float quantile you want to compute. Between 0 and 1 tol : float precision for the output Returns ------- x : float x such that P(X < x) = q WARNING : This method is deprecated if not overriden. It can be very slow """ warnings.warn("""Deprecated to use the general quantile_notTruncated method : it should be overrriden""", DeprecationWarning) if dps is None: dps = self._dps f = lambda x: cdf_notTruncated(-np.inf, x, dps) lb, ub = -1.e8, 1.e8 x = find_root(f, y, lb, ub, tol=1e-6) return x
[docs] def rvs(self, size=1): """ Sample a random variable from the truncated disribution Parameters ---------- size : int Number of samples. Default : 1 Returns ------- X : np.array array of sample """ if not hasattr(self, '_quantile_notTruncated'): raise NotImplementedError( \ """The 'quantile_notTruncated' method should be implemented in order to use the truncated rvs method""" ) U = np.random.uniform(size=size) X = np.array([self.quantile(u) for u in U]) return X
[docs] def sf(self, z): """ Compute the survival function of the truncated distribution Parameters ---------- z : float Minimum bound of the interval Returns ------- sf : float The survival function of the truncated distribution sf(z) = P( X > z | X is in intervals ) """ intervals = self.intervals Q, sumQ = self._Q, self._sumQ N = len(Q) dps = self._dps k, (a, b) = min( (k, (a, b)) for k, (a, b) in enumerate(intervals) if b > z) sf = fsum(Q[k+1:]) + self._cdf_notTruncated(max(a, z), b, dps) sf /= sumQ return sf
[docs] def cdf(self, z): """ Compute the survival function of the truncated distribution Parameters ---------- z : float Minimum bound of the interval Returns ------- cdf : float function The cumulative distribution function of the truncated distribution cdf(z) = P( X < z | X is in intervals ) WARNING : This method only use the sf method : it is never going to be more precise """ cdf = 1. - self.sf(z) return cdf
[docs] def pdf(self, z): """ Compute the probability distribution funtion of the truncated distribution Parameters ---------- z : float Returns ------- p : float p(z) such that E[f(X)] = \int f(z)p(z)dz """ if not hasattr(self, '_pdf_notTruncated'): raise NotImplementedError( \ """The 'pdf_notTruncated' should be implemented in order to use the truncated pdf method""" ) intervals = self.intervals dps = self._dps if intervals(z): p = self._pdf_notTruncated(z, dps) p /= self._sumQ else: p = 0 return p
[docs] def quantile(self, q, tol=1.e-6): if not hasattr(self, '_quantile_notTruncated'): raise NotImplementedError( \ """The 'quantile_notTruncated' method should be implemented in order to use the truncated quantile method""" ) Q = self._Q sumQ = self._sumQ intervals = self.intervals dps = self._dps cum_sum = np.cumsum( Q ) k = min( i for i, c in enumerate(cum_sum) if c > q*sumQ ) a, b = intervals[k] q_notTruncated = q*sumQ + self._cdf_notTruncated(-np.inf, a, dps) if k>0: q_notTruncated -= cum_sum[k-1] q_notTruncated = float(q_notTruncated) z = self._quantile_notTruncated(q_notTruncated, tol) return z
[docs]def find_root(f, y, lb, ub, tol=1e-6): """ searches for solution to f(x) = y in (lb, ub), where f is a monotone decreasing function """ # make sure solution is in range a, b = lb, ub fa, fb = f(a), f(b) # assume a < b if fa > y and fb > y: while fb > y : b, fb = b + (b-a), f(b + (b-a)) elif fa < y and fb < y: while fa < y : a, fa = a - (b-a), f(a - (b-a)) # determine the necessary number of iterations try: max_iter = int( np.ceil( ( np.log(tol) - np.log(b-a) ) / np.log(0.5) ) ) except OverflowError: warnings.warn('root finding failed, returning np.nan') return np.nan # bisect (slow but sure) until solution is obtained for _ in range(max_iter): try: c, fc = (a+b)/2, f((a+b)/2) if fc > y: a = c elif fc < y: b = c except OverflowError: warnings.warn('root finding failed, returning np.nan') return np.nan return c