The AI/ ML development framework Keras 3x in recent times has introduced support for Torch & Jax backends in addition to Tensorflow. However, given Keras's Tensorflow legacy large sections of the code are deeply integrated with Tensorflow.
One such piece of code is the text_classification_from_scratch.py from keras-io/examples project. Without tensorflow this piece of code simply doesn't run!
Here's text_classification_torch.py a Torch/ PyTorch port of the same code. The bits that needed modification were:
- Removing all tensorflow related imports
- Loading the Imdb text files in "grain" format in place of "tf" format, by passing the appropriate param:
keras.utils->text_dataset_from_directory(format="grain")
Which obviously needed grain to be installed:
pip3 install grain
- Using torchtext for Vocab/ Tokenizer/ Vectorizing :
pip3 install torchtext
- Few other changes such as ensure max_features constraint's honoured, text is standardized, padded, and so on.
No comments:
Post a Comment