The AI/ ML development framework Keras 3x supports in recent times has got support for Torch & Jax backends, in addition to Tensorflow. However, given Keras's Tensorflow legacy large sections of the code are deeply integerated with Tensorflow.
One such piece of code is text_classification_from_scratch.py from the keras-io/ examples project. Without tensorflow this piece of code simply won't run!
Here's text_classification_torch.py a pure Torch/ PyTorch port of the same code. The bits that needed modification:
- Removing all tensorflow related imports
- Loading the Imdb text files in "grain" format in place of "tf" format, by passing the appropriate param:
keras.utils->text_dataset_from_directory(format="grain")
Also grain needs to be installed:
pip3 install grain
- For building Vocab, Tokenizer, Vectorizing use torchtext:
pip3 install torchtext
- Few other changes such as ensure max_features constraint's honoured, text is standardized, padded, and so on