Do you wonder how Artificial Intelligence can be helpful for healthcare? Are you interested in learning Machine Learning?
Note: To get the most out of this article, you can download the code for free here.
This article will explore how healthcare can benefit from Artificial Intelligence through Natural Language Processing technology.
In a previous article, I provided a birds-eye view of how Conversational UI could benefit a healthcare provider.
This article is the first one of a series in which we will build an end-to-end Machine Learning Project for mental healthcare.
Disclaimer: This blog post on suicidal tendency detection is for educational purposes only. It is not meant to be a reliable, highly accurate mental illness diagnosis system, nor has it been professionally or academically vetted.
This project can also be a way to explore the Data Science Thought process when building a Machine Learning project.
A Machine Learning Project is made of the following steps:
- Problem Statement
- Data Collection
- Exploratory Data Analysis
- Error Analysis
The result will be a web app available for everybody to use. It will invite the user to enter his current feelings and, upon submission, be prompted to seek help if the model indicates suicidal tendencies.
In this article, after defining the problem, we will aggregate and perform the first round of data cleaning to prepare a dataset for future exploration and modeling.
Let’s get right into it!
To follow this article, you can download the Jupyter Notebook file for free here.
Mental disorders can be very distressing, and individuals facing them might feel at a loss on what to do and not realizing that they might need to seek help to get better.
Let’s suppose that a mental health clinic wants to be able to detect suicidal patients among depressed ones to provide more immediate care support for them.
We will build a Classification Machine Learning model based on two subreddits: r/depression and r/SuicideWatch
We will compare a few Machine Learning techniques such as Logistic Regression, Naive Bayes, Random Forest, XGBoost, and Gradient Boosting in terms of algorithms.
If the Machine Learning algorithms are not good enough to solve the problem, we will consider Deep Learning solutions, such as Recurrent Neural networks.
We will display the accuracy, AUC, precision, sensitivity, specificity, and F1 score in terms of metrics.
The main goal is to classify posts showing suicidal tendencies accurately. Putting it another way, we will consider it acceptable that a position indicating depression gets misclassified as suicidal. However, we want to avoid as much as possible posts showing suicidal tendencies and getting misclassified as depression.
The reason behind this is because people who start to have recurrent suicidal thoughts have a higher chance to harm themselves, and we do not want to keep them isolated.
On the other hand, people depressed but misclassified as suicidal might have to go to a check-up to make sure they will not harm themselves.
The reason behind this project is that I wanted to make something potentially useful while having an interesting Machine Learning Classification problem to work on. The frontier between the state of depression and having a suicidal tendency can be thin. It assumes that our data are correctly labeled, meaning that posts belonging to the “depression category” do not indicate suicidal tendencies.
Let’s describe what Reddit is before going through the compilation of our dataset.
What Is Reddit
Reddit is one of the most visited websites in the United States and defines itself as “the front page of the internet”. Reddit is a vast collection of forums where people can share news and content or comment on other people’s posts anonymously. Because users can remain anonymous, we can think that opinions and thoughts shared on the platform are less filtered that if it was not anonymous. As Reddit allows us to collect data, it is a great resource to get data for a Machine Learning project.
Reddit is divided into over one million subreddits, each one covering a specific topic.
Reddit will be our data source for this project.
We will collect data from Reddit SuicideWatch and Reddit Depression. Reddit has an API available to collect data. This is what we are going to do in this article. In general, using an API to scrape data is good practice over programming a web scraper.
To scrape data from Reddit, we will use the Python Pushshift.io API Wrapper (PSAW) to scrape over 20000 submissions from Reddit. If you are interested, you can find the documentation here.
To do it, you might need to install
PSAW, which can be conveniently done with the
pip install command:
pip install psaw
We need to initialize the Pushshift API with the
For this project, we are going to collect 25000 submissions or posts before 30 July 2021.
We will call
SuicideWatch as the
subreddit will be passed in the
before parameter is set as an integer of 30 July 2021 using the
datetime built-in module.
25000 will be set as
limit parameter to fetch a maximum of 25000 posts.
Then, we initialize an empty list and append the posts into the list by iterating through a for loop.
We finally create a data frame for our dataset and save it as a CSV file.
To perform this action, we need to use the
pandas library, a powerful Python library to work with data.
If necessary, you can install it by running
pip install pandas.
To finalize the dataset, we call the
DataFrame() method from the pandas library, and then, we save our data frame as a CSV file with the
Now that we managed to scrape our data, in the next part, we will perform some minor data cleaning operations before saving our dataset for later use. To follow along, do not forget to download the associated Jupyter notebook for free here.
To prepare our data for exploratory data analysis and modeling, we need to clean our dataset.
At this stage, I kept the data cleaning minimal. The goal here was to have a text properly formatted without duplicates nor missing values.
First, we can check if there are any duplicates in our dataset.
After running the
duplicated() method, we have 4377 duplicates rows, easily removed with
We now have 20363 posts in our dataset.
Then, I removed the
created columns with the
I was also curious to know how many posts per subreddit we have, so I called
groupby().count() on the
subreddit column. As you can see, our dataset has a good balance of classes.
Next, I created a dictionary to rename the columns as
text instead of
label instead of
I merged the
title column with the
text column and dropped the
I checked for any null value with
isnull().any().any() and the output was
Then, I removed the emojis with
I also eliminated some characters from Reddit-specific Markdown formatting appearing here and there in the text, such as
To perform this action, I used the
redditcleaner library to get rid of them, which you can install by running
pip install redditcleaner.
To expand the contractions, I used a library called
contractions. You can install it by typing
pip install contractions.
I called the
fix() method from the contractions library, and I used a list comprehension to iterate through each word. Finally, we use a
lambda function to apply it to the text column of the dataset.
It is worth mentioning that this operation will tokenize together the expanded words as “I have” instead of “I”, “have”. Therefore, we need to convert the text back to a string with a map function. We use a list comprehension again to iterate through the whole column and join the words back together as a string.
I removed the URL using regular expressions and saved the dataset for later use in the next part.
Here is what our dataset looks like at the end of this cleaning session:
The cleaning done at this stage is not enough to fit a model yet. We need to normalize the text further and eventually build a custom dictionary. We need, for example, to expand the abbreviations, check the spelling, and standardize the spelling between the UK and US English, example. We will do it in the following article.
We then save our dataset for future use.
In this part, after defining the problem, we aggregated data from Reddit using the Pushshifht.io API Wrapper. We also cleaned and saved our data as a CSV file.
In the following article, we will try to understand our data better, preprocess our dataset for Machine Learning, and see if we can extract some valuable features.
You can download the source code for free here and run it to get your dataset. I also invite you to play with the code and modify it.