``` ```

Text mining Star Trek dialogue and classifying characters using machine learning

Graphs, text mining and analysis using the #TidyTuesday data set for week 34 of 2021 (17/8/2021): β€œStar Trek voice commands”

Ronan Harrington https://github.com/rnnh/
2021-08-18

Introduction

In this article, the Star Trek voice commands data set from the #TidyTuesday project is used to investigate character diction using text mining, and train a machine learning model to distinguish between people and computers in the data set. The techniques used in this article are taken from the following textbooks, both of which are freely available online:

  1. Text Mining with R
  2. Supervised Machine Learning for Text Analysis in R

Setup

Loading the R libraries and data set.

# Loading libraries
library(tidytuesdayR) # For loading data set
library(textrecipes) # For adding tf-idf to recipies (step_tfidf)
library(tidyverse) # For the ggplot2, dplyr, forcats libraries
library(tidytext) # For text mining
library(glmnet) # For lasso model
library(tidymodels) # For various modelling libraries

# Loading data set
tt <- tt_load("2021-08-17")

    Downloading file 1 of 1: `computer.csv`

Most spoken words by prominent characters

The first graph plots the words most frequently used by prominent characters when talking to computers. They were chosen by selecting characters in the data set that appeared more than 150 times each.

# Creating a vector of characters with most interactions with computers
chars_with_most_interactions <- tt$computer %>%
  filter(char_type != "Computer") %>% # Filtering out rows for computer lines
  count(char, sort = TRUE) %>% # Counting the number of rows for each character
  filter(n >= 150) %>% # Filtering out characters with under 150 lines
  pull(char) # Pulling the names of the characters as a vector

# Counting words spoken per character
character_words <- tt$computer %>%
  filter(char %in% chars_with_most_interactions) %>% # Selecting rows
  unnest_tokens(word, interaction) %>% # Creating a row per word in interaction
  anti_join(get_stopwords(), by = "word") %>% # Removing stop words
  count(char, word, sort = TRUE) # Counting words per character

# Plotting most spoken words per character
character_words %>%
  group_by(char) %>%
  slice_max(n, n = 15, with_ties = FALSE) %>%
  ungroup() %>%
  ggplot(aes(n, fct_reorder(word, n), fill = char)) +
  geom_col(show.legend = FALSE) +
  facet_wrap(~char, ncol = 2, scales = "free") +
  scale_fill_viridis_d() +
  theme_classic() +
  labs(title = "Words most frequently spoken by prominent characters",
       x = "Utterances", y = "Words")
The most frequently spoken word is `computer`.

Figure 1: The most frequently spoken word is computer.

Words specific to prominent characters

From Figure 1, we can see that the most frequently spoken word across all four characters is β€œcomputer”. This is not surprising, as this data set records interactions between characters and computers, with characters addressing the computers as β€œcomputer”. An expected result, but it does not tell us anything about the characters. What if we wanted to see which words were specific to each character?

To see which words are specific to these characters, we can use term frequency-inverse document frequency (tf-idf) to measure word importance. For a given word, tf-idf increases the more often it is used, and decreases relative to the number of documents it appears in. In this case, each β€œdocument” refers to a character. So even though β€œcomputer” is the most frequent term used by the characters, it will have a low tf-idf value as it is used by every character (therefore it appears in every document).

# Counting total number of words spoken per character
total_words <- character_words %>% 
  group_by(char) %>%  # Grouping words by character
  summarize(total = sum(n)) # Counting words spoken per character ("total")

# Joining tbl_dfs: adding "total" variable to character_words
character_words <- left_join(character_words, total_words)

# Adding tf_idf to character_words
character_words <- character_words %>%
  bind_tf_idf(term = word, document = char, n = n) %>%
  arrange(desc(tf_idf)) # Arranging rows by descending tf_idf values

# Printing a summary of the character_words object
glimpse(character_words)
Rows: 1,492
Columns: 7
$ char   <chr> "Riker", "Riker", "Riker", "Riker", "Riker", "Riker",…
$ word   <chr> "little", "omega", "indeed", "jazz", "good", "audienc…
$ n      <int> 18, 15, 14, 14, 24, 12, 12, 12, 12, 12, 12, 12, 12, 1…
$ total  <int> 1144, 1144, 1144, 1144, 1144, 1144, 1144, 1144, 1144,…
$ tf     <dbl> 0.015734266, 0.013111888, 0.012237762, 0.012237762, 0…
$ idf    <dbl> 1.3862944, 1.3862944, 1.3862944, 1.3862944, 0.6931472…
$ tf_idf <dbl> 0.02181232, 0.01817694, 0.01696514, 0.01696514, 0.014…
# Plotting words with the top tf-idf values per character
character_words %>%
  group_by(char) %>%
  slice_max(tf_idf, n = 15, with_ties = FALSE) %>%
  ungroup() %>%
  ggplot(aes(tf_idf, fct_reorder(word, tf_idf), fill = char)) +
  geom_col(show.legend = FALSE) +
  facet_wrap(~char, ncol = 2, scales = "free") +
  scale_fill_viridis_d() +
  theme_classic() +
  labs(title = "Spoken words specific to each character",
       subtitle = "Words listed are used often by each character relative to other characters",
       y = "Words",
       x = "term frequency-inverse document frequency (tf-idf)")
Data is the most likely character to say please when addressing a computer. Riker likes jazz. Picard likes to say his own name

Figure 2: Data is the most likely character to say please when addressing a computer. Riker likes jazz. Picard likes to say his own name

Specifying a model to classify People and Computers in the data set

In this section, a machine learning model for classifying interactions into categories is set up. The variable β€œchar_type” divides the data set into two categories: the lines spoken by people to computers (char_type = β€œPerson”) and responses from the computers (char_type = β€œComputer”). The goal of this model is to look at the dialogue in the β€œinteraction” variable, and determine whether that dialogue was spoken by a β€œPerson” or β€œComputer” character based on the words used. To do this:

The preprocessing for this data set includes a downsampling step. There is a class imbalance in this dataset, as there are over twice the number of β€œPerson” interactions as β€œComputer” ones. Including the downsampling step addresses this class imbalance by removing observations of the majority class (β€œPerson”) during training.

# Summarising the values in the char_type variable
tt$computer %>% select(char_type) %>% table()
.
Computer   Person 
     708     1506 
# Setting seed for reproducibility
set.seed(20210818)
# Splitting data into training and test subsets
char_type_split <- tt$computer %>%
  initial_split(strata = char_type) # Ensures equal proportion of char_type

# Creating training and test sets using departures_split
char_type_train <- training(char_type_split)
char_type_test <- testing(char_type_split)

# Setting seed for reproducibility
set.seed(20210818)
# Setting up 10-fold cross-validation (CV) using training data
char_type_folds <- vfold_cv(char_type_train,
                            strata = char_type) # Ensures equal proportion of char_type

# Creating preprocessing recipe for predicting if lines are spoken by a computer
char_type_rec <-  recipe(char_type ~ interaction, data = char_type_train) %>%
  step_tokenize(interaction) %>% # Splitting interaction variable into words
  step_tokenfilter(interaction, max_tokens = 1500) %>% # Limiting tokens used
  step_tfidf(interaction) %>% # Weighing tokens by tf-idf
  themis::step_downsample(char_type) # Downsampling to address class imbalance

# Specifying a lasso regularised model
lasso_spec <- logistic_reg(penalty = 0.01, mixture = 1) %>%
  set_mode("classification") %>%
  set_engine("glmnet")

# Creating a supervised machine learning workflow
char_type_wf <- workflow() %>%
  add_recipe(char_type_rec) %>%
  add_model(lasso_spec)

# Printing workflow
char_type_wf
══ Workflow ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ──────────────────────────────────────────────────────
4 Recipe Steps

β€’ step_tokenize()
β€’ step_tokenfilter()
β€’ step_tfidf()
β€’ step_downsample()

── Model ─────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)

Main Arguments:
  penalty = 0.01
  mixture = 1

Computational engine: glmnet 

Fitting the model to resampled training data

Now that the machine learning workflow is in place, the model can be evaluated using the training data. It will be fit to the training model ten times, as specified in the 10-fold cross-validation above. As each of these folds tests the model using a different set of observations for training and testing, it gives a more accurate estimate of the performance of the model than just training and testing once.

Two metrics are used to estimate model performance:

The ROC curve is plotted for each of the training data resamples. The ROC curve is a plot of true positive rate (observations given the correct class) versus false positive rate. Essentially, the closer the ROC curve is to the top-right of the plot, the higher its true positive rate, the lower its false positive rate, the larger the area under the curve. If the machine learning model was just guessing, it would have a 50/50 chance of classifying a given observation correctly: this would result in a ROC curve that is a straight line through the origin, with an AUC of 0.5 (50%).

# Fitting model to resampled folds to estimate performance
char_type_rs <- fit_resamples(
  object = char_type_wf, # Machine learning workflow
  resamples = char_type_folds, # 10-fold cross-validation
  control = control_resamples(save_pred = TRUE)
)
# Collecting predictions from resampled folds
char_type_rs_predictions <- collect_predictions(char_type_rs)

# Collecting model performance metrics metrics
char_type_rs_metrics <- collect_metrics(char_type_rs)
# Printing performance metrics
char_type_rs_metrics
# A tibble: 2 x 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.880    10 0.00650 Preprocessor1_Model1
2 roc_auc  binary     0.952    10 0.00420 Preprocessor1_Model1
# Plotting ROC curves for resampled folds
char_type_rs_predictions %>%
  group_by(id) %>%
  roc_curve(truth = char_type, .pred_Computer) %>%
  autoplot() + 
  scale_color_viridis_d() +
  labs(color = "Resample (fold)",
       title = "ROC curve for character type based on interaction",
       subtitle = "True Positive Rate (TPR) v.s. False Positive Rate (FPR)",
       x = "FPR (1 - Specificity)", y = "TPR (Sensitivity)") +
  theme_bw()
The machine learning model is far better at classifying characer types than random guessing.

Figure 3: The machine learning model is far better at classifying characer types than random guessing.

Evaluating the model using the test data

From the 10-fold cross-validation training, we can see that this model is effective at classifying character type based on words used in interactions. To put this model to the test, we can use it to classify unseen data. Again, accuracy and ROC AUC will be used to evaluate model performance on the test data. A confusion matrix will also be plotted to illustrate performance: this matrix summarises model predictions versus the true values for observations in the test data.

# Fitting model to training data, evaluating performance on test data
char_type_last_fit <- last_fit(char_type_wf, char_type_split)
# Printing metrics for final fit
collect_metrics(char_type_last_fit)
# A tibble: 2 x 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.895 Preprocessor1_Model1
2 roc_auc  binary         0.952 Preprocessor1_Model1
# Plotting confusion matrix of testing data
collect_predictions(char_type_last_fit) %>%
  conf_mat(truth = char_type, estimate = .pred_class) %>%
  autoplot(type = "heatmap") +
  scale_fill_continuous() +
  labs(title = "Confusion matrix of lasso classification model performance on test subset",
       subtitle = "Sum of True Positives and False Positives for both classes: Computer and Person")
The model correctly classifies most observations in the test data.

Figure 4: The model correctly classifies most observations in the test data.

Important words in the final model

The model is effective at classifying observations in the test data, with an accuracy of 89% and a ROC AUC of 95%. At this point, we may want to know more about how the model works. For example, which words had the largest impact on the model’s decision to classify a character type as β€œComputer” or β€œPerson”? To see which words were important in the final model, we can extract the coefficient assigned to each word in the final fit, and plot the words with the highest values.

# Extracting model fit
char_type_workflow_fit <- pull_workflow_fit(char_type_last_fit$.workflow[[1]])

# Visualising the most important words for predicting whether a line was spoken
# by a Computer or Person character
tidy(char_type_workflow_fit) %>%
  filter(term != "Bias", term != "(Intercept)") %>%
  group_by(sign = estimate > 0) %>%
  slice_max(abs(estimate), n = 15) %>% 
  ungroup() %>%
  mutate(term = str_remove(term, "tfidf_interaction_"), # Tidying terms
         sign = ifelse(sign, "More likely from a Person",
                       "More likely a Computer")) %>%
  ggplot(aes(abs(estimate), fct_reorder(term, abs(estimate)), fill = sign)) +
  geom_col(show.legend = FALSE) +
  facet_wrap(~sign, scales = "free") +
  theme_classic() +
  scale_fill_viridis_d() +
  labs(titles = "Which words are more likely to be used by a Computer vs. Person character?",
       subtitle = "Importance assigned by lasso classification model, based lines spoken",
       x = "Coefficient from lasso regularised model", y = "Words")
The model's most important word for classifying `Person` characters was `computer`.

Figure 5: The model’s most important word for classifying Person characters was computer.

References

Corrections

If you see mistakes or want to suggest changes, please create an issue on the source repository.

Reuse

Text and figures are licensed under Creative Commons Attribution CC BY 4.0. Source code is available at https://github.com/rnnh/TidyTuesday/, unless otherwise noted. The figures that have been reused from other sources don't fall under this license and can be recognized by a note in their caption: "Figure from ...".

Citation

For attribution, please cite this work as

Harrington (2021, Aug. 18). Ronan's #TidyTuesday blog: Text mining Star Trek dialogue and classifying characters using machine learning. Retrieved from https://tidytuesday.netlify.app/posts/2021-08-18-star-trek-voice-commands/

BibTeX citation

@misc{harrington2021text,
  author = {Harrington, Ronan},
  title = {Ronan's #TidyTuesday blog: Text mining Star Trek dialogue and classifying characters using machine learning},
  url = {https://tidytuesday.netlify.app/posts/2021-08-18-star-trek-voice-commands/},
  year = {2021}
}