Text Classification Using Keras LSTM, CNN and embeddings

A no-fuss pipeline for binary or multiclass text classification

Page content

Introduction

Traditionally, neural network algorithms are associated with image classification, but there are some dedicated to text classification as well. LSTM or Long Short Term Memory networks can be used for text classification tasks. We have also used CNN, an image classification oriented algorithm in our text classification.

In this blog post, we look at a 24-class symptom-to-disease classification problem kaggle dataset. This is a rather challenging problem because of the large number of labels and relatively few cases per label and numerous words denoting symptoms. The Curse of Dimensionality at work? We will use the power of an LSTM and a CNN along with word embeddings to develop a basic text classification pipeline and see how far we can go with this dataset.

We discuss code fragments here. The pointer to the complete code appear at the end of this article.

Import libraries

from keras import regularizers, optimizers
from keras.layers.experimental.preprocessing import TextVectorization
from keras.layers import Embedding, Dense, Dropout, Input, LSTM, Conv1D, GlobalMaxPool1D,GlobalAveragePooling1D
from keras.models import Sequential
from keras.initializers import Constant
import tensorflow as tf
import spacy
import en_core_web_lg  # Spacy large language model

Get your favourite embeddings

We use spacy large model embeddings as encoding for our words.

nlp = en_core_web_lg.load()
...

Create a vectorizer layer

The Keras vectorizer layer converts words to integers. We have the facility to specify the cutoff frequency to get the most frequent words in the vocabulary and also a padding/truncation length per document. This layer helps us feed text directly to the Keras model.

...
disease_vectorizer = TextVectorization(max_tokens=5000,output_sequence_length=None)

# We fit our vectorizer on our text and extract our corpus
disease_vectorizer.adapt(X_retain.to_numpy())
vocab = disease_vectorizer.get_vocabulary()
...

Create the embedding layer

The embedding layer goes before the LSTM and contains the embedding matrix that maps word integer codes to the corresponding word embedding.

...
# Generating the embedding matrix
ntokens = len(vocab)
embed_dim = len(nlp('The').vector)
embed_matrix = np.zeros((ntokens, embed_dim))
for i, word in tqdm(enumerate(vocab)):
    embed_matrix[i] = nlp(str(word)).vector
    
# Load the embedding matrix as the weights matrix for the embedding layer
Embed_layer=Embedding(
    ntokens,
    embed_dim,
    embeddings_initializer=Constant(embed_matrix),
    trainable=False)
...

Create the model

...
FILTERS = 24
KERNEL_SIZE = 3

import keras.metrics as kmetrics
# We fit a large model as our dataset is reasonably huge.
model = Sequential()
# The vectorizer layer
model.add(Input(shape=(1,), dtype=tf.string))
model.add(disease_vectorizer)
# Add the embedding layer
model.add(Embed_layer)

model.add(LSTM(25, return_sequences=True))
model.add(Conv1D(FILTERS,
                 KERNEL_SIZE,
                 padding='same',
                 strides=1,
                 activation='tanh', 
                 ))

model.add(GlobalAveragePooling1D())
#model.add(GlobalMaxPool1D())
model.add(Dropout(0.5))
model.add(Dense(32, activation='tanh', kernel_regularizer=regularizers.l1_l2(l1=1e-5,l2=1e-4)))
model.add(Dropout(0.5))
model.add(Dense(32, activation='tanh', kernel_regularizer=regularizers.l1_l2(l1=1e-5,l2=1e-4)))
model.add(Dense(24,activation='softmax'))

# Now add in your optimizer
model.compile(optimizer='adam', loss='categorical_crossentropy',metrics=['accuracy',kmetrics.Precision(),
                               kmetrics.Recall()])

# Print the summary of the model
print(model.summary())
...

The model has dropout layers to prevent overfitting. Also, the GlobalAvgPool1D pooling layer helps generate a document embedding of sorts. Remember that labels have to be one-hot encoded for the output activation to be softmax.

Model training

We track our model training by gathering model training history for accuracy, precision and recall after each epoch. The post-training plots of these metrics shows us when to stop training to prevent overfitting.

n_epochs = 200
batch_size = 20
# Now, fit the model
history = model.fit(x=X_retain, y=y_retain,
                    batch_size = batch_size,
                    epochs = n_epochs,
                    validation_split=.3
                   )

The output fragment:

Epoch 1/200
26/26 [==============================] - 5s 59ms/step - loss: 3.1886 - accuracy: 0.0676 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_loss: 3.1381 - val_accuracy: 0.0968 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Epoch 2/200
26/26 [==============================] - 1s 23ms/step - loss: 3.0858 - accuracy: 0.0974 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_loss: 3.0663 - val_accuracy: 0.0968 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Epoch 3/200
26/26 [==============================] - 1s 26ms/step - loss: 2.9604 - accuracy: 0.1750 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_loss: 2.9370 - val_accuracy: 0.1659 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
...

Epoch 198/200
26/26 [==============================] - 1s 26ms/step - loss: 0.2708 - accuracy: 0.9165 - precision: 0.9357 - recall: 0.8966 - val_loss: 1.2114 - val_accuracy: 0.7235 - val_precision: 0.7549 - val_recall: 0.7097
Epoch 199/200
26/26 [==============================] - 1s 24ms/step - loss: 0.3016 - accuracy: 0.9046 - precision: 0.9382 - recall: 0.8748 - val_loss: 1.2121 - val_accuracy: 0.7143 - val_precision: 0.7476 - val_recall: 0.7097
Epoch 200/200
26/26 [==============================] - 1s 25ms/step - loss: 0.3549 - accuracy: 0.8887 - precision: 0.9053 - recall: 0.8748 - val_loss: 1.1697 - val_accuracy: 0.7097 - val_precision: 0.7704 - val_recall: 0.6959

Classification report

                                 precision    recall  f1-score   support

                           Acne       0.95      0.95      0.95        20
                      Arthritis       1.00      0.95      0.97        20
               Bronchial Asthma       0.65      0.65      0.65        20
           Cervical spondylosis       0.65      1.00      0.78        20
                    Chicken pox       0.40      0.30      0.34        20
                    Common Cold       0.71      0.85      0.77        20
                         Dengue       0.58      0.55      0.56        20
          Dimorphic Hemorrhoids       0.72      0.90      0.80        20
               Fungal infection       0.86      0.95      0.90        20
                   Hypertension       0.82      0.90      0.86        20
                       Impetigo       0.95      0.95      0.95        20
                       Jaundice       0.56      1.00      0.71        20
                        Malaria       0.90      0.95      0.93        20
                       Migraine       1.00      0.75      0.86        20
                      Pneumonia       0.80      0.40      0.53        20
                      Psoriasis       0.72      0.65      0.68        20
                        Typhoid       0.60      0.45      0.51        20
                 Varicose Veins       0.73      0.95      0.83        20
                        allergy       0.69      0.55      0.61        20
                       diabetes       0.60      0.45      0.51        20
                  drug reaction       0.50      0.35      0.41        20
gastroesophageal reflux disease       0.67      0.80      0.73        20
           peptic ulcer disease       0.58      0.55      0.56        20
        urinary tract infection       0.71      0.50      0.59        20

                       accuracy                           0.72       480
                      macro avg       0.72      0.72      0.71       480
                   weighted avg       0.72      0.72      0.71       480


Takeaway

The classification report shows us different precision and recall values for each class label. This is understandable because many diseases have common symptoms.

For the complete code in context, refer to my Kaggle notebook.