moseq2_model.train Package

Train - Model Module

ARHMM initialization utilities.

moseq2_model.train.models.ARHMM(data_dict, kappa=1000000.0, gamma=999, nlags=3, alpha=5.7, K_0_scale=10.0, S_0_scale=0.01, max_states=100, empirical_bayes=True, affine=True, model_hypparams={}, obs_hypparams={}, sticky_init=False, separate_trans=False, groups=None, robust=False, silent=False)

Initialize ARHMM and add data and group labels to the ARHMM.

Args: data_dict (OrderedDict): training data to add to model kappa (float): hyperparameter for setting syllable duration. Larger kappa = longer syllable durations gamma (float): scaling parameter for hierarchical dirichlet process (try to not change it) nlags (int): number of lag frames to add to sessions alpha (float): scaling parameter for hierarchical dirichlet process (try to not change it) K_0_scale (float): Standard deviation of lagged data S_0_scale (float): scale standard deviation initialization (try to not change it) max_states (int): Maximum number of model states empirical_bayes (bool): Boolean flag that specifies using empirical bayes to initialize sigma affine (bool): Boolean flag that specifies using affine transformation in the AR processes model_hypparams (dict): other model parameters (try to not change it) obs_hypparams (dict): observed parameters nu_0, S_0, M_0, and K_0 (try to not change it) sticky_init (bool): Boolean flag that specifies using random states to initialize the model separate_trans (bool): Boolean flag that specifies using separate transition matrices for each group groups (list): list of groups to model robust (bool): Boolean flag that specifies using student’s t-distributed AR model silent (bool): Boolean flag that specifies printing out model information

Returns: model (ARHMM): initialized model object

Train - General Utilities Module

ARHMM utility functions

moseq2_model.train.util.apply_model(model, whitening_params, data_dict, metadata, whiten='all')

Apply pre-trained model to data_dict. Note that this function might produce unexpected behavior if the model was trained using separate transition matrices for different groups of sessions.

Args:

model (ARHMM): pre-trained model whitening_params (namedtuple or dict): whitening parameters data_dict (OrderedDict): data to apply model to metadata (dict): metadata for data_dict

Returns:

labels (dict): dictionary of labels predicted per session after modeling

moseq2_model.train.util.get_crosslikes(arhmm, frame_by_frame=False)

Get the cross-likelihoods, a measure of confidence in label segmentation, for each model label.

Args: arhmm: the ARHMM model object frame_by_frame (bool): if True, the cross-likelihoods will be computed for each frame.

Returns: All_CLs (list): a dictionary containing cross-likelihoods for each syllable pair. CL (np.ndarray): the average cross-likelihood for each syllable pair

moseq2_model.train.util.get_labels_from_model(model)

Grab model labels for each training dataset and place them in a list.

Args: model (ARHMM): trained ARHMM

Returns: labels (list): An array of predicted syllable labels for each training session

moseq2_model.train.util.get_model_summary(model, groups, train_data, val_data, separate_trans)

Compute log-likelihood of train_data and val_data (if not None) when verbose is True.

Args: model (ARHMM): model to compute log-likelihoods. groups (list): list of session group names. train_data (OrderedDict): Ordered dict of training data val_data: (OrderedDict or None): Ordered dict of validation/held-out data separate_trans (bool): boolean flag that indicates whether to separate log-likelihoods for each group.

Returns: train_ll (float): normalized average training log-likelihoods across all recording sessions. val_ll (float): normalized average held-out log-likelihood across all recording sessions.

moseq2_model.train.util.rleslices(seq)

Get changepoint slices

Args: seq (list): list of labels

Returns: (map generator): slices of syllable changepoints

moseq2_model.train.util.run_e_step(arhmm)

Compute the expected state sequence for sessions in the training dataset and place them in a list.

Args: arhmm (ARHMM): model to compute expected states from.

Returns: e_states (list): list of expected states

moseq2_model.train.util.slices_from_indicators(indseq)

Compute start and stop indices (slices) for each contiguous sequence of True values in indseq.

Args: indseq (list): Indicator array, containing True and False values

Returns: (list): list of slices from indseq.

moseq2_model.train.util.train_model(model, num_iter=100, ncpus=1, checkpoint_freq=None, checkpoint_file=None, start=0, progress_kwargs={}, train_data=None, val_data=None, separate_trans=False, groups=None, verbose=False, check_every=2)

Train ARHMM for inputted number of iterations.

Args: model (ARHMM): model object to train num_iter (int): total number of resampling iterations ncpus (int): number of cpus used to resample the model checkpoint_freq (int): frequency (iterations) to save a checkpoint of the model checkpoint_file (str): path to save new checkpoint file start (int): starting iteration index used to resume modeling. Default is 0 progress_kwargs (dict): keyword arguments for progress bar train_data (OrderedDict): dict of training data used for getting log-likelihods if verbose is True val_data (OrderedDict): dict of validation data used for getting validation log-likelihoods if verbose is True. separate_trans (bool): use separated transition matrices for each group groups (list): list of groups included in modeling used for getting log-likelihoods if verbose is True verbose (bool): get log-likelihoods at check_every interval check_every (int): frequency (iterations) to record model training/validation log-likelihoods during training

Returns: model (ARHMM): trained model. log_likelihood (list): list of training log-likelihoods per session after modeling. labels (list): list of labels predicted per session after modeling. iter_lls (list): list of training log-likelihoods for each check_every iteration. iter_holls (list): list of held-out log-likelihoods for each check_every iteration. interrupt (bool): flag to notify the caller of this function if a keyboard interrupt happened

moseq2_model.train.util.training_checkpoint(model, itr, checkpoint_file)

Format the model checkpoint filename and save the model checkpoint

Args: model (ARHMM): Model object being trained. itr (itr): Current modeling iteration. checkpoint_file (str): Model checkpoint filename.

moseq2_model.train.util.whiten_all(data_dict, center=True)

Whiten the PC Scores (with Cholesky decomposition) using all the data to compute the covariance matrix.

Args: data_dict (OrderedDict): Training dataset center (bool): Indicates whether to center data by subtracting the mean PC score.

Returns: data_dict (OrderedDict): Whitened training data dictionary

moseq2_model.train.util.whiten_each(data_dict, center=True)

Whiten the PC scores for each training dataset separately.

Args: data_dict (OrderedDict): Training dataset center (bool): Boolean flag that indicates whether to center data by subtracting the mean PC score.

Returns: data_dict (OrderedDict): Whitened training data dictionary

moseq2_model.train.util.zscore_all(data_dict, npcs=10, center=True)

z-score the PC Scores altogether.

Args: data_dict (OrderedDict): Training dictionary npcs (int): number of pcs included center (bool): Indicates whether to center data by subtracting the mean PC score.

Returns: data_dict (OrderedDict): z-scored training data dictionary

moseq2_model.train.util.zscore_each(data_dict, center=True)

z-score each set of PC Scores separately

Args: data_dict (OrderedDict): Training dictionary center (bool): Indicates whether to center data by subtracting the mean PC score.

Returns: data_dict (OrderedDict): z-scored training data dictionary