moseq2_viz.model package
Model - Dist Module
Utility functions for estimating syllable similarity (behavioral distance).
- moseq2_viz.model.dist.get_behavioral_distance(index, model_file, whiten='all', distances=['ar[init]', 'scalars'], max_syllable=None, resample_idx=-1, dist_options={}, sort_labels_by_usage=True, count='usage')
Compute the behavioral distance (square) matrices with respect to a predefined set of variables.
Args: index (str): Path to index file model_file (str): Path to trained model whiten (str): Indicates whether to whiten all PCs at once or each one at a time. Options = [‘all’, ‘each’] distances (list or str): type of distance(s) to compute. Available options = [‘scalars’, ‘ar[init]’, ‘ar[dtw]’, ‘pca[dtw]’, ‘combined’] max_syllable (int): the index of the maximum number of syllables to include resample_idx (int): Indicates the parsing method according to the shape of the labels array. dist_options (dict): Dictionary holding each distance operations configurable parameters sort_labels_by_usage (bool): boolean flag that indicates whether to relabel syllables by count ordering count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.
Returns: dist_dict (dict): Dictionary containing all computed behavioral square distance matrices
- moseq2_viz.model.dist.get_behavioral_distance_ar(ar_mat, init_point=None, sim_points=10, max_syllable=40, dist='correlation', parallel=False)
Compute behavioral distance with respect to the model’s AutoRegressive matrices. The function Affords either AR trajectory correlation distance, or computing dynamically time-warped trajectory distances.
- Parameters:
ar_mat (numpy.ndarray) (Trained model AutoRegressive matrices; shape=(max_syllable, npcs, npcs*nlags+1))
init_point (list) (Initial values as a reference point for distance estimation)
sim_points (int) (number of time points to simulate)
max_syllable (int) (the index of the maximum number of syllables to include)
dist (str) (Distance operation to compute. Either ‘correlation’ or ‘dtw’.)
parallel (bool) (Boolean flag that indicates whether to use multiprocessing to compute dtw distances.)
- Returns:
ar_dist (2D numpy array) (Computed AR trajectory distances for each AR matrix/model state.)
shape=(max_syllable, max_syllable)
- moseq2_viz.model.dist.get_init_points(pca_scores, model_labels, max_syllable=40, nlags=3, npcs=10)
Compute initial AR trajectories based on a cumulative average of lagged-PC Scores over nlags.
Args: pca_scores (numpy.ndarray): Loaded PC Scores. Shape=(npcs, nsamples) model_labels (list): list of 1D numpy arrays of relabeled/sorted syllable labels max_syllable (int): the index of the maximum number of syllables to include nlags (int): Number of lagged frames. npcs (int): Number of PCs to use in computation.
Returns: syll_average (list): List containing 2D np arrays of average syllable trajectories over a nlag-strided PC scores array.
- moseq2_viz.model.dist.reformat_dtw_distances(full_mat, nsyllables, rescale=True)
Reduce full (max states) dynamically time-warped PC Score distance matrices to only include dimensions for a total of nsyllables.
Args: full_mat (np.ndarray): DTW distance matrices for all model states/syllables. nsyllables (int): Number of syllables to include in truncated DTW distance matrix. rescale (bool): Rescale truncated dtw-distance matrices to match output distribution.
Returns: rmat (2D np array): Reformatted-Truncated DTW Distance Matrix; shape = (nsyllables, nsyllables)
Model - embed Module
Functions to run PCA and LDA on syllable usages and scalars
- moseq2_viz.model.embed.plot_embedding(L, y, mapping, rev_mapping, output_file='embedding.pdf', embedding='PCA', x_dim=0, y_dim=1, symbols='o*v^s', plot_all_subjects=True)
Plot 2D embedding plot.
Args: L (2D np.array): the embedding representations of the mean syllable statistic to plot. y (1D list): list of group names corresponding to each row in L. mapping (dict): dictionary conataining mappings from group string to integer for later embedding. rev_mapping (dict): inverse mapping dict to retrieve the group names given their mapped integer value. output_file (str): path to saved outputted figure embedding (str): type of embedding to run. Either [‘lda’, ‘pca’]. x_dim (int): component number to graph on x-axis y_dim (int): component number to graph on y-axis symbols (str): symbols to use to draw different groups. plot_all_subjects (bool): boolean flag that indicates whether to plot individual subject embeddings along with their respective group means.
Returns: fig (matplotlib.figure): figure containing plotted 2d embedding. ax (matplonlib.axes): axes instance for plotted figure.
- moseq2_viz.model.embed.run_2d_embedding(mean_df, stat='usage', output_file='2d_embedding.pdf', embedding='PCA', n_components=2, plot_all_subjects=True)
Compute a 2D embedding (PCA or LDA) of the mean syllable statistic of choice. The function will output a figure of the 2D representation of the embedding.
Args: mean_df (pd DataFrame): Dataframe of the mean syllable statistics for all sessions stat (str): name of statistic (column) in mean_df to embed. output_file (str): path to saved outputted figure embedding (str): type of embedding to run. Either [‘lda’, ‘pca’] n_components (int): Number of components to compute. plot_all_subjects (bool): indicates whether to plot individual subject embeddings along with their respective group means.
Returns: fig (matplotlib.figure): figure containing plotted 2d embedding. ax (matplonlib.axes): axes instance for plotted figure.
- moseq2_viz.model.embed.run_2d_scalar_embedding(scalar_df, output_file='2d_scalar_embedding.pdf', embedding='PCA', n_components=2, plot_all_subjects=True)
Compute a 2D embedding (PCA or LDA) of the mean measured scalar values for all groups. The function will output a figure of the 2D representation of the embedding.
Args: scalar_df (pd DataFrame): Dataframe of the frame-by-frame scalar measurements for all sessions output_file (str): path to saved outputted figure embedding (str): type of embedding to run. Either [‘lda’, ‘pca’] n_components (int): Number of components to compute. plot_all_subjects (bool): indicates whether to plot individual subject embeddings along with their respective group means.
Returns: fig (matplotlib figure): figure containing plotted 2d embedding. ax (matplonlib axes): axes instance for plotted figure.
Model - Fingerprint and Classifier Module
Functions for creating fingerprint plots and linear classifier
- moseq2_viz.model.fingerprint_classifier.classifier_fingerprint(summary, features=['MoSeq'], preprocessor=None, classes=['group'], param_search=True, C_list=None, model_type='lr', cv='loo', n_splits=5)
run classifier using the fingerprint dataframe
- Args:
summary (pandas.DataFrame): fingerprint dataframe features (list, optional): Features for the classifier. [‘MoSeq’] for MoSeq syllables or a list of MoSeq scalar values. Defaults to [‘MoSeq’]. preprocessor (sklearn.preprocessing object, optional): Scalar for scaling the data by feature. Defaults to None. target (list, optional): labels the classifier predicts. Defaults to [‘group’]. param_search (bool, optional): run GridSearchCV to find the regularization param for classifier. Defaults to True. C_list (numpy.array, optional): list of C regularization paramters to search through. Defaults to None. If None, C_list will search through np.logspace(-6,3, 50) model_type (str, optional): name of the linear classifier. ‘lr’ for logistic regression or ‘svc’ for linearSVC. Defaults to ‘lr’. cv (str, optional): cross validation type. ‘loo’ for LeaveOneOut ‘skf’ for StratifiedKFold. Defaults to ‘loo’. n_splits (int, optional): number of splits for StratifiedKFold. Defaults to 5.
- Returns:
y_true (np.array): array for true label y_pred (np.array): array for predicted label real_f1 (np.array): array for f1 score true_coef (np.array): array for model weights y_shuffle_true (np.array): array for shffuled label y_shuffle_pred (np.array): array for shuffled predicted label shuffle_f1 (np.array): array for shuffled f1 score shuffle_coef (np.array): array for shuffled model weights
- moseq2_viz.model.fingerprint_classifier.create_fingerprint_dataframe(scalar_df, mean_df, stat_type='mean', n_bins=None, groupby_list=['group', 'uuid'], range_type='robust', scalars=['velocity_2d_mm', 'height_ave_mm', 'length_mm', 'dist_to_center_px'])
create fingerprint dataframe from scalar_df and mean_df
- Args:
scalar_df (pandas.DataFrame): scalar summary dataframe generated from scalars_to_dataframe mean_df (pandas.DataFrame): syllable mean dataframe from compute_behavioral_statistics bin_num (int, optional): number of bins for the features. Defaults to None. groupby_list (list, optional): the list of levels the fingerprint dataframe should be grouped by. Defaults to [‘group’, ‘uuid’].
- Returns:
summary ([pandas.DataFrame]): fingerprint dataframe range_dict ([dict]): dictionary that hold min max values of the features
- moseq2_viz.model.fingerprint_classifier.plot_cm(y_true, y_pred, y_shuffle_true, y_shuffle_pred)
plot confusion matrix
- Args:
y_true (np.array): array for true label y_pred (np.array): array for predicted label y_shuffle_true (np.array): array for shffuled label y_shuffle_pred (np.array): array for shuffled predicted label
- moseq2_viz.model.fingerprint_classifier.plotting_fingerprint(summary, save_dir, range_dict, figsize=(20, 18), preprocessor=None, num_level=1, level_names=['Group'], vmin=None, vmax=None, plot_columns=['dist_to_center_px', 'velocity_2d_mm', 'height_ave_mm', 'length_mm', 'MoSeq'], col_names=[('Position', 'Dist. from center (px)'), ('Speed', 'Speed (mm/s)'), ('Height', 'Height (mm)'), ('Length', 'Length (mm)'), ('MoSeq', 'Syllable ID')])
plot the fingerprint heatmap
- Args:
summary (pandas.DataFrame): fingerprint dataframe range_dict (pandas.DataFrame): pd.DataFrame that hold min max values of the features preprocessor (sklearn.preprocessing object, optional): Scalar for scaling the data by session. Defaults to None. num_level (int, optional): the number of groupby levels. Defaults to 1. level_names (list, optional): list of names of the levels. Defaults to [‘Group’]. vmin (int, optional): min value the figure color map covers. Defaults to 0. vmax (float, optional): max value the figure color map covers. Defaults to 0.2. plot_columns (list, optional): columns to plot col_names = (list, optional): list of (column name, x label) pairs
- Raises:
Exception: num_levels greater than the existing levels
- moseq2_viz.model.fingerprint_classifier.robust_max(v)
find relative max
- Args:
v (pandas.Series): the pandas.Series to find the robust max.
- Returns:
(numpy.float): the robust max in the series.
- moseq2_viz.model.fingerprint_classifier.robust_min(v)
find relative min
- Args:
v (pandas.Series): the pandas.Series to find the robust min.
- Returns:
(numpy.float): the robust min in the series.
Model - Stats Module
Functions for statistical tests for analyzing model results.
- moseq2_viz.model.stat.bootstrap_group_means(df, group1, group2, statistic='usage', max_syllable=40)
compute boostrapped group means
Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of group 1 to compare. group2 (str): Name of group 2 to compare. statistic (str): Syllable statistic to compute bootstrap means for. max_syllable (int): the index of the maximum number of syllables to include
Returns: boots (dictionary): dictionary of group name (keys) paired with their bootstrapped statistics numpy array.
- moseq2_viz.model.stat.bootstrap_me(usages, n_iters=10000)
Bootstrap the inputted stat data using random sampling with replacement.
Args: usages (np.array): Data to bootstrap n_iters (int): Number of samples to return.
Returns: boots (np.array): Bootstrapped input array of shape
- moseq2_viz.model.stat.compute_pvalues_for_group_pairs(real_zs_within_group, null_zs, df_k_real, group_names, n_perm=10000, thresh=0.05, mc_method='fdr_bh', verbose=False)
Adjust the p-values from Dunn’s z-test statistics and computes the resulting significant syllables with the adjusted p-values.
Args: real_zs_within_group (dict): dict of group pair keys paired with vector of Dunn’s z-test statistics null_zs (dict): dict of group pair keys paired with vector of Dunn’s z-test statistics of the null hypothesis. df_k_real (pandas.DataFrame): DataFrame of KW test results. group_names (pd.Index): Index list of unique group names. n_perm (int): Number of permuted samples to generate. thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use. verbose (bool): indicates whether to print out the significant syllable results
Returns: df_pval_corrected (pandas.DataFrame): DataFrame containing Dunn’s test results with corrected p-values. significant_syllables (list): List of corrected KW significant syllables (syllables with p-values < thresh)
- moseq2_viz.model.stat.dunns_z_test_permute_within_group_pairs(df_usage, vc, real_ranks, X_ties, N_m, group_names, rnd, n_perm)
Run Dunn’s z-test statistic on combinations of all group pairs, handling pre-computed tied ranks.
Args: df_usage (pandas.DataFrame): DataFrame containing only pre-computed syllable stats. vc (pd.Series): value counts of sessions in each group. real_ranks (np.array): Array of syllable ranks. X_ties (np.array): 1-D list of tied ranks, where if value > 0, then rank is tied N_m (int): Number of sessions. group_names (pd.Index): Index list of unique group names. rnd (np.random.RandomState): Pseudo-random number generator. n_perm (int): Number of permuted samples to generate.
Returns: null_zs_within_group (dict): dict of group pair keys paired with vector of Dunn’s z-test statistics of the null hypothesis. real_zs_within_group (dict): dict of group pair keys paired with vector of Dunn’s z-test statistics
- moseq2_viz.model.stat.get_session_mean_df(df, statistic='usage', max_syllable=40)
Compute a given mean syllable statistic grouped by groups and UUIDs.
Args: df (pandas.DataFrame): dataframe that contains average syllable statistics per session (mean_df/stats_df) statistic (str): statistic to compute mean for, (any of the columns in input df); for example: ‘usage’, ‘duration’, ‘velocity_2d_mm’, etc. max_syllable (int): the index of the maximum number of syllables to include
Returns: df_pivot (pandas.DataFrame): Mean syllable statistic per session.
- moseq2_viz.model.stat.get_sig_syllables(df_pvals, thresh=0.05, mc_method='fdr_bh', verbose=False)
Runs multiple p-value comparisons test given a set alpha threshold with mutliple test correction.
Args: df_pvals (pandas.DataFrame): dataframe listing raw p-values thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use. verbose (bool): indicates whether to print out the significant syllable results
Returns: df_pvals (pandas.DataFrame): updated dataframe listing adjusted p-values
- moseq2_viz.model.stat.get_tie_correction(x, N_m)
assign tied rank values to the average of the ranks they would have received if they had not been tied for Kruskal-Wallis helper function.
Args: x (pd.Series): syllable usages for a single session. N_m (int): Number of total sessions.
Returns: corrected_rank (float): average of the inputted tied rank
- moseq2_viz.model.stat.mann_whitney(df, group1, group2, statistic='usage', max_syllable=40, verbose=False, **kwargs)
Runs a Mann-Whitney hypothesis test with multiple test corrections on two given groups to find significant syllables. Also runs multiple corrections test to find syllables to exclude.
Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of first group group2 (str): Name of second group statistic (str): Name of statistic to compute z-test on. max_syllable (int): the index of the maximum number of syllables to include verbose (bool): indicates whether to print out the significant syllable results thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use.
Returns: df_mw_real (pandas.DataFrame): DataFrame containing Mann-Whitney U corrected results. exclude_sylls (list): list of syllables that were excluded via multiple comparisons test.
- moseq2_viz.model.stat.plot_H_stat_significance(df_k_real, h_all, N_s)
Plot the assigned H-statistic for each syllable computed via manual KW test.
Args: df_k_real (pandas.DataFrame): the dataframe that contains Kruskal-Wallis p value, H stats and whether it is significant h_all (np.array): Array of H-stats computed for given n_syllables; shape = (n_perms, N_s) N_s (int): Number of syllables to plot
Returns: fig (pyplot.figure): plotted H-stats plot ax (pyplot.axis): plotted H-stats axis
- moseq2_viz.model.stat.run_kruskal(df, statistic='usage', max_syllable=40, n_perm=10000, seed=42, thresh=0.05, mc_method='fdr_bh', verbose=False)
Run Kruskal-Wallis Hypothesis test and Dunn’s posthoc multiple comparisons test for syllable statistic.
Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) statistic (str): statistic to compute mean for, (any of the columns in input df). max_syllable (int): the index of the maximum number of syllables to include n_perm (int): Number of permuted samples to generate. seed (int): Random seed used to initialize the pseudo-random number generator. thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use. verbose (bool): indicates whether to print out the significant syllable results
Returns: df_k_real (pandas.DataFrame): DataFrame of KW test results. dunn_results_df (pandas.DataFrame): DataFrame of Dunn’s test results for permuted group pairs. intersect_sig_syllables (dict): dictionary containing intersecting significant syllables between KW and Dunn’s tests.
- moseq2_viz.model.stat.run_manual_KW_test(df_usage, merged_usages_all, num_groups, n_per_group, cum_group_idx, n_perm=10000, seed=0)
Run a manual Kruskal-Wallis test compare the results agree with the scipy.stats.s
Args: df_usage (pandas.DataFrame): DataFrame containing only pre-computed syllable stats. shape = (N_m, n_syllables) merged_usages_all (np.array): numpy array format of the df_usage DataFrame. num_groups (int): Number of unique groups n_per_group (list): list of value counts for sessions per group. len == num_groups. cum_group_idx (list): list of indices for different groups. len == num_groups + 1. n_perm (int): Number of permuted samples to generate. seed (int): Random seed used to initialize the pseudo-random number generator.
Returns: h_all (np.array): Array of H-stats computed for given n_syllables; shape = (n_perms, N_s) real_ranks (np.array): Array of syllable ranks, shape = (N_m, n_syllables) X_ties (np.array): 1-D list of tied ranks, where if value > 0, then rank is tied. len(X_ties) = n_syllables
- moseq2_viz.model.stat.run_pairwise_stats(df, group1, group2, test_type='mw', verbose=False, **kwargs)
Run hypothesis testing functions: MannWhitney, Z-Test and T-Test.
Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of first group group2 (str): Name of second group test_type (str): specifying which type of statistical test verbose (bool): boolean flag that indicates whether to print out the significant syllable results
- Returns:
df_pvals (pandas.DataFrame): Dataframe listing the p-values and which syllables are significant
- moseq2_viz.model.stat.ttest(df, group1, group2, statistic='usage', max_syllable=40, verbose=False, **kwargs)
Computes a t-hypothesis test on 2 selected groups to find significant syllables with multiple test correction.
Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of first group group2 (str): Name of second group statistic (str): Name of statistic to compute t-test on. max_syllable (int): the index of the maximum number of syllables to include verbose (bool): indicates whether to print out the significant syllable results. thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use.
Returns: p (np.array): Computed array of p-values syllables_to_include (list): List of significant syllables after multiple corrections.
- moseq2_viz.model.stat.ztest(df, group1, group2, statistic='usage', max_syllable=40, verbose=False, **kwargs)
Computes a z hypothesis test on 2 (bootstrapped) selected groups with multiple test correction.
Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of first group group2 (str): Name of second group statistic (str): Name of statistic to compute z-test on. max_syllable (int): the index of the maximum number of syllables to include verbose (bool): boolean flag that indicates whether to print out the significant syllable results thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use.
Returns: pvals_ztest_boots (np.array): Computed array of p-values syllables_to_include (list): List of significant syllables after multiple corrections.
- moseq2_viz.model.stat.ztest_vect(d1, d2)
Perform a z-test on a pair of bootstrapped syllable statistics.
Args: d1 (np.array): bootstrapped syllable stat array from group 1 d2 (np.array): bootstrapped syllable stat array from group 2
Returns: p-values (np.array): array of computed p-values of the syllables.
Model - Transition Graph Module
Visualization and utility functions for transition matrices.
- moseq2_viz.model.trans_graph.compute_and_graph_grouped_TMs(config_data, labels, label_group, group)
compute a transition matrix for each given group.
Args: config_data (dict): configuration dictionary containing graphing parameters labels (list): list of 1D numpy arrays containing syllable labels per frame for every included session label_group (list): list of corresponding group names to plot transition aggregated transition plots group (list): unique list of groups to plot.
Returns: plt (pyplot.Figure): open transition graph figure to save
- moseq2_viz.model.trans_graph.convert_ebunch_to_graph(ebunch)
Convert transition matrices to transition DAGs.
Args: ebunch (list of tuples): syllable transition data
Returns: g (networkx.DiGraph): DAG object to graph
- moseq2_viz.model.trans_graph.convert_transition_matrix_to_ebunch(weights, transition_matrix, usages=None, usage_threshold=-0.1, speeds=None, speed_threshold=0, edge_threshold=-0.1, indices=None, keep_orphans=False, max_syllable=None)
Compute thresholded syllable transition data by usages and transition probabilities.
Args: weights (np.ndarray): syllable transition edge weights transition_matrix (np.ndarray): syllable transition matrix usages (list): list of syllable usages usage_threshold (float): threshold syllable usage to include a syllable in list of orphans speeds (np.array): list of syllable speeds speed_threshold (int): threshold value for syllable speeds to include edge_threshold (float): threshold transition probability to consider an edge part of the graph. indices (list): indices of syllable bigrams to plot keep_orphans (bool): indicate whether to graph orphan syllables max_syllable (int): the index of the maximum number of syllables to include
Returns: ebunch (list): syllable transition data. orphans (list): syllables with no edges.
- moseq2_viz.model.trans_graph.draw_graph(graph, width, pos, node_color, node_size, node_edge_colors, ax, arrows=False, font_size=12, edge_colors='k', title=None)
Draw transition graph to existing matplotlib axes.
Args: graph (nx.DiGraph): list of created nx.DiGraphs converted from transition matrices width (list): list of edge widths corresponding to each graph’s edges. pos (nx.Layout): nx.Layout type object holding position coordinates for the nodes. node_color (list): list of node colors for each graph. node_sizes (int or list): list of node sizes for each graph. node_edge_colors (list): list of node edge colors for each graph. ax (mpl.pyplot.axis): axis to draw graphs on. arrows (bool): whether to draw arrow edges font_size (int): Node label font size edge_colors (str): color of the transition edges drawn in the graph. title (str): title/group name of the transition graph.
Returns:
- moseq2_viz.model.trans_graph.get_group_trans_mats(labels, label_group, group, max_sylls, normalize='bigram')
Compute individual transition matrices for each given group.
Args: labels (np.ndarray): list of frame labels for each included session label_group (list): list of groups for each included session group (list): list of unique groups included max_sylls (int): Maximum number of syllables to include normalize (str): indicates how to normalize the computed transition matrices.
Returns: trans_mats (list of 2D np.ndarrays): list of transition matrices for each given group. usages (list of lists): list of corresponding usage statistics for each group.
- moseq2_viz.model.trans_graph.get_pos(graph_anchor, layout, nnodes)
Get node positions in the graph based on the graph anchor and a user selected layout.
Args: graph_anchor (nx.Digraph): graph to get node layout for layout (str): networkx layout type nnodes (int): number of nodes in the graph
Returns: pos (nx layout): computed node position layout
- moseq2_viz.model.trans_graph.get_trans_graph_groups(model_fit)
Get the groups and their respective session uuids to use in transition graph generation.
Args: model_fit (dict): trained model ARHMM containing training data UUIDs.
Returns: label_group (list): list of groups for each included session model_uuids (list): list of corresponding UUIDs for each included session in the model
- moseq2_viz.model.trans_graph.get_transition_matrix(labels, max_syllable=100, normalize='bigram', smoothing=0.0, combine=False, disable_output=False) list
Compute the transition matrix from a set of model labels.
Args: labels (list of numpy.array): labels loaded from a model fit max_syllable (int): the index of the maximum number of syllables to include normalize (str): how to normalize transition matrix, ‘bigram’ or ‘rows’ or ‘columns’ smoothing (float): constant to add to transition_matrix pre-normalization to smooth counts combine (bool): flag for computing a separate transition matrix for each element (False) or combine across all arrays in the list (True) disable_output (bool): flag to display a TQDM progress bar for transition matrix computation process.
Returns: transition_matrix (list or np.ndarray): list of 2d np.arrays that represent the transitions
- moseq2_viz.model.trans_graph.get_transitions(label_sequence)
Computes syllable transitions.
Args: label_sequence (tuple): a tuple of syllable transitions and their indices
Returns: transitions (np.array): filtered label sequence containing only the syllable changes locs (np.array): list of all the indices where the syllable label changes
- moseq2_viz.model.trans_graph.graph_transition_matrix(trans_mats, usages=None, groups=None, edge_threshold=0.0025, anchor=0, usage_threshold=0, layout='spring', edge_width_scale=100, fig=None, ax=None, width_per_group=8, headless=False, difference_threshold=0.0005, weights=None, usage_scale=10000.0, keep_orphans=False, max_syllable=None, orphan_weight=0, arrows=False, font_size=12, difference_edge_width_scale=500, **kwargs)
Create transition graph plot given a transition matrix and some metadata.
Args: trans_mats (np.ndarray): syllable transition matrix usages (list): list of syllable usage probabilities groups (list): list groups to graph transition graphs for. edge_threshold (float): threshold to include edge in graph anchor (int): syllable index as the base syllable usage_threshold (int): threshold to include syllable usages layout (str): layout format edge_width_scale (int): edge line width scaling factor fig (pyplot.figure): figure to plot to ax (pyplot.Axes): axes object width_per_group (int): graph width scaling factor per group headless (bool): exclude first node. difference_threshold (float): threshold to consider 2 graph elements different weights (list): list of edge weights usage_scale (float): syllable usage scaling factor keep_orphans (bool): plot orphans. max_syllable (int): the index of the maximum number of syllables to include orphan_weight (int): scaling factor to plot orphan node sizes arrows (bool): indicate whether to plot arrows as transitions. difference_edge_width_scale (float): difference graph edge line width scaling factor kwargs (dict): extra keyword arguments
Returns: fig (pyplot.figure): figure containing transition graphs. ax (pyplot.axis): figure axis object. pos (dict): dict figure information.
- moseq2_viz.model.trans_graph.make_difference_graphs(trans_mats, usages, group, group_names, usage_kwargs, widths, pos, node_edge_colors, ax=None, node_sizes=[], indices=None, difference_threshold=0.0005, difference_edge_width_scale=500, font_size=12, usage_scale=50000.0, difference_graphs=[], scalars=None, arrows=False, speed_kwargs={})
compute transition graph differences bewtween two groups.
Args: trans_mats (np.ndarray): syllable transition matrix. usages (list): list of syllable usage probabilities. group (list): list groups to graph transition graphs for. group_names (list): list groups names to display with transition graphs. usage_kwargs (dict): kwargs for graph threshold settings using usage. Keys can be ‘usages’, and ‘usage_threshold’ widths (list): list of edge widths for each created single-group graph. pos (nx.Layout): nx.Layout type object holding position coordinates for the nodes. node_edge_colors (list): node edge colors (of type str). ax (np.ndarray matplotlib.pyplot.Axis): Optional axes to plot graphs in node_sizes (list): node size scaling factor (of type int) indices (list): list of in->out syllable indices to keep in graph difference_threshold (float): threshold to consider 2 graph elements different. difference_edge_width_scale (int): scaling factor for edge widths in difference transition graphs. font_size (int): indicates the size of the numbers drawn on the transition graph nodes. usage_scale (float): syllable usage scaling factor. difference_graphs (list): list of created difference transition graphs. scalars (dict): dict of syllable scalar data per transition graph arrows (bool): indicates whether to display arrows between node transitions speed_kwargs (dict): kwargs for graph threshold settings using usage. Keys can be ‘speeds’, and ‘speed_threshold’
Returns: usages (list): list of syllable usage probabilities including usages differences across groups group_names (list): list groups names to display with transition graphs including difference graphs. difference_graphs (list): list of computed difference graphs widths (list): list of edge widths for each created graph appended with difference weights node_sizes (list): lists of node sizes corresponding to each graph including difference graph node sizes node_edge_colors (list): lists of node colors corresponding to each graph including difference graph node sizes
- moseq2_viz.model.trans_graph.make_transition_graphs(trans_mats, usages, group, group_names, usage_kwargs, pos, orphans, edge_threshold=0.0025, difference_threshold=0.0005, orphan_weight=0, ax=None, edge_width_scale=100, usage_scale=100000.0, difference_edge_width_scale=500, speed_kwargs={}, indices=None, font_size=12, scalars=None, arrows=False)
create transition matrices for all included groups, as well as their difference graphs.
Args: trans_mats (np.ndarray): syllable transition matrix. usages (list): list of syllable usage probabilities. group (list): list groups to graph transition graphs for. group_names (list): list groups names to display with transition graphs. usage_kwargs (dict): kwargs for graph threshold settings using usage. Keys can be ‘usages’, and ‘usage_threshold’ pos (nx.Layout): nx.Layout type object holding position coordinates for the nodes. orphans (list): list of nodes with no edges. edge_threshold (float): threshold to include edge in graph. difference_threshold (float): threshold to consider 2 graph elements different. orphan_weight (int): scaling factor to plot orphan node sizes. ax (np.ndarray matplotlib.pyplot Axis): Optional axes to plot graphs in edge_width_scale (int): edge line width scaling factor. usage_scale (float): syllable usage scaling factor. difference_edge_width_scale (int): scaling factor for edge widths in difference transition graphs. speed_kwargs (dict): kwargs for graph threshold settings using usage. Keys can be ‘speeds’, and ‘speeds_threshold’ indices (list): list of in->out syllable indices to keep in graph font_size (int): indicates the size of the numbers drawn on the transition graph nodes. scalars (dict): dict of syllable scalar data per transition graph arrows (bool): indicates whether to display arrows between node transitions
Returns: usages (list): list of syllable usage probabilities including possible appended difference usages. group_names (list): list groups names to display with transition graphs including difference graphs. widths (list): list of edge widths for each created graph appended with difference weights. node_sizes (2D list): lists of node sizes corresponding to each graph. node_edge_colors (2D list): lists of node colors corresponding to each graph including difference graph node sizes. graphs (list of nx.DiGraph): list of all group and difference transition graphs.
- moseq2_viz.model.trans_graph.n_gram_transition_matrix(labels, n=2, max_label=99)
Compute the transition count for a fixed syllable sequence length ‘n’.
Args: labels (list of np.array of ints): syllable slabels loaded from a model fit. n (int): length of transition chain to compute transition probability for. max_label (int): max number of syllables to scan for in transition matrix.
Returns: trans_mat (np.ndarray): array of n-transition counts for given max_label.
- moseq2_viz.model.trans_graph.normalize_transition_matrix(init_matrix, normalize)
Normalize a transition matrix by given criteria.
Args: init_matrix (np.array): transition matrix to normalize. normalize (str): normalization criteria; [‘bigram’, ‘rows’, ‘columns’, or None]
Returns: init_matrix (np.array): normalized transition matrix
Model - Utilities Module
Utility functions for handling model data during pre and post processing.
- moseq2_viz.model.util.add_duration_column(scalar_df)
Add syllable duration column to scalar dataframe if the dataframe contains syllable labels.
Args: scalar_df (pandas.DataFrame): merged dataframe of scalar data and syllable data.
Returns: scalar_df (pandas.DataFrame): Same DataFrame with a new column titled “duration”.
- moseq2_viz.model.util.compute_behavioral_statistics(scalar_df, groupby=['group', 'uuid'], count='usage', fps=30, usage_normalization=True, syllable_key='labels (usage sort)')
Compute syllable statistics merged with the scalar features.
Args: scalar_df (pandas.DataFrame): Scalar measuresments for full dataset, including metadata for all the sessions. groupby (list of strings): list of columns to run the pandas groupby() on the scalar_df. count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’. fps (int): frames per second that the data was acquired in. usage_normalization (bool): indicates whether to normalize syllable usages by the value counts. syllable_key (str): column to rename to “syllable” for convenient referencing later on.
Returns: features (pandas.DataFrame): full feature Dataframe with scalars, metadata, and syllable statistics.
- moseq2_viz.model.util.compute_syllable_explained_variance(model, save_dir='/home/wingillis/dev/moseq/moseq2-viz/docs', n_explained=99)
Compute the maximum number of syllables to include that explain n_explained percent of all frames in the dataset.
Args: model (dict): ARHMM results dict n_explained (int): explained variance percentage threshold
Returns: max_sylls (int): the index of the maximum number of syllables to include that explain the given percentage of the variance
- moseq2_viz.model.util.compute_syllable_onset(labels)
Computes the onset index of the each syllable label in a Series.
Args: labels (list or dict): label sequences loaded from a model fit
Returns: onsets (2D np.array): onset indices for each syllable for the given sessions.
- moseq2_viz.model.util.get_Xy_values(stat_means, unique_groups, stat='usage')
Compute the syllable or scalar mean statistics for each session, stored in X.
Args: stat_means (pd DataFrame): Dataframe of syllable or session-scalar mean statistics unique_groups (list): list of unique groups in the syll_means dataframe. stat (str or list): statistic column(s) to read from the syll_means df.
Returns: X (2D np.array): mean syllable or scalar statistics for each session. (nsessions x nsyllables) y (1D list): list of group names corresponding to each row in X. mapping (dict): dictionary conataining mappings from group string to integer for later embedding. rev_mapping (dict): inverse mapping dict to retrieve the group names given their mapped integer value.
- moseq2_viz.model.util.get_best_fit(cp_path, model_results)
Return the model with the closest median syllable duration and closest duration distribution to the model free changepoints given the objective.
Args: cp_path (str): Path to PCA Changepoints h5 file. model_results (dict): dict of pairs of model names paired with dict containing their respective changepoints.
Returns: info (dict): information about the best-fit models. pca_cps (1D array): pc score changepoint durations.
- moseq2_viz.model.util.get_mouse_syllable_slices(syllable: int, labels: ndarray) Iterator[slice]
Return a list containing slices of syllable indices for a mouse.
Args: syllable (list): list of syllables to get slices from. labels (np.ndarrary): list of label predictions for each session.
Returns: slices (list): list of syllable label slices; e.g. [slice(3, 6, None), slice(9, 12, None)]
- moseq2_viz.model.util.get_normalized_syllable_usages(model_data, max_syllable=100, count='usage')
Compute syllable usages and normalizes to sum to 1 and return a 1D array of their corresponding usage values.
Args: model_data (dict): dict object of modeling results max_syllable (int): the index of the maximum number of syllables to include count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.
Returns: syllable_usages (1D np array): array of sorted syllable usages for all syllables in model
- moseq2_viz.model.util.get_syllable_slices(syllable='__no__default__', labels='__no__default__', label_uuids='__no__default__', index='__no__default__', trim_nans: bool = True) list
Get the indices that correspond to a specific syllable for each session in a modeling run.
Args: syllable (int): syllable number to get slices of. labels (np.ndarrary): list of label predictions for each session. label_uuids (list): list of uuid keys corresponding to each session. index (dict): index file contents contained in a dict. trim_nans (bool): flag to use the pc scores file for removing time points that contain NaNs.
Returns: syllable_slices (list): a list of indices for syllable in the labels array. Each item in the list is a tuple of (slice, uuid, h5_file).
- moseq2_viz.model.util.get_syllable_statistics(data, fill_value=-5, max_syllable=100, count='usage')
Compute the usage and duration statistics from a set of model labels
Args: data (list of np.array of ints): labels loaded from a model fit. fill_value (int): lagged label values in the labels array to remove. max_syllable (int): the index of the maximum number of syllables to include count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.
Returns: usages (OrderedDict): default dictionary of usages durations (OrderedDict): default dictionary of durations
- moseq2_viz.model.util.get_syllable_usages(data, max_syllable=100, count='usage')
Compute syllable usages for relabeled syllable labels.
Args: data (list): list of syllable frame-labels for each session. max_syllable (int): the index of the maximum number of syllables to include count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.
Returns: usages (dict): dict object that contains usage frequency information.
- moseq2_viz.model.util.labels_to_changepoints(labels, fs=30)
Compute syllable durations and combine into a “changepoint” distribution.
Args: labels (list of np.ndarray of ints): labels loaded from a model fit. fs (float): sampling rate of camera.
Returns: cp_dist (np.ndarray of floats): list of block durations per element in labels list.
- moseq2_viz.model.util.make_separate_crowd_movies(config_data, sorted_index, group_keys, label_dict, output_dir, ordering, sessions=False)
write syllable crowd movies for each given grouping found in group_keys, and return a dictionary crowd movie file information.
Args: config_data (dict): Loaded crowd movie writing configuration parameters. sorted_index (dict): Loaded index file and sorted files in list. group_keys (dict): Dict of group/session name keys paired with UUIDS to match with labels. label_dict (dict): dict of corresponding session UUIDs for all sessions included in labels. output_dir (str): Path to output directory to save crowd movies in. ordering (list): ordering for the new mapping of the relabeled syllable usages. sessions (bool): indicates whether session crowd movies are being generated.
Returns: cm_paths (dict): group/session name keys paired with paths to their respectively generated syllable crowd movies.
- moseq2_viz.model.util.merge_models(model_dir, ext='p', count='usage', force_merge=False, cost_function='ar_norm')
WARNING: THIS IS EXPERIMENTAL. USE AT YOUR OWN RISK. Merge model states by using the Hungarian Algorithm: a minimum distance state matching algorithm. User inputs a directory containing models to merge, (and the name of the latest-trained model) to match other model states to.
Args: model_dir (str): path to directory containing all the models to merge. ext (str): model extension to search for. count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’. force_merge (bool): whether or not to force a merge. Keeping this false will protect you from merging models trained with different kappa values. cost_function (str): either ar_norm or label for the cost function in Hungarian Algorithm.
Returns: model_data (dict): a dictionary containing all the new keys and state-matched labels.
- moseq2_viz.model.util.normalize_pcs(pca_scores: dict, method: str = 'zscore') dict
Normalize PC scores.
Args: pca_scores (dict): dict of uuid to PC-scores key-value pairs. method (str): the type of normalization to perform (demean, zscore, ind-zscore)
Returns: norm_scores (dict): a dictionary of normalized PC scores.
- moseq2_viz.model.util.normalize_usages(usage_dict)
Normalize syllable usages to frequency values from [0,1] instead of total counts.
Args: usage_dict (dict): dictionary containing syllable label keys pointing to total counts.
Returns: usage_dict (dict): dictionary containing syllable label keys pointing to usage frequencies.
- moseq2_viz.model.util.parse_batch_modeling(filename)
Reads model parameter scan training results into a single dictionary.
Args: filename (str): path to h5 manifest file containing all the model results.
Returns: results_dict (dict): dictionary containing each model’s training results, concatenated into a single list.
- moseq2_viz.model.util.parse_model_results(model_obj, restart_idx=0, resample_idx=-1, map_uuid_to_keys: bool = False, sort_labels_by_usage: bool = False, count: str = 'usage') dict
Reads model file and returns dictionary containing modeled results and some metadata.
Args: model_obj (str or results returned from joblib.load): path to the model fit or a loaded model fit restart_idx (int): Select which model restart to load. (Only change for models with multiple restarts used) resample_idx (int): parameter used to select labels from a specific sampling iteration. Default is the last iteration (-1) map_uuid_to_keys (bool): flag to create a label dictionary where each key->value pair contains the uuid and the labels for that session. sort_labels_by_usage (bool): sort and re-assign labels by their usages. count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.
Returns: output_dict (dict): dictionary with labels and model parameters
- moseq2_viz.model.util.prepare_model_dataframe(model_path, pca_path)
Creates a dataframe from syllable labels to be aligned with scalars.
Args: model_path (str): path to model to load label arrays from pca_path (str): path to pca_scores.h5 file.
Returns: _df (pandas.DataFrame): DataFrame object of timestamp aligned syllable label information.
- moseq2_viz.model.util.relabel_by_usage(labels, fill_value=-5, count='usage')
Resort model labels by their usages.
Args: labels (list or dict): label sequences loaded from a model fit fill_value (int): value prepended to modeling results to account for nlags count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.
Returns: labels (list or dict): label sequences sorted by usage sorting (list): the new label sorting. The index corresponds to the new label, while the value corresponds to the old label.
- moseq2_viz.model.util.retrieve_pcs_from_slices(slices, pca_scores, max_dur=60, min_dur=3, max_samples=100, npcs=10, subsampling=None, remove_offset=False, **kwargs)
Subsample Principal components from syllable slices
Args: slices (np.ndarray): syllable slices or subarrays pca_scores (np.ndarray): PC scores for respective session. max_dur (int): maximum syllable length. min_dur (int): minimum syllable length. max_samples (int): maximum number of samples to retrieve. npcs (int): number of pcs to use. subsampling (int): number of syllable subsamples (defined through KMeans clustering). remove_offset (bool): indicate whether to remove initial offset from each PC score. kwargs (dict): used to capture certain arguments in other parts of the codebase.
Returns: syllable_matrix (np.ndarray): 3D matrix of subsampled PC projected syllable slices.
- moseq2_viz.model.util.simulate_ar_trajectory(ar_mat, init_points=None, sim_points=100)
Simulate auto-regressive trajectory matrices from a set of initalized points.
Args: ar_mat (2D np.ndarray): numpy array representing the autoregressive matrix of a model state with shape (npcs, npcs * nlags + 1) init_points (2D np.ndarray): pre-initialzed array of shape (nlags, npcs) sim_points (int): number of time points to simulate.
Returns: sim_mat[nlags:] simulated AR trajectories excluding lagged values.
- moseq2_viz.model.util.sort_batch_results(data, averaging=True, filenames=None, **kwargs)
Sort modeling results from batch/parameter scan.
Args: data (np.ndarray): model AR-matrices. averaging (bool): return an average of all the model AR-matrices. filenames (list): list of paths to fit models. kwargs (dict): dict of extra keyword arguments.
Returns: new_matrix (np.ndarray): either average of all AR-matrices, or top sorted matrix param_dict (dict): model parameter dict filename_index (list): list of filenames associated with each model.
- moseq2_viz.model.util.sort_syllables_by_stat(complete_df, stat='usage', max_sylls=None)
Computes the sorted ordering of the given DataFrame with respect to the chosen stat.
Args: complete_df (pandas.DataFrame): dataframe containing the summary statistics about scalars and syllable data (mean_df/stats_df) stat (str): choice of statistic to order syllables by. max_sylls (int or None): the index of the maximum number of syllables to include
Returns: ordering (list): list of sorted syllables by stat. relabel_mapping (dict): a dict with key-value pairs {old_ordering: new_ordering}.
- moseq2_viz.model.util.sort_syllables_by_stat_difference(complete_df, ctrl_group, exp_group, max_sylls=None, stat='usage')
Compute the syllable ordering for the difference of the inputted groups (exp - ctrl) and sort the syllables by the differences.
Args: complete_df (pandas.DataFrame): dataframe containing the summary statistics about scalars and syllable data (mean_df/stats_df) ctrl_group (str): Control group. exp_group (str): Experimental group. max_sylls (int): the index of the maximum number of syllables to include stat (str): choice of statistic to order mutations by: {usage, duration, speed}.
Returns: ordering (list): list of array indices for the new label mapping.
- moseq2_viz.model.util.syll_duration(labels: ndarray) ndarray
Compute the duration of each syllable.
Args: labels (np.ndarray): array of syllable labels for a session.
Returns: durations (np.ndarray): array of syllable durations.
- moseq2_viz.model.util.syll_id(labels: ndarray) ndarray
Return the syllable label at each onset of a syllable transition.
Args: labels (np.ndarray): array of syllable labels for a mouse.
Returns: labels[onsets] (np.ndarray): an array of compressed labels.
- moseq2_viz.model.util.syll_onset(labels: ndarray) ndarray
Find indices of syllable onsets.
Args: labels (np.ndarray): array of syllable labels for a mouse.
Returns: indices (np.ndarray): an array of indices denoting the beginning of each syllables.
- moseq2_viz.model.util.syllable_slices_from_dict(syllable: int = '__no__default__', labels: Dict[str, ndarray] = '__no__default__', index: Dict = '__no__default__', filter_nans: bool = True) Dict[str, list]
Read a dictionary of syllable labels, and returning a dict of syllable slices.
Args: syllable (list): list of syllables to get slices from. labels (np.ndarrary): list of label predictions for each session. index (dict): index file contents contained in a dict. filter_nans (bool): replace NaN values with 0.
Returns: vals (dict): key-value pairs of syllable slices per session uuid.
- moseq2_viz.model.util.whiten_pcs(pca_scores, method='all', center=True)
Whiten PC scores using Cholesky whitening.
Args: pca_scores (dict): dictionary where values are pca_scores. method (str): ‘all’ to whiten using the covariance estimated from all keys, or ‘each’ to whiten each separately center (bool): whether or not to center the data
Returns: whitened_scores (dict): dictionary of whitened pc scores