moseq2_viz.model package

Model - Dist Module

Utility functions for estimating syllable similarity (behavioral distance).

moseq2_viz.model.dist.get_behavioral_distance(index, model_file, whiten='all', distances=['ar[init]', 'scalars'], max_syllable=None, resample_idx=-1, dist_options={}, sort_labels_by_usage=True, count='usage')

Compute the behavioral distance (square) matrices with respect to a predefined set of variables.

Args: index (str): Path to index file model_file (str): Path to trained model whiten (str): Indicates whether to whiten all PCs at once or each one at a time. Options = [‘all’, ‘each’] distances (list or str): type of distance(s) to compute. Available options = [‘scalars’, ‘ar[init]’, ‘ar[dtw]’, ‘pca[dtw]’, ‘combined’] max_syllable (int): the index of the maximum number of syllables to include resample_idx (int): Indicates the parsing method according to the shape of the labels array. dist_options (dict): Dictionary holding each distance operations configurable parameters sort_labels_by_usage (bool): boolean flag that indicates whether to relabel syllables by count ordering count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.

Returns: dist_dict (dict): Dictionary containing all computed behavioral square distance matrices

moseq2_viz.model.dist.get_behavioral_distance_ar(ar_mat, init_point=None, sim_points=10, max_syllable=40, dist='correlation', parallel=False)

Compute behavioral distance with respect to the model’s AutoRegressive matrices. The function Affords either AR trajectory correlation distance, or computing dynamically time-warped trajectory distances.

Parameters:
  • ar_mat (numpy.ndarray) (Trained model AutoRegressive matrices; shape=(max_syllable, npcs, npcs*nlags+1))

  • init_point (list) (Initial values as a reference point for distance estimation)

  • sim_points (int) (number of time points to simulate)

  • max_syllable (int) (the index of the maximum number of syllables to include)

  • dist (str) (Distance operation to compute. Either ‘correlation’ or ‘dtw’.)

  • parallel (bool) (Boolean flag that indicates whether to use multiprocessing to compute dtw distances.)

Returns:

  • ar_dist (2D numpy array) (Computed AR trajectory distances for each AR matrix/model state.)

  • shape=(max_syllable, max_syllable)

moseq2_viz.model.dist.get_init_points(pca_scores, model_labels, max_syllable=40, nlags=3, npcs=10)

Compute initial AR trajectories based on a cumulative average of lagged-PC Scores over nlags.

Args: pca_scores (numpy.ndarray): Loaded PC Scores. Shape=(npcs, nsamples) model_labels (list): list of 1D numpy arrays of relabeled/sorted syllable labels max_syllable (int): the index of the maximum number of syllables to include nlags (int): Number of lagged frames. npcs (int): Number of PCs to use in computation.

Returns: syll_average (list): List containing 2D np arrays of average syllable trajectories over a nlag-strided PC scores array.

moseq2_viz.model.dist.reformat_dtw_distances(full_mat, nsyllables, rescale=True)

Reduce full (max states) dynamically time-warped PC Score distance matrices to only include dimensions for a total of nsyllables.

Args: full_mat (np.ndarray): DTW distance matrices for all model states/syllables. nsyllables (int): Number of syllables to include in truncated DTW distance matrix. rescale (bool): Rescale truncated dtw-distance matrices to match output distribution.

Returns: rmat (2D np array): Reformatted-Truncated DTW Distance Matrix; shape = (nsyllables, nsyllables)

Model - embed Module

Functions to run PCA and LDA on syllable usages and scalars

moseq2_viz.model.embed.plot_embedding(L, y, mapping, rev_mapping, output_file='embedding.pdf', embedding='PCA', x_dim=0, y_dim=1, symbols='o*v^s', plot_all_subjects=True)

Plot 2D embedding plot.

Args: L (2D np.array): the embedding representations of the mean syllable statistic to plot. y (1D list): list of group names corresponding to each row in L. mapping (dict): dictionary conataining mappings from group string to integer for later embedding. rev_mapping (dict): inverse mapping dict to retrieve the group names given their mapped integer value. output_file (str): path to saved outputted figure embedding (str): type of embedding to run. Either [‘lda’, ‘pca’]. x_dim (int): component number to graph on x-axis y_dim (int): component number to graph on y-axis symbols (str): symbols to use to draw different groups. plot_all_subjects (bool): boolean flag that indicates whether to plot individual subject embeddings along with their respective group means.

Returns: fig (matplotlib.figure): figure containing plotted 2d embedding. ax (matplonlib.axes): axes instance for plotted figure.

moseq2_viz.model.embed.run_2d_embedding(mean_df, stat='usage', output_file='2d_embedding.pdf', embedding='PCA', n_components=2, plot_all_subjects=True)

Compute a 2D embedding (PCA or LDA) of the mean syllable statistic of choice. The function will output a figure of the 2D representation of the embedding.

Args: mean_df (pd DataFrame): Dataframe of the mean syllable statistics for all sessions stat (str): name of statistic (column) in mean_df to embed. output_file (str): path to saved outputted figure embedding (str): type of embedding to run. Either [‘lda’, ‘pca’] n_components (int): Number of components to compute. plot_all_subjects (bool): indicates whether to plot individual subject embeddings along with their respective group means.

Returns: fig (matplotlib.figure): figure containing plotted 2d embedding. ax (matplonlib.axes): axes instance for plotted figure.

moseq2_viz.model.embed.run_2d_scalar_embedding(scalar_df, output_file='2d_scalar_embedding.pdf', embedding='PCA', n_components=2, plot_all_subjects=True)

Compute a 2D embedding (PCA or LDA) of the mean measured scalar values for all groups. The function will output a figure of the 2D representation of the embedding.

Args: scalar_df (pd DataFrame): Dataframe of the frame-by-frame scalar measurements for all sessions output_file (str): path to saved outputted figure embedding (str): type of embedding to run. Either [‘lda’, ‘pca’] n_components (int): Number of components to compute. plot_all_subjects (bool): indicates whether to plot individual subject embeddings along with their respective group means.

Returns: fig (matplotlib figure): figure containing plotted 2d embedding. ax (matplonlib axes): axes instance for plotted figure.

Model - Fingerprint and Classifier Module

Functions for creating fingerprint plots and linear classifier

moseq2_viz.model.fingerprint_classifier.classifier_fingerprint(summary, features=['MoSeq'], preprocessor=None, classes=['group'], param_search=True, C_list=None, model_type='lr', cv='loo', n_splits=5)

run classifier using the fingerprint dataframe

Args:

summary (pandas.DataFrame): fingerprint dataframe features (list, optional): Features for the classifier. [‘MoSeq’] for MoSeq syllables or a list of MoSeq scalar values. Defaults to [‘MoSeq’]. preprocessor (sklearn.preprocessing object, optional): Scalar for scaling the data by feature. Defaults to None. target (list, optional): labels the classifier predicts. Defaults to [‘group’]. param_search (bool, optional): run GridSearchCV to find the regularization param for classifier. Defaults to True. C_list (numpy.array, optional): list of C regularization paramters to search through. Defaults to None. If None, C_list will search through np.logspace(-6,3, 50) model_type (str, optional): name of the linear classifier. ‘lr’ for logistic regression or ‘svc’ for linearSVC. Defaults to ‘lr’. cv (str, optional): cross validation type. ‘loo’ for LeaveOneOut ‘skf’ for StratifiedKFold. Defaults to ‘loo’. n_splits (int, optional): number of splits for StratifiedKFold. Defaults to 5.

Returns:

y_true (np.array): array for true label y_pred (np.array): array for predicted label real_f1 (np.array): array for f1 score true_coef (np.array): array for model weights y_shuffle_true (np.array): array for shffuled label y_shuffle_pred (np.array): array for shuffled predicted label shuffle_f1 (np.array): array for shuffled f1 score shuffle_coef (np.array): array for shuffled model weights

moseq2_viz.model.fingerprint_classifier.create_fingerprint_dataframe(scalar_df, mean_df, stat_type='mean', n_bins=None, groupby_list=['group', 'uuid'], range_type='robust', scalars=['velocity_2d_mm', 'height_ave_mm', 'length_mm', 'dist_to_center_px'])

create fingerprint dataframe from scalar_df and mean_df

Args:

scalar_df (pandas.DataFrame): scalar summary dataframe generated from scalars_to_dataframe mean_df (pandas.DataFrame): syllable mean dataframe from compute_behavioral_statistics bin_num (int, optional): number of bins for the features. Defaults to None. groupby_list (list, optional): the list of levels the fingerprint dataframe should be grouped by. Defaults to [‘group’, ‘uuid’].

Returns:

summary ([pandas.DataFrame]): fingerprint dataframe range_dict ([dict]): dictionary that hold min max values of the features

moseq2_viz.model.fingerprint_classifier.plot_cm(y_true, y_pred, y_shuffle_true, y_shuffle_pred)

plot confusion matrix

Args:

y_true (np.array): array for true label y_pred (np.array): array for predicted label y_shuffle_true (np.array): array for shffuled label y_shuffle_pred (np.array): array for shuffled predicted label

moseq2_viz.model.fingerprint_classifier.plotting_fingerprint(summary, save_dir, range_dict, figsize=(20, 18), preprocessor=None, num_level=1, level_names=['Group'], vmin=None, vmax=None, plot_columns=['dist_to_center_px', 'velocity_2d_mm', 'height_ave_mm', 'length_mm', 'MoSeq'], col_names=[('Position', 'Dist. from center (px)'), ('Speed', 'Speed (mm/s)'), ('Height', 'Height (mm)'), ('Length', 'Length (mm)'), ('MoSeq', 'Syllable ID')])

plot the fingerprint heatmap

Args:

summary (pandas.DataFrame): fingerprint dataframe range_dict (pandas.DataFrame): pd.DataFrame that hold min max values of the features preprocessor (sklearn.preprocessing object, optional): Scalar for scaling the data by session. Defaults to None. num_level (int, optional): the number of groupby levels. Defaults to 1. level_names (list, optional): list of names of the levels. Defaults to [‘Group’]. vmin (int, optional): min value the figure color map covers. Defaults to 0. vmax (float, optional): max value the figure color map covers. Defaults to 0.2. plot_columns (list, optional): columns to plot col_names = (list, optional): list of (column name, x label) pairs

Raises:

Exception: num_levels greater than the existing levels

moseq2_viz.model.fingerprint_classifier.robust_max(v)

find relative max

Args:

v (pandas.Series): the pandas.Series to find the robust max.

Returns:

(numpy.float): the robust max in the series.

moseq2_viz.model.fingerprint_classifier.robust_min(v)

find relative min

Args:

v (pandas.Series): the pandas.Series to find the robust min.

Returns:

(numpy.float): the robust min in the series.

Model - Stats Module

Functions for statistical tests for analyzing model results.

moseq2_viz.model.stat.bootstrap_group_means(df, group1, group2, statistic='usage', max_syllable=40)

compute boostrapped group means

Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of group 1 to compare. group2 (str): Name of group 2 to compare. statistic (str): Syllable statistic to compute bootstrap means for. max_syllable (int): the index of the maximum number of syllables to include

Returns: boots (dictionary): dictionary of group name (keys) paired with their bootstrapped statistics numpy array.

moseq2_viz.model.stat.bootstrap_me(usages, n_iters=10000)

Bootstrap the inputted stat data using random sampling with replacement.

Args: usages (np.array): Data to bootstrap n_iters (int): Number of samples to return.

Returns: boots (np.array): Bootstrapped input array of shape

moseq2_viz.model.stat.compute_pvalues_for_group_pairs(real_zs_within_group, null_zs, df_k_real, group_names, n_perm=10000, thresh=0.05, mc_method='fdr_bh', verbose=False)

Adjust the p-values from Dunn’s z-test statistics and computes the resulting significant syllables with the adjusted p-values.

Args: real_zs_within_group (dict): dict of group pair keys paired with vector of Dunn’s z-test statistics null_zs (dict): dict of group pair keys paired with vector of Dunn’s z-test statistics of the null hypothesis. df_k_real (pandas.DataFrame): DataFrame of KW test results. group_names (pd.Index): Index list of unique group names. n_perm (int): Number of permuted samples to generate. thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use. verbose (bool): indicates whether to print out the significant syllable results

Returns: df_pval_corrected (pandas.DataFrame): DataFrame containing Dunn’s test results with corrected p-values. significant_syllables (list): List of corrected KW significant syllables (syllables with p-values < thresh)

moseq2_viz.model.stat.dunns_z_test_permute_within_group_pairs(df_usage, vc, real_ranks, X_ties, N_m, group_names, rnd, n_perm)

Run Dunn’s z-test statistic on combinations of all group pairs, handling pre-computed tied ranks.

Args: df_usage (pandas.DataFrame): DataFrame containing only pre-computed syllable stats. vc (pd.Series): value counts of sessions in each group. real_ranks (np.array): Array of syllable ranks. X_ties (np.array): 1-D list of tied ranks, where if value > 0, then rank is tied N_m (int): Number of sessions. group_names (pd.Index): Index list of unique group names. rnd (np.random.RandomState): Pseudo-random number generator. n_perm (int): Number of permuted samples to generate.

Returns: null_zs_within_group (dict): dict of group pair keys paired with vector of Dunn’s z-test statistics of the null hypothesis. real_zs_within_group (dict): dict of group pair keys paired with vector of Dunn’s z-test statistics

moseq2_viz.model.stat.get_session_mean_df(df, statistic='usage', max_syllable=40)

Compute a given mean syllable statistic grouped by groups and UUIDs.

Args: df (pandas.DataFrame): dataframe that contains average syllable statistics per session (mean_df/stats_df) statistic (str): statistic to compute mean for, (any of the columns in input df); for example: ‘usage’, ‘duration’, ‘velocity_2d_mm’, etc. max_syllable (int): the index of the maximum number of syllables to include

Returns: df_pivot (pandas.DataFrame): Mean syllable statistic per session.

moseq2_viz.model.stat.get_sig_syllables(df_pvals, thresh=0.05, mc_method='fdr_bh', verbose=False)

Runs multiple p-value comparisons test given a set alpha threshold with mutliple test correction.

Args: df_pvals (pandas.DataFrame): dataframe listing raw p-values thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use. verbose (bool): indicates whether to print out the significant syllable results

Returns: df_pvals (pandas.DataFrame): updated dataframe listing adjusted p-values

moseq2_viz.model.stat.get_tie_correction(x, N_m)

assign tied rank values to the average of the ranks they would have received if they had not been tied for Kruskal-Wallis helper function.

Args: x (pd.Series): syllable usages for a single session. N_m (int): Number of total sessions.

Returns: corrected_rank (float): average of the inputted tied rank

moseq2_viz.model.stat.mann_whitney(df, group1, group2, statistic='usage', max_syllable=40, verbose=False, **kwargs)

Runs a Mann-Whitney hypothesis test with multiple test corrections on two given groups to find significant syllables. Also runs multiple corrections test to find syllables to exclude.

Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of first group group2 (str): Name of second group statistic (str): Name of statistic to compute z-test on. max_syllable (int): the index of the maximum number of syllables to include verbose (bool): indicates whether to print out the significant syllable results thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use.

Returns: df_mw_real (pandas.DataFrame): DataFrame containing Mann-Whitney U corrected results. exclude_sylls (list): list of syllables that were excluded via multiple comparisons test.

moseq2_viz.model.stat.plot_H_stat_significance(df_k_real, h_all, N_s)

Plot the assigned H-statistic for each syllable computed via manual KW test.

Args: df_k_real (pandas.DataFrame): the dataframe that contains Kruskal-Wallis p value, H stats and whether it is significant h_all (np.array): Array of H-stats computed for given n_syllables; shape = (n_perms, N_s) N_s (int): Number of syllables to plot

Returns: fig (pyplot.figure): plotted H-stats plot ax (pyplot.axis): plotted H-stats axis

moseq2_viz.model.stat.run_kruskal(df, statistic='usage', max_syllable=40, n_perm=10000, seed=42, thresh=0.05, mc_method='fdr_bh', verbose=False)

Run Kruskal-Wallis Hypothesis test and Dunn’s posthoc multiple comparisons test for syllable statistic.

Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) statistic (str): statistic to compute mean for, (any of the columns in input df). max_syllable (int): the index of the maximum number of syllables to include n_perm (int): Number of permuted samples to generate. seed (int): Random seed used to initialize the pseudo-random number generator. thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use. verbose (bool): indicates whether to print out the significant syllable results

Returns: df_k_real (pandas.DataFrame): DataFrame of KW test results. dunn_results_df (pandas.DataFrame): DataFrame of Dunn’s test results for permuted group pairs. intersect_sig_syllables (dict): dictionary containing intersecting significant syllables between KW and Dunn’s tests.

moseq2_viz.model.stat.run_manual_KW_test(df_usage, merged_usages_all, num_groups, n_per_group, cum_group_idx, n_perm=10000, seed=0)

Run a manual Kruskal-Wallis test compare the results agree with the scipy.stats.s

Args: df_usage (pandas.DataFrame): DataFrame containing only pre-computed syllable stats. shape = (N_m, n_syllables) merged_usages_all (np.array): numpy array format of the df_usage DataFrame. num_groups (int): Number of unique groups n_per_group (list): list of value counts for sessions per group. len == num_groups. cum_group_idx (list): list of indices for different groups. len == num_groups + 1. n_perm (int): Number of permuted samples to generate. seed (int): Random seed used to initialize the pseudo-random number generator.

Returns: h_all (np.array): Array of H-stats computed for given n_syllables; shape = (n_perms, N_s) real_ranks (np.array): Array of syllable ranks, shape = (N_m, n_syllables) X_ties (np.array): 1-D list of tied ranks, where if value > 0, then rank is tied. len(X_ties) = n_syllables

moseq2_viz.model.stat.run_pairwise_stats(df, group1, group2, test_type='mw', verbose=False, **kwargs)

Run hypothesis testing functions: MannWhitney, Z-Test and T-Test.

Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of first group group2 (str): Name of second group test_type (str): specifying which type of statistical test verbose (bool): boolean flag that indicates whether to print out the significant syllable results

Returns:

df_pvals (pandas.DataFrame): Dataframe listing the p-values and which syllables are significant

moseq2_viz.model.stat.ttest(df, group1, group2, statistic='usage', max_syllable=40, verbose=False, **kwargs)

Computes a t-hypothesis test on 2 selected groups to find significant syllables with multiple test correction.

Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of first group group2 (str): Name of second group statistic (str): Name of statistic to compute t-test on. max_syllable (int): the index of the maximum number of syllables to include verbose (bool): indicates whether to print out the significant syllable results. thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use.

Returns: p (np.array): Computed array of p-values syllables_to_include (list): List of significant syllables after multiple corrections.

moseq2_viz.model.stat.ztest(df, group1, group2, statistic='usage', max_syllable=40, verbose=False, **kwargs)

Computes a z hypothesis test on 2 (bootstrapped) selected groups with multiple test correction.

Args: df (pandas.DataFrame): dataframe that contains syllable statistics per session (mean_df/stats_df) group1 (str): Name of first group group2 (str): Name of second group statistic (str): Name of statistic to compute z-test on. max_syllable (int): the index of the maximum number of syllables to include verbose (bool): boolean flag that indicates whether to print out the significant syllable results thresh (float): Alpha threshold to consider syllable significant. mc_method (str): Multiple Corrections method to use.

Returns: pvals_ztest_boots (np.array): Computed array of p-values syllables_to_include (list): List of significant syllables after multiple corrections.

moseq2_viz.model.stat.ztest_vect(d1, d2)

Perform a z-test on a pair of bootstrapped syllable statistics.

Args: d1 (np.array): bootstrapped syllable stat array from group 1 d2 (np.array): bootstrapped syllable stat array from group 2

Returns: p-values (np.array): array of computed p-values of the syllables.

Model - Transition Graph Module

Visualization and utility functions for transition matrices.

moseq2_viz.model.trans_graph.compute_and_graph_grouped_TMs(config_data, labels, label_group, group)

compute a transition matrix for each given group.

Args: config_data (dict): configuration dictionary containing graphing parameters labels (list): list of 1D numpy arrays containing syllable labels per frame for every included session label_group (list): list of corresponding group names to plot transition aggregated transition plots group (list): unique list of groups to plot.

Returns: plt (pyplot.Figure): open transition graph figure to save

moseq2_viz.model.trans_graph.convert_ebunch_to_graph(ebunch)

Convert transition matrices to transition DAGs.

Args: ebunch (list of tuples): syllable transition data

Returns: g (networkx.DiGraph): DAG object to graph

moseq2_viz.model.trans_graph.convert_transition_matrix_to_ebunch(weights, transition_matrix, usages=None, usage_threshold=-0.1, speeds=None, speed_threshold=0, edge_threshold=-0.1, indices=None, keep_orphans=False, max_syllable=None)

Compute thresholded syllable transition data by usages and transition probabilities.

Args: weights (np.ndarray): syllable transition edge weights transition_matrix (np.ndarray): syllable transition matrix usages (list): list of syllable usages usage_threshold (float): threshold syllable usage to include a syllable in list of orphans speeds (np.array): list of syllable speeds speed_threshold (int): threshold value for syllable speeds to include edge_threshold (float): threshold transition probability to consider an edge part of the graph. indices (list): indices of syllable bigrams to plot keep_orphans (bool): indicate whether to graph orphan syllables max_syllable (int): the index of the maximum number of syllables to include

Returns: ebunch (list): syllable transition data. orphans (list): syllables with no edges.

moseq2_viz.model.trans_graph.draw_graph(graph, width, pos, node_color, node_size, node_edge_colors, ax, arrows=False, font_size=12, edge_colors='k', title=None)

Draw transition graph to existing matplotlib axes.

Args: graph (nx.DiGraph): list of created nx.DiGraphs converted from transition matrices width (list): list of edge widths corresponding to each graph’s edges. pos (nx.Layout): nx.Layout type object holding position coordinates for the nodes. node_color (list): list of node colors for each graph. node_sizes (int or list): list of node sizes for each graph. node_edge_colors (list): list of node edge colors for each graph. ax (mpl.pyplot.axis): axis to draw graphs on. arrows (bool): whether to draw arrow edges font_size (int): Node label font size edge_colors (str): color of the transition edges drawn in the graph. title (str): title/group name of the transition graph.

Returns:

moseq2_viz.model.trans_graph.get_group_trans_mats(labels, label_group, group, max_sylls, normalize='bigram')

Compute individual transition matrices for each given group.

Args: labels (np.ndarray): list of frame labels for each included session label_group (list): list of groups for each included session group (list): list of unique groups included max_sylls (int): Maximum number of syllables to include normalize (str): indicates how to normalize the computed transition matrices.

Returns: trans_mats (list of 2D np.ndarrays): list of transition matrices for each given group. usages (list of lists): list of corresponding usage statistics for each group.

moseq2_viz.model.trans_graph.get_pos(graph_anchor, layout, nnodes)

Get node positions in the graph based on the graph anchor and a user selected layout.

Args: graph_anchor (nx.Digraph): graph to get node layout for layout (str): networkx layout type nnodes (int): number of nodes in the graph

Returns: pos (nx layout): computed node position layout

moseq2_viz.model.trans_graph.get_trans_graph_groups(model_fit)

Get the groups and their respective session uuids to use in transition graph generation.

Args: model_fit (dict): trained model ARHMM containing training data UUIDs.

Returns: label_group (list): list of groups for each included session model_uuids (list): list of corresponding UUIDs for each included session in the model

moseq2_viz.model.trans_graph.get_transition_matrix(labels, max_syllable=100, normalize='bigram', smoothing=0.0, combine=False, disable_output=False) list

Compute the transition matrix from a set of model labels.

Args: labels (list of numpy.array): labels loaded from a model fit max_syllable (int): the index of the maximum number of syllables to include normalize (str): how to normalize transition matrix, ‘bigram’ or ‘rows’ or ‘columns’ smoothing (float): constant to add to transition_matrix pre-normalization to smooth counts combine (bool): flag for computing a separate transition matrix for each element (False) or combine across all arrays in the list (True) disable_output (bool): flag to display a TQDM progress bar for transition matrix computation process.

Returns: transition_matrix (list or np.ndarray): list of 2d np.arrays that represent the transitions

moseq2_viz.model.trans_graph.get_transitions(label_sequence)

Computes syllable transitions.

Args: label_sequence (tuple): a tuple of syllable transitions and their indices

Returns: transitions (np.array): filtered label sequence containing only the syllable changes locs (np.array): list of all the indices where the syllable label changes

moseq2_viz.model.trans_graph.graph_transition_matrix(trans_mats, usages=None, groups=None, edge_threshold=0.0025, anchor=0, usage_threshold=0, layout='spring', edge_width_scale=100, fig=None, ax=None, width_per_group=8, headless=False, difference_threshold=0.0005, weights=None, usage_scale=10000.0, keep_orphans=False, max_syllable=None, orphan_weight=0, arrows=False, font_size=12, difference_edge_width_scale=500, **kwargs)

Create transition graph plot given a transition matrix and some metadata.

Args: trans_mats (np.ndarray): syllable transition matrix usages (list): list of syllable usage probabilities groups (list): list groups to graph transition graphs for. edge_threshold (float): threshold to include edge in graph anchor (int): syllable index as the base syllable usage_threshold (int): threshold to include syllable usages layout (str): layout format edge_width_scale (int): edge line width scaling factor fig (pyplot.figure): figure to plot to ax (pyplot.Axes): axes object width_per_group (int): graph width scaling factor per group headless (bool): exclude first node. difference_threshold (float): threshold to consider 2 graph elements different weights (list): list of edge weights usage_scale (float): syllable usage scaling factor keep_orphans (bool): plot orphans. max_syllable (int): the index of the maximum number of syllables to include orphan_weight (int): scaling factor to plot orphan node sizes arrows (bool): indicate whether to plot arrows as transitions. difference_edge_width_scale (float): difference graph edge line width scaling factor kwargs (dict): extra keyword arguments

Returns: fig (pyplot.figure): figure containing transition graphs. ax (pyplot.axis): figure axis object. pos (dict): dict figure information.

moseq2_viz.model.trans_graph.make_difference_graphs(trans_mats, usages, group, group_names, usage_kwargs, widths, pos, node_edge_colors, ax=None, node_sizes=[], indices=None, difference_threshold=0.0005, difference_edge_width_scale=500, font_size=12, usage_scale=50000.0, difference_graphs=[], scalars=None, arrows=False, speed_kwargs={})

compute transition graph differences bewtween two groups.

Args: trans_mats (np.ndarray): syllable transition matrix. usages (list): list of syllable usage probabilities. group (list): list groups to graph transition graphs for. group_names (list): list groups names to display with transition graphs. usage_kwargs (dict): kwargs for graph threshold settings using usage. Keys can be ‘usages’, and ‘usage_threshold’ widths (list): list of edge widths for each created single-group graph. pos (nx.Layout): nx.Layout type object holding position coordinates for the nodes. node_edge_colors (list): node edge colors (of type str). ax (np.ndarray matplotlib.pyplot.Axis): Optional axes to plot graphs in node_sizes (list): node size scaling factor (of type int) indices (list): list of in->out syllable indices to keep in graph difference_threshold (float): threshold to consider 2 graph elements different. difference_edge_width_scale (int): scaling factor for edge widths in difference transition graphs. font_size (int): indicates the size of the numbers drawn on the transition graph nodes. usage_scale (float): syllable usage scaling factor. difference_graphs (list): list of created difference transition graphs. scalars (dict): dict of syllable scalar data per transition graph arrows (bool): indicates whether to display arrows between node transitions speed_kwargs (dict): kwargs for graph threshold settings using usage. Keys can be ‘speeds’, and ‘speed_threshold’

Returns: usages (list): list of syllable usage probabilities including usages differences across groups group_names (list): list groups names to display with transition graphs including difference graphs. difference_graphs (list): list of computed difference graphs widths (list): list of edge widths for each created graph appended with difference weights node_sizes (list): lists of node sizes corresponding to each graph including difference graph node sizes node_edge_colors (list): lists of node colors corresponding to each graph including difference graph node sizes

moseq2_viz.model.trans_graph.make_transition_graphs(trans_mats, usages, group, group_names, usage_kwargs, pos, orphans, edge_threshold=0.0025, difference_threshold=0.0005, orphan_weight=0, ax=None, edge_width_scale=100, usage_scale=100000.0, difference_edge_width_scale=500, speed_kwargs={}, indices=None, font_size=12, scalars=None, arrows=False)

create transition matrices for all included groups, as well as their difference graphs.

Args: trans_mats (np.ndarray): syllable transition matrix. usages (list): list of syllable usage probabilities. group (list): list groups to graph transition graphs for. group_names (list): list groups names to display with transition graphs. usage_kwargs (dict): kwargs for graph threshold settings using usage. Keys can be ‘usages’, and ‘usage_threshold’ pos (nx.Layout): nx.Layout type object holding position coordinates for the nodes. orphans (list): list of nodes with no edges. edge_threshold (float): threshold to include edge in graph. difference_threshold (float): threshold to consider 2 graph elements different. orphan_weight (int): scaling factor to plot orphan node sizes. ax (np.ndarray matplotlib.pyplot Axis): Optional axes to plot graphs in edge_width_scale (int): edge line width scaling factor. usage_scale (float): syllable usage scaling factor. difference_edge_width_scale (int): scaling factor for edge widths in difference transition graphs. speed_kwargs (dict): kwargs for graph threshold settings using usage. Keys can be ‘speeds’, and ‘speeds_threshold’ indices (list): list of in->out syllable indices to keep in graph font_size (int): indicates the size of the numbers drawn on the transition graph nodes. scalars (dict): dict of syllable scalar data per transition graph arrows (bool): indicates whether to display arrows between node transitions

Returns: usages (list): list of syllable usage probabilities including possible appended difference usages. group_names (list): list groups names to display with transition graphs including difference graphs. widths (list): list of edge widths for each created graph appended with difference weights. node_sizes (2D list): lists of node sizes corresponding to each graph. node_edge_colors (2D list): lists of node colors corresponding to each graph including difference graph node sizes. graphs (list of nx.DiGraph): list of all group and difference transition graphs.

moseq2_viz.model.trans_graph.n_gram_transition_matrix(labels, n=2, max_label=99)

Compute the transition count for a fixed syllable sequence length ‘n’.

Args: labels (list of np.array of ints): syllable slabels loaded from a model fit. n (int): length of transition chain to compute transition probability for. max_label (int): max number of syllables to scan for in transition matrix.

Returns: trans_mat (np.ndarray): array of n-transition counts for given max_label.

moseq2_viz.model.trans_graph.normalize_transition_matrix(init_matrix, normalize)

Normalize a transition matrix by given criteria.

Args: init_matrix (np.array): transition matrix to normalize. normalize (str): normalization criteria; [‘bigram’, ‘rows’, ‘columns’, or None]

Returns: init_matrix (np.array): normalized transition matrix

Model - Utilities Module

Utility functions for handling model data during pre and post processing.

moseq2_viz.model.util.add_duration_column(scalar_df)

Add syllable duration column to scalar dataframe if the dataframe contains syllable labels.

Args: scalar_df (pandas.DataFrame): merged dataframe of scalar data and syllable data.

Returns: scalar_df (pandas.DataFrame): Same DataFrame with a new column titled “duration”.

moseq2_viz.model.util.compute_behavioral_statistics(scalar_df, groupby=['group', 'uuid'], count='usage', fps=30, usage_normalization=True, syllable_key='labels (usage sort)')

Compute syllable statistics merged with the scalar features.

Args: scalar_df (pandas.DataFrame): Scalar measuresments for full dataset, including metadata for all the sessions. groupby (list of strings): list of columns to run the pandas groupby() on the scalar_df. count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’. fps (int): frames per second that the data was acquired in. usage_normalization (bool): indicates whether to normalize syllable usages by the value counts. syllable_key (str): column to rename to “syllable” for convenient referencing later on.

Returns: features (pandas.DataFrame): full feature Dataframe with scalars, metadata, and syllable statistics.

moseq2_viz.model.util.compute_syllable_explained_variance(model, save_dir='/home/wingillis/dev/moseq/moseq2-viz/docs', n_explained=99)

Compute the maximum number of syllables to include that explain n_explained percent of all frames in the dataset.

Args: model (dict): ARHMM results dict n_explained (int): explained variance percentage threshold

Returns: max_sylls (int): the index of the maximum number of syllables to include that explain the given percentage of the variance

moseq2_viz.model.util.compute_syllable_onset(labels)

Computes the onset index of the each syllable label in a Series.

Args: labels (list or dict): label sequences loaded from a model fit

Returns: onsets (2D np.array): onset indices for each syllable for the given sessions.

moseq2_viz.model.util.get_Xy_values(stat_means, unique_groups, stat='usage')

Compute the syllable or scalar mean statistics for each session, stored in X.

Args: stat_means (pd DataFrame): Dataframe of syllable or session-scalar mean statistics unique_groups (list): list of unique groups in the syll_means dataframe. stat (str or list): statistic column(s) to read from the syll_means df.

Returns: X (2D np.array): mean syllable or scalar statistics for each session. (nsessions x nsyllables) y (1D list): list of group names corresponding to each row in X. mapping (dict): dictionary conataining mappings from group string to integer for later embedding. rev_mapping (dict): inverse mapping dict to retrieve the group names given their mapped integer value.

moseq2_viz.model.util.get_best_fit(cp_path, model_results)

Return the model with the closest median syllable duration and closest duration distribution to the model free changepoints given the objective.

Args: cp_path (str): Path to PCA Changepoints h5 file. model_results (dict): dict of pairs of model names paired with dict containing their respective changepoints.

Returns: info (dict): information about the best-fit models. pca_cps (1D array): pc score changepoint durations.

moseq2_viz.model.util.get_mouse_syllable_slices(syllable: int, labels: ndarray) Iterator[slice]

Return a list containing slices of syllable indices for a mouse.

Args: syllable (list): list of syllables to get slices from. labels (np.ndarrary): list of label predictions for each session.

Returns: slices (list): list of syllable label slices; e.g. [slice(3, 6, None), slice(9, 12, None)]

moseq2_viz.model.util.get_normalized_syllable_usages(model_data, max_syllable=100, count='usage')

Compute syllable usages and normalizes to sum to 1 and return a 1D array of their corresponding usage values.

Args: model_data (dict): dict object of modeling results max_syllable (int): the index of the maximum number of syllables to include count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.

Returns: syllable_usages (1D np array): array of sorted syllable usages for all syllables in model

moseq2_viz.model.util.get_syllable_slices(syllable='__no__default__', labels='__no__default__', label_uuids='__no__default__', index='__no__default__', trim_nans: bool = True) list

Get the indices that correspond to a specific syllable for each session in a modeling run.

Args: syllable (int): syllable number to get slices of. labels (np.ndarrary): list of label predictions for each session. label_uuids (list): list of uuid keys corresponding to each session. index (dict): index file contents contained in a dict. trim_nans (bool): flag to use the pc scores file for removing time points that contain NaNs.

Returns: syllable_slices (list): a list of indices for syllable in the labels array. Each item in the list is a tuple of (slice, uuid, h5_file).

moseq2_viz.model.util.get_syllable_statistics(data, fill_value=-5, max_syllable=100, count='usage')

Compute the usage and duration statistics from a set of model labels

Args: data (list of np.array of ints): labels loaded from a model fit. fill_value (int): lagged label values in the labels array to remove. max_syllable (int): the index of the maximum number of syllables to include count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.

Returns: usages (OrderedDict): default dictionary of usages durations (OrderedDict): default dictionary of durations

moseq2_viz.model.util.get_syllable_usages(data, max_syllable=100, count='usage')

Compute syllable usages for relabeled syllable labels.

Args: data (list): list of syllable frame-labels for each session. max_syllable (int): the index of the maximum number of syllables to include count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.

Returns: usages (dict): dict object that contains usage frequency information.

moseq2_viz.model.util.labels_to_changepoints(labels, fs=30)

Compute syllable durations and combine into a “changepoint” distribution.

Args: labels (list of np.ndarray of ints): labels loaded from a model fit. fs (float): sampling rate of camera.

Returns: cp_dist (np.ndarray of floats): list of block durations per element in labels list.

moseq2_viz.model.util.make_separate_crowd_movies(config_data, sorted_index, group_keys, label_dict, output_dir, ordering, sessions=False)

write syllable crowd movies for each given grouping found in group_keys, and return a dictionary crowd movie file information.

Args: config_data (dict): Loaded crowd movie writing configuration parameters. sorted_index (dict): Loaded index file and sorted files in list. group_keys (dict): Dict of group/session name keys paired with UUIDS to match with labels. label_dict (dict): dict of corresponding session UUIDs for all sessions included in labels. output_dir (str): Path to output directory to save crowd movies in. ordering (list): ordering for the new mapping of the relabeled syllable usages. sessions (bool): indicates whether session crowd movies are being generated.

Returns: cm_paths (dict): group/session name keys paired with paths to their respectively generated syllable crowd movies.

moseq2_viz.model.util.merge_models(model_dir, ext='p', count='usage', force_merge=False, cost_function='ar_norm')

WARNING: THIS IS EXPERIMENTAL. USE AT YOUR OWN RISK. Merge model states by using the Hungarian Algorithm: a minimum distance state matching algorithm. User inputs a directory containing models to merge, (and the name of the latest-trained model) to match other model states to.

Args: model_dir (str): path to directory containing all the models to merge. ext (str): model extension to search for. count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’. force_merge (bool): whether or not to force a merge. Keeping this false will protect you from merging models trained with different kappa values. cost_function (str): either ar_norm or label for the cost function in Hungarian Algorithm.

Returns: model_data (dict): a dictionary containing all the new keys and state-matched labels.

moseq2_viz.model.util.normalize_pcs(pca_scores: dict, method: str = 'zscore') dict

Normalize PC scores.

Args: pca_scores (dict): dict of uuid to PC-scores key-value pairs. method (str): the type of normalization to perform (demean, zscore, ind-zscore)

Returns: norm_scores (dict): a dictionary of normalized PC scores.

moseq2_viz.model.util.normalize_usages(usage_dict)

Normalize syllable usages to frequency values from [0,1] instead of total counts.

Args: usage_dict (dict): dictionary containing syllable label keys pointing to total counts.

Returns: usage_dict (dict): dictionary containing syllable label keys pointing to usage frequencies.

moseq2_viz.model.util.parse_batch_modeling(filename)

Reads model parameter scan training results into a single dictionary.

Args: filename (str): path to h5 manifest file containing all the model results.

Returns: results_dict (dict): dictionary containing each model’s training results, concatenated into a single list.

moseq2_viz.model.util.parse_model_results(model_obj, restart_idx=0, resample_idx=-1, map_uuid_to_keys: bool = False, sort_labels_by_usage: bool = False, count: str = 'usage') dict

Reads model file and returns dictionary containing modeled results and some metadata.

Args: model_obj (str or results returned from joblib.load): path to the model fit or a loaded model fit restart_idx (int): Select which model restart to load. (Only change for models with multiple restarts used) resample_idx (int): parameter used to select labels from a specific sampling iteration. Default is the last iteration (-1) map_uuid_to_keys (bool): flag to create a label dictionary where each key->value pair contains the uuid and the labels for that session. sort_labels_by_usage (bool): sort and re-assign labels by their usages. count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.

Returns: output_dict (dict): dictionary with labels and model parameters

moseq2_viz.model.util.prepare_model_dataframe(model_path, pca_path)

Creates a dataframe from syllable labels to be aligned with scalars.

Args: model_path (str): path to model to load label arrays from pca_path (str): path to pca_scores.h5 file.

Returns: _df (pandas.DataFrame): DataFrame object of timestamp aligned syllable label information.

moseq2_viz.model.util.relabel_by_usage(labels, fill_value=-5, count='usage')

Resort model labels by their usages.

Args: labels (list or dict): label sequences loaded from a model fit fill_value (int): value prepended to modeling results to account for nlags count (str): method to compute syllable mean usage, either ‘usage’ or ‘frames’.

Returns: labels (list or dict): label sequences sorted by usage sorting (list): the new label sorting. The index corresponds to the new label, while the value corresponds to the old label.

moseq2_viz.model.util.retrieve_pcs_from_slices(slices, pca_scores, max_dur=60, min_dur=3, max_samples=100, npcs=10, subsampling=None, remove_offset=False, **kwargs)

Subsample Principal components from syllable slices

Args: slices (np.ndarray): syllable slices or subarrays pca_scores (np.ndarray): PC scores for respective session. max_dur (int): maximum syllable length. min_dur (int): minimum syllable length. max_samples (int): maximum number of samples to retrieve. npcs (int): number of pcs to use. subsampling (int): number of syllable subsamples (defined through KMeans clustering). remove_offset (bool): indicate whether to remove initial offset from each PC score. kwargs (dict): used to capture certain arguments in other parts of the codebase.

Returns: syllable_matrix (np.ndarray): 3D matrix of subsampled PC projected syllable slices.

moseq2_viz.model.util.simulate_ar_trajectory(ar_mat, init_points=None, sim_points=100)

Simulate auto-regressive trajectory matrices from a set of initalized points.

Args: ar_mat (2D np.ndarray): numpy array representing the autoregressive matrix of a model state with shape (npcs, npcs * nlags + 1) init_points (2D np.ndarray): pre-initialzed array of shape (nlags, npcs) sim_points (int): number of time points to simulate.

Returns: sim_mat[nlags:] simulated AR trajectories excluding lagged values.

moseq2_viz.model.util.sort_batch_results(data, averaging=True, filenames=None, **kwargs)

Sort modeling results from batch/parameter scan.

Args: data (np.ndarray): model AR-matrices. averaging (bool): return an average of all the model AR-matrices. filenames (list): list of paths to fit models. kwargs (dict): dict of extra keyword arguments.

Returns: new_matrix (np.ndarray): either average of all AR-matrices, or top sorted matrix param_dict (dict): model parameter dict filename_index (list): list of filenames associated with each model.

moseq2_viz.model.util.sort_syllables_by_stat(complete_df, stat='usage', max_sylls=None)

Computes the sorted ordering of the given DataFrame with respect to the chosen stat.

Args: complete_df (pandas.DataFrame): dataframe containing the summary statistics about scalars and syllable data (mean_df/stats_df) stat (str): choice of statistic to order syllables by. max_sylls (int or None): the index of the maximum number of syllables to include

Returns: ordering (list): list of sorted syllables by stat. relabel_mapping (dict): a dict with key-value pairs {old_ordering: new_ordering}.

moseq2_viz.model.util.sort_syllables_by_stat_difference(complete_df, ctrl_group, exp_group, max_sylls=None, stat='usage')

Compute the syllable ordering for the difference of the inputted groups (exp - ctrl) and sort the syllables by the differences.

Args: complete_df (pandas.DataFrame): dataframe containing the summary statistics about scalars and syllable data (mean_df/stats_df) ctrl_group (str): Control group. exp_group (str): Experimental group. max_sylls (int): the index of the maximum number of syllables to include stat (str): choice of statistic to order mutations by: {usage, duration, speed}.

Returns: ordering (list): list of array indices for the new label mapping.

moseq2_viz.model.util.syll_duration(labels: ndarray) ndarray

Compute the duration of each syllable.

Args: labels (np.ndarray): array of syllable labels for a session.

Returns: durations (np.ndarray): array of syllable durations.

moseq2_viz.model.util.syll_id(labels: ndarray) ndarray

Return the syllable label at each onset of a syllable transition.

Args: labels (np.ndarray): array of syllable labels for a mouse.

Returns: labels[onsets] (np.ndarray): an array of compressed labels.

moseq2_viz.model.util.syll_onset(labels: ndarray) ndarray

Find indices of syllable onsets.

Args: labels (np.ndarray): array of syllable labels for a mouse.

Returns: indices (np.ndarray): an array of indices denoting the beginning of each syllables.

moseq2_viz.model.util.syllable_slices_from_dict(syllable: int = '__no__default__', labels: Dict[str, ndarray] = '__no__default__', index: Dict = '__no__default__', filter_nans: bool = True) Dict[str, list]

Read a dictionary of syllable labels, and returning a dict of syllable slices.

Args: syllable (list): list of syllables to get slices from. labels (np.ndarrary): list of label predictions for each session. index (dict): index file contents contained in a dict. filter_nans (bool): replace NaN values with 0.

Returns: vals (dict): key-value pairs of syllable slices per session uuid.

moseq2_viz.model.util.whiten_pcs(pca_scores, method='all', center=True)

Whiten PC scores using Cholesky whitening.

Args: pca_scores (dict): dictionary where values are pca_scores. method (str): ‘all’ to whiten using the covariance estimated from all keys, or ‘each’ to whiten each separately center (bool): whether or not to center the data

Returns: whitened_scores (dict): dictionary of whitened pc scores