import matplotlib.pyplot as plt
import numpy as np
from scipy import optimize
from astropy.convolution import Gaussian1DKernel, convolve
from psutil import cpu_count
from joblib import parallel_backend
import arviz as az
import warnings
warnings.filterwarnings("ignore")
from seismolab.fourier import MultiHarmonicFitter
from matplotlib.collections import LineCollection
import joblib
from joblib import delayed
from tqdm.auto import tqdm
__all__ = ['TemplateFitter']
class ProgressParallel(joblib.Parallel):
def __init__(self, total=None, **kwds):
self.total = total
super().__init__(**kwds)
def __call__(self, *args, **kwargs):
with tqdm() as self._pbar:
return joblib.Parallel.__call__(self, *args, **kwargs)
def print_progress(self):
if self.total is None:
self._pbar.total = self.n_dispatched_tasks
else:
self._pbar.total = self.total
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()
def make_segments(x, y):
'''
Create list of line segments from x and y coordinates, in the correct format for LineCollection:
an array of the form numlines x (points per line) x 2 (x and y) array
'''
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
return segments
def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=3, alpha=1.0, ax=None):
'''
Plot a colored line with coordinates x and y
Optionally specify colors in the array z
Optionally specify a colormap, a norm function and a line width
'''
# Default colors equally spaced on [0,1]:
if z is None:
z = np.linspace(0.0, 1.0, len(x))
# Special case if a single number:
if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
z = np.array([z])
z = np.asarray(z)
segments = make_segments(x, y)
lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
if ax is None: ax = plt.gca()
ax.add_collection(lc)
return lc
def splitthem(inputBJD, inputflux,fluxerror,span,step,n):
"""
Split light curve into chunks
"""
um = (inputBJD>=inputBJD[0]+step*n) & (inputBJD<inputBJD[0]+step*n+span)
dfbitlistBJD = inputBJD[um]
dfbitlistflux = inputflux[um]
if fluxerror is None:
dfbitlistfluxerror = None
else:
dfbitlistfluxerror = fluxerror[um]
midBJD = inputBJD[0]+step*n + span/2
return midBJD, dfbitlistBJD, dfbitlistflux, dfbitlistfluxerror
def modulated_lc_model(time, a0, a, dPhi, pfit, kind):
"""
Sum of modulated sin/cos curves.
Parameters
----------
time : array
Light curve time points.
a0 : array
Relative zero point variation.
a : array
Relative amplitude variation.
dPhi : array
Relative phase variation.
pfit : array-like
Array of fitted parameters. The main frequency, amplitudes and phases of the harmonics,
and the zero point.
kind : 'sin' or 'cos'
Function type to construct template.
"""
best_freq = pfit[0]
nparams = (len(pfit)-2)//2
amps = pfit[1:1+nparams]
phases = pfit[1+nparams:-1]
const = pfit[-1]
y = 0
if kind == 'sin':
for i,(Amp,Phi) in enumerate(zip(amps,phases)):
y += a*Amp*np.sin(2*np.pi*(i+1)*best_freq*time + Phi+(i+1)*dPhi)
elif kind == 'cos':
for i,(Amp,Phi) in enumerate(zip(amps,phases)):
y += a*Amp*np.cos(2*np.pi*(i+1)*best_freq*time + Phi+(i+1)*dPhi)
y += a0*const
return y
def unpack_az_statistics(az_stat):
a0 = az_stat['mean']['a0']
a0ep = az_stat['hdi_97%']['a0'] - a0
a0em = a0 - az_stat['hdi_3%']['a0']
a = az_stat['mean']['a']
aep = az_stat['hdi_97%']['a'] - a
aem = a - az_stat['hdi_3%']['a']
psi = az_stat['mean']['psi']
psiep = az_stat['hdi_97%']['psi'] - psi
psiem = psi - az_stat['hdi_3%']['psi']
return a0, a0ep, a0em, a, aep, aem, psi, psiep, psiem
def smooth_data(a0values,avalues,psivalues,gapat,smoothness_factor,step):
gapat = np.concatenate((np.array([0]),gapat,np.array([ len(a0values) ])))
a0values_out = np.empty_like(a0values)
avalues_out = np.empty_like(avalues)
psivalues_out = np.empty_like(psivalues)
for i in range(gapat.shape[0]-1):
a0values_cut = a0values[gapat[i]:gapat[i+1]]
avalues_cut = avalues[gapat[i]:gapat[i+1]]
psivalues_cut = psivalues[gapat[i]:gapat[i+1]]
a0values_smooth = np.concatenate((a0values_cut[::-1],a0values_cut,a0values_cut[::-1]))
avalues_smooth = np.concatenate((avalues_cut[::-1],avalues_cut,avalues_cut[::-1]))
psivalues_smooth = np.concatenate((psivalues_cut[::-1],psivalues_cut,psivalues_cut[::-1]))
gauss = Gaussian1DKernel(stddev = smoothness_factor * 1/step)
a0values_smooth = convolve(a0values_smooth, gauss)
avalues_smooth = convolve(avalues_smooth, gauss)
psivalues_smooth = convolve(psivalues_smooth, gauss)
a0values_smooth = a0values_smooth[a0values_cut.shape[0]:2*a0values_cut.shape[0]]
avalues_smooth = avalues_smooth[avalues_cut.shape[0]:2*avalues_cut.shape[0]]
psivalues_smooth = psivalues_smooth[psivalues_cut.shape[0]:2*psivalues_cut.shape[0]]
a0values_out[gapat[i]:gapat[i+1]] = a0values_smooth
avalues_out[gapat[i]:gapat[i+1]] = avalues_smooth
psivalues_out[gapat[i]:gapat[i+1]] = psivalues_smooth
return a0values_out, avalues_out, psivalues_out
def fit_lightcurve_chunk(midBJD,bitBJD,bitflux,bitfluxerror,
LSPfreq,span,pfit,
duty_cycle,error_estimation,kind,
debug=False):
# ---- Skip chunk if number of pts is low ----
if debug: print('N points:',len(bitBJD))
if len(bitBJD)<4:
return [np.nan]*7
# ---- Skip chunk if duty cycle is low ----
if debug: print('Duty cycle:', np.ptp(bitBJD) / (span*1/LSPfreq) )
if np.ptp(bitBJD) < duty_cycle * span*1/LSPfreq:
return [np.nan]*7
# ---- Skip chunk if there is a large gap ----
if ~np.all(np.diff(bitBJD) < duty_cycle * span*1/LSPfreq):
if debug: print('Skipping due to large gap...')
return [np.nan]*7
# ---- Fit zp, amp, phase ----
try:
_pfit, _pcov = optimize.curve_fit(lambda x, a0, a, psi: modulated_lc_model(x, a0, a, psi, pfit, kind),
bitBJD,bitflux,
p0=[0.9,0.9,0.1],
sigma=bitfluxerror,absolute_sigma=True,method='trf')
except RuntimeError:
return [np.nan]*7
perr_curvefit = np.sqrt(np.diag(_pcov))
# Extract fitted parameters and errors
a0_val, a_val, psi_val = _pfit[0], _pfit[1], _pfit[2]
a0_err, a_err, psi_err = perr_curvefit[0], perr_curvefit[1], perr_curvefit[2]
if error_estimation == 'montecarlo':
if debug: print('Running MCMC...')
import pymc as pm
with pm.Model() as model:
## define Uniform priors
a0 = pm.Uniform("a0", 0, 2, initval=a0_val)
a = pm.Uniform("a", 0, 2, initval=a_val)
psi = pm.Uniform("psi", -1, 1, initval=psi_val)
## define model
yest = modulated_lc_model(bitBJD, a0, a, psi, pfit, kind)
## define Normal likelihood with HalfCauchy noise (fat tails, equiv to HalfT 1DoF)
likelihood = pm.Normal("likelihood", mu=yest,
sigma=bitfluxerror if bitfluxerror is not None else np.sqrt(bitflux),
observed=bitflux)
#Populate MCMC sampler
traces = pm.sample(1000,chains=1,cores=1)
_, a0ep, a0em, _, aep, aem, _, psiep, psiem = unpack_az_statistics(az.summary(traces, kind="stats"))
if debug:
az.plot_pair(traces,
var_names=['a0', 'a', 'psi'],
kind='kde',
divergences=True,
marginals=True,
textsize=18)
# Update errors with MCMC confidence intervals
a0_err = max(a0ep, a0em)
a_err = max(aep, aem)
psi_err = max(psiep, psiem)
return midBJD, a0_val, a0_err, a_val, a_err, psi_val, psi_err
[docs]class TemplateFitter:
def __init__(self, time,flux,fluxerror=None):
"""
time : array
Light curve time points.
flux : array
Corresponding flux/mag values.
fluxerror : array, optional
Corresponding flux/mag error values.
"""
time = np.asarray(time,dtype=float)
flux = np.asarray(flux,dtype=float)
if fluxerror is not None:
fluxerror = np.asarray(fluxerror,dtype=float)
goodpts = np.isfinite(time)
goodpts &= np.isfinite(flux)
if fluxerror is not None:
goodpts &= np.isfinite(fluxerror)
self.time = time[goodpts]
self.flux = flux[goodpts]
if fluxerror is not None:
self.fluxerror = fluxerror[goodpts]
else:
self.fluxerror = fluxerror
[docs] def fit(self,
span = 3,
step = 1,
error_estimation='analytic',
maxharmonics = 10,
minimum_frequency=None,
maximum_frequency=None,
nyquist_factor=1,
samples_per_peak=100,
kind='sin',
plotting=False,
scale='flux',
saveplot=False,
saveresult=False,
filename='result',
showerrorbar=True,
smoothness_factor=0.5,
duty_cycle = 0.6,
debug=False,
best_freq=None
):
"""
Compute amplitude/phase/zero point variation based on template fitting.
Parameters
----------
span : float, default: 5
Number of puls cycles to be fitted.
step : float, default: 3
Steps in number of puls cycle.
error_estimation : 'analytic' or 'montecarlo', default 'analytic'
Type of error estimation for results.
maxharmonics : int, default: 5
Max number of harmonics to be used in template.
minimum_frequency : float, optional
If specified, then use this minimum frequency rather than one
chosen based on the size of the baseline.
maximum_frequency : float, optional
If specified, then use this maximum frequency rather than one
chosen based on the average nyquist frequency.
nyquist_factor : float, optional, default: 1
The multiple of the average nyquist frequency used to choose the
maximum frequency if maximum_frequency is not provided.
samples_per_peak : float, optional, default: 100
The approximate number of desired samples across the typical peak.
kind : 'sin' or 'cos', default: 'sin'
Function type to construct template.
plotting : bool, deaful: False
Show result.
scale: 'mag' or 'flux', default: 'flux'
Lightcurve plot scale.
saveplot : bool, default: False
Save result as txt file.
saveresult : bool, default: False
Save results as txt
filename : str, default: 'result'
Beginning of txt filename.
showerrorbar : bool, default: True
Plot errorbars as well.
smoothness_factor : float, optional, default: 0.5
Level of Gaussian smoothing of amp/phase/zp values.
0: no smoothing, 0.5-1: slight smoothing, >=1: significant smoothing
duty_cycle : float, optional, default: 0.6
Minimum duty cycle that is needed in case of each light curve chunk.
Should be between 0-1.
best_freq : float, default: None
If given, then this frequency will be used as the basis of the harmonics,
instead of calculating a Lomb-Scargle spectrum to get a frequency.
debug : bool, default False
Verbose output.
Returns:
-------
times : array
Time points.
amp : array
Amplitude variation.
amperr : array
Amplitude variation error.
phase : array
Phase variation.
phaseerr : array
Phase variation error.
zp : array
Zero point variation.
zperr : array
Zero point variation error.
"""
if error_estimation not in ['analytic','montecarlo']:
raise TypeError('%s method is not supported! Please set \'error_estimation\' to \'analytic\' or \'montecarlo\'.' % str(error_estimation))
# Initialize Fourier fitter by passing your light curve
if debug: print('Calculating Lomb-Scargle...')
fitter = MultiHarmonicFitter(self.time,self.flux)
pfit,perr = fitter.fit_harmonics(maxharmonics = maxharmonics,
plotting = debug,
minimum_frequency=minimum_frequency,
maximum_frequency=maximum_frequency,
nyquist_factor=nyquist_factor,
samples_per_peak=samples_per_peak,
kind=kind,
best_freq=best_freq)
# Make Fourier results attribute
self.pfit = pfit
self.perr = perr
self.kind = kind
LSPfreq = pfit[0]
if debug:
plt.title('Template')
plt.scatter(self.time, self.flux)
plt.plot(self.time,fitter.lc_model(self.time,*pfit),c='C1')
plt.xlim(self.time.max()-3/pfit[0],self.time.max())
plt.show()
BJDmidP=[]
a0values=[]
a0errorvalues=[]
avalues=[]
aerrorvalues=[]
psivalues=[]
psierrorvalues=[]
if debug: print('Splitting them...')
params = []
for counter in range( int(np.ceil( self.time.ptp() / (step*1/LSPfreq))) ):
# ---- Get chunk ----
midBJD,bitBJD,bitflux,bitfluxerror=splitthem(
self.time,self.flux,self.fluxerror,
span=span*1/LSPfreq,
step=step*1/LSPfreq,
n=counter)
if debug:
result = fit_lightcurve_chunk(midBJD,bitBJD,bitflux,bitfluxerror,
LSPfreq,span,pfit,
duty_cycle,error_estimation,kind,
debug)
BJDmidP.append( result[0] )
a0values.append(result[1])
a0errorvalues.append(result[2])
avalues.append(result[3])
aerrorvalues.append(result[4])
psivalues.append(result[5])
psierrorvalues.append(result[6])
if ~np.all(np.isnan(bitBJD)):
plt.figure()
plt.title('Fit to subsample %d' % (counter+1))
plt.scatter(bitBJD,bitflux)
xxxx = np.linspace(min(bitBJD),max(bitBJD),1000)
plt.plot(xxxx,modulated_lc_model(xxxx, a0values[-1], avalues[-1], psivalues[-1], pfit, kind))
plt.plot(xxxx,modulated_lc_model(xxxx, a0values[-1], avalues[-1], 0, pfit, kind))
plt.show()
else:
params.append( [midBJD,bitBJD,bitflux,bitfluxerror,
LSPfreq,span,pfit,
duty_cycle,error_estimation,kind,
debug] )
if not debug:
ncores = cpu_count(logical=False)
with parallel_backend('multiprocessing'):
result = ProgressParallel(n_jobs=ncores,total=len(params))(delayed(fit_lightcurve_chunk)(*par) for par in params)
result = np.asarray(result)
BJDmidP = result[:,0]
a0values = result[:,1]
a0errorvalues = result[:,2]
avalues = result[:,3]
aerrorvalues = result[:,4]
psivalues = result[:,5]
psierrorvalues = result[:,6]
BJDmidP = np.asarray(BJDmidP)
a0values = np.asarray(a0values)
a0errorvalues = np.asarray(a0errorvalues)
avalues = np.asarray(avalues)
aerrorvalues = np.asarray(aerrorvalues)
psivalues = np.asarray(psivalues)
psierrorvalues = np.asarray(psierrorvalues)
goodpts = np.isfinite(BJDmidP)
BJDmidP = BJDmidP[goodpts]
a0values = a0values[goodpts]
a0errorvalues = a0errorvalues[goodpts]
avalues = avalues[goodpts]
aerrorvalues = aerrorvalues[goodpts]
psivalues = psivalues[goodpts]
psierrorvalues = psierrorvalues[goodpts]
# ----- Add Fourier errorbars to OC errors -----
nharmonics = (len(pfit)-2)//2
a0values = pfit[-1] * a0values
avalues = pfit[1] * avalues
psivalues = pfit[1+nharmonics] + psivalues
a0errorvalues = perr[-1] + a0errorvalues
aerrorvalues = perr[1] + aerrorvalues
psierrorvalues = perr[1+nharmonics] + psierrorvalues
# ----- Find large gaps in dataset -----
gapatorig = np.where( np.diff(self.time) > max(4.,200*np.nanmedian(np.diff(self.time))) )[0]
gapatorig += 1
if len(gapatorig) == 0:
gapsize = np.inf
else:
gapsize = np.min( self.time[gapatorig] - self.time[gapatorig-1] )
gapsize = max( gapsize, 2*np.median(np.diff(BJDmidP)) )
gapat = np.where( np.diff(BJDmidP) > gapsize )[0]
gapat += 1
# ----- Smoothing OC curve -----
if len(a0values)>0:
a0values_smooth,avalues_smooth,psivalues_smooth = smooth_data(a0values,avalues,psivalues,gapat,smoothness_factor,step)
else:
return np.array(BJDmidP), avalues,aerrorvalues, psivalues,psierrorvalues, a0values,a0errorvalues
if ~np.all(np.isnan(avalues_smooth)):
a0values = a0values_smooth
avalues = avalues_smooth
psivalues = psivalues_smooth
# ----- Plot results -----
if plotting or saveplot:
period=1/LSPfreq
BJDmodP_extended = self.time%period/period
BJDmodP_extended = np.concatenate((BJDmodP_extended,1+BJDmodP_extended))
flux_extended = np.tile(self.flux,2)
fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(20,4), gridspec_kw = {'width_ratios':[1, 2, 2, 2]})
#plt.subplots_adjust(wspace=0.28,left=0.05,right=0.98)
#ax[0].set_aspect(0.7)
ax[0].scatter(BJDmodP_extended,flux_extended,s=1)
ax[0].set_xlim(0,1.5)
ax[0].set_xlabel("Phase")
ax[0].set_ylabel("Brightness")
if scale == 'mag': ax[0].invert_yaxis()
#ax[1].set_title(os.path.basename(line))
ax[1].plot(np.insert(self.time,gapatorig,np.nan),np.insert(self.flux,gapatorig,np.nan))
ax[1].set_xlabel("Time")
#ax[1].set_ylabel("flux")
ax[1].errorbar(np.insert(BJDmidP,gapat,np.nan),np.insert(a0values,gapat,np.nan),
np.insert(a0errorvalues,gapat,np.nan) if showerrorbar else None,c='C1',alpha=0.5,label="ZP",ecolor='lightgray')
if scale == 'mag': ax[1].invert_yaxis()
ax[2].errorbar(np.insert(BJDmidP,gapat,np.nan),np.insert(avalues,gapat,np.nan),
np.insert(aerrorvalues,gapat,np.nan) if showerrorbar else None,c='r',label="Amp",ecolor='lightgray')
ax[2].set_xlabel("Time")
#ylmin,ylmax = ax[2].get_ylim()
#ylmin = min(ylmin,0.85)
#ylmax = max(ylmax,1.15)
#ax[2].set_ylim(ylmin,ylmax)
ax[2].set_title('P='+str(round(period,8))+' days')
ylmin,ylmax = ax[2].get_ylim()
ylmin = min(ylmin,avalues.mean()-10*aerrorvalues.mean())
ylmax = max(ylmax,avalues.mean()+10*aerrorvalues.mean())
ax[2].set_ylim(ylmin,ylmax)
ax2b=ax[2].twinx()
ax2b.errorbar(np.insert(BJDmidP,gapat,np.nan),np.insert(psivalues,gapat,np.nan),
np.insert(psierrorvalues,gapat,np.nan) if showerrorbar else None,label=r"$\Phi$",ecolor='lightgray',c='C0')
ylmin,ylmax = ax2b.get_ylim()
ylmin = min(ylmin,psivalues.mean()-0.10)
ylmax = max(ylmax,psivalues.mean()+0.10)
ax2b.set_ylim(ylmin,ylmax)
ax[2].legend(loc='upper left')
ax2b.legend(loc='upper right')
colorline( np.insert(avalues,gapat,np.nan),np.insert(psivalues,gapat,np.nan),ax=ax[3])
ax[3].errorbar(avalues,psivalues,
xerr=aerrorvalues if showerrorbar else None,yerr=psierrorvalues if showerrorbar else None,ecolor='lightgray',fmt='.',ms=0,zorder=0)
ax[3].set_xlim(np.min(avalues) - 0.05*np.ptp(avalues), np.max(avalues) + 0.05*np.ptp(avalues))
ax[3].set_ylim(np.min(psivalues) - 0.05*np.ptp(psivalues), np.max(psivalues) + 0.05*np.ptp(psivalues))
ax[3].set_xlabel("Amp")
ax[3].set_ylabel("Phase")
plt.tight_layout()
if saveplot:
plt.savefig(filename+'_template_OC.jpg')
if plotting:
plt.show()
plt.close(fig)
# ----- Save OC -----
if saveresult:
period=1/LSPfreq
np.savetxt(filename+'_template_OC.txt',
np.c_[BJDmidP,avalues,aerrorvalues,psivalues,psierrorvalues,a0values,a0errorvalues],
fmt='%.8f',
header='Calculated with period of %.6f\nTIME AVALS AVALS_ERR PSIVALS PSIVALS_ERR ZP ZP_ERR' % period)
# Store OC curve results
self.oc_time = np.array(BJDmidP)
self.avalues = avalues
self.aerrorvalues = aerrorvalues
self.psivalues = psivalues
self.psierrorvalues = psierrorvalues
self.a0values = a0values
self.a0errorvalues = a0errorvalues
return np.array(BJDmidP), avalues,aerrorvalues, psivalues,psierrorvalues, a0values,a0errorvalues
[docs] def get_lc_model(self, time=None, amp=None, phase=None, zp=None):
"""
Get modulated model light curve.
Parameters
----------
time : array
Time points where modulated model light curve is desired.
amp : array
Amplitude variation by time.
phase : array
Phase variation by time.
zp : array
Zero point variation by time.
Returns:
-------
ymodel : array
Modulated model light curve.
"""
if not hasattr(self,"pfit"):
warnings.warn("Please run \'fit\' first!")
return None
nharmonics = (len(self.pfit)-2)//2
if time is None:
time = self.oc_time
amp = self.avalues
phase = self.psivalues
zp = self.a0values
ymodel = modulated_lc_model( time,
zp/self.pfit[-1], amp/self.pfit[1], phase-self.pfit[nharmonics+1],
self.pfit, self.kind)
return ymodel
[docs] def get_lc_model_interp(self,kind='slinear'):
"""
Get modulated model light curve interpolated at the original time points.
Parameters
----------
kind : str or int, optional
Specifies the kind of interpolation. Default is ‘slinear’.
See `scipy.interpolate.interp1d` for the kinds.
Returns:
-------
ymodel : array
Modulated model light curve interpolated at the original time points.
"""
from scipy.interpolate import interp1d
goodpts = np.isfinite(self.avalues)
ampinterp = interp1d(self.oc_time[goodpts],self.avalues[goodpts],
kind=kind,fill_value="extrapolate")
goodpts = np.isfinite(self.psivalues)
phiinterp = interp1d(self.oc_time[goodpts],self.psivalues[goodpts],
kind=kind,fill_value="extrapolate")
goodpts = np.isfinite(self.a0values)
zpinterp = interp1d(self.oc_time[goodpts],self.a0values[goodpts],
kind=kind,fill_value="extrapolate")
ymodel = self.get_lc_model(self.time, ampinterp(self.time), phiinterp(self.time), zpinterp(self.time))
return ymodel