ProtGPT2 QA Task for Mutation Generation

#3
by hunarbatra - opened

Dear Authors,

Thank you for the excellent work with ProtGPT2 and for making it available through HuggingFace :).

I had an out-of-the-box question that I was trying to use ProtGPT2 for and it'll be great to hear if you have any inputs over this:-
Is it possible to fine-tune ProtGPT2 for a mutation generation task (input-output pairs, more like question-answering in tasks terminology) where given a sequence it would return a mutated sequence. And I'll fine-tune it on this kind of a dataset with input-output pairs of wild-type & mutated sequence?
If yes, then could you please guide me with how this kind of input-output sequence mutation generation could be performed with ProtGPT2?

Thank you so much :)

I figured out how to do that (just how we do it to a regular GPT2 model by adding [WILDTYPE]: seq..... \n [MUTATION]: seq......)
But I'm unable to do the finetuning on Google Colab hosted runtime (GPU) or locally on my system (16gigs RAM / M1 mbp)
"RuntimeError: CUDA out of memory. Tried to allocate 160.00 MiB (GPU 0; 14.76 GiB total capacity; 13.64 GiB already allocated; 81.75 MiB free; 13.71 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"

Is it possible to do the finetuning on Colab or with the Colab pro plan? Any clue?

Hi hunarbatra!

Happy that you're using ProtGPT2! It indeed seems that the model hits OOM on colab, I haven't tried to load the model myself on colab, so no experience in this regard, but a few thoughts come to mind that will help with the memory issues.

  1. If you are not specifying --per_device_train_batch_size in your command, it possibly is using a batch size of 8. You could try a batch size of 1: add the flag --per_device_train_batch_size 1 to your command.

  2. If you are still getting that CUDA OOM error after this, I'd try gradient_checkpointing: https://huggingface.co/docs/transformers/main/en/performance#gradient-checkpointing.

  3. If that still doesn't fit into memory, I'd use mixed precision (passing the flag --fp16 to the Trainer).

I hope this helps, and please let me know if questions remain!
Noelia

Thank you so much! :)

Sign up or log in to comment