mt-ranker / app.py
ibraheemmoosa's picture
Update app.py
891b9bc verified
raw
history blame contribute delete
No virus
2.39 kB
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer, PretrainedConfig, PreTrainedModel, MT5EncoderModel
class MTRankerConfig(PretrainedConfig):
def __init__(self, backbone='google/mt5-base', **kwargs):
self.backbone = backbone
super().__init__(**kwargs)
class MTRanker(PreTrainedModel):
config_class = MTRankerConfig
def __init__(self, config):
super().__init__(config)
self.encoder = MT5EncoderModel.from_pretrained(config.backbone)
self.num_classes = 2
self.classifier = torch.nn.Linear(self.encoder.config.hidden_size, self.num_classes)
def forward(self, input_ids, attention_mask):
encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
seq_lengths = torch.sum(attention_mask, keepdim=True, dim=1)
pooled_hidden_state = torch.sum(encoder_output * attention_mask.unsqueeze(-1).expand(-1, -1, self.encoder.config.hidden_size), dim=1)
pooled_hidden_state /= seq_lengths
prediction_logit = self.classifier(pooled_hidden_state)
return prediction_logit
config = MTRankerConfig(backbone='google/mt5-base')
tokenizer = AutoTokenizer.from_pretrained(config.backbone)
model = MTRanker.from_pretrained('ibraheemmoosa/mt-ranker-base')
def predict(source, translation1, translation2):
model_input = "Source: {} Translation 0: {} Translation 1: {}".format(source, translation1, translation2)
inputs = tokenizer([model_input], max_length=512, padding='max_length', truncation=True, return_tensors='pt')
with torch.inference_mode():
logits = model(inputs.input_ids, inputs.attention_mask)
output_scores = torch.softmax(logits, dim=1)
output_scores = output_scores[0]
return {'Translation 1': output_scores[0], 'Translation 2': output_scores[1]}
source_textbox = gr.Textbox(label="Source", info="Source Sentence", value="Le chat est sur la tapis.")
translation1_textbox = gr.Textbox(label="Translation 1", info="Translation 1", value="The cat is on the bed.")
translation2_textbox = gr.Textbox(label="Translation 2", info="Translation 2", value="The cat is on the carpet.")
output = gr.Label(label="Result")
iface = gr.Interface(fn=predict, inputs=[source_textbox, translation1_textbox, translation2_textbox], outputs=output)
iface.launch()