Saturday, November 22, 2025

Text Classification from Scratch using PyTorch

The AI/ ML development framework Keras 3x in recent times has introduced support for Torch & Jax backends in addition to Tensorflow. However, given Keras's Tensorflow legacy large sections of the code are deeply integrated with Tensorflow. 

One such piece of code is the text_classification_from_scratch.py from keras-io/examples project. Without tensorflow this piece of code simply doesn't run!

Here's text_classification_torch.py a Torch/ PyTorch port of the same code. The bits that needed modification were:

  • Removing all tensorflow related imports
  • Loading the Imdb text files in "grain" format in place of "tf" format, by passing the appropriate param: 

    keras.utils->text_dataset_from_directory(format="grain") 

Which obviously needed grain to be installed:

    pip3 install grain 

  • Using torchtext for Vocab/ Tokenizer/ Vectorizing :

    pip3 install torchtext

  • Few other changes such as ensure max_features constraint's honoured, text is standardized, padded, and so on.

No comments:

Post a Comment