import torch from torch.utils.data import Dataset from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator import torch.nn as nn import math import gradio as gr # Suppress torchtext deprecation warnings import torchtext torchtext.disable_torchtext_deprecation_warning() # Define the CSS styles css_styles = ''' @import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600;700;800&display=swap'); .gradio-container { font-family: 'Plus Jakarta Sans', sans-serif; } button.primary-button { width: 300px; height: 48px; padding: 2px; font-weight: 700; border: 3px solid #5964C2; border-radius: 5px; background-color: #7583FF; color: white; font-size: 20px; transition: 0.3s ease; } button.primary-button:hover { background-color: #5C67C9; border: 3px solid #31376B; } input[type="text"], textarea { width: 100%; outline: none; border: 3px solid #B4CFBB !important; background-color: #DEFFE7 !important; border-radius: 10px !important; color: #B4CFBB !important; padding: 2px !important; font-weight: 600 !important; transition: 0.3s ease; } input[type="text"]:focus, textarea:focus { background-color: #88A88D !important; border: 3px solid #657D69 !important; color: #657D69 !important; font-size: 16px !important; } ''' # Define the TranslationDataset class (simplified for vocab loading) class TranslationDataset(Dataset): def __init__(self, file_path): self.src_tokenizer = get_tokenizer('basic_english') self.tgt_tokenizer = get_tokenizer('basic_english') self.src_vocab = build_vocab_from_iterator(self._yield_tokens(file_path, 0), specials=["", "", "", ""]) self.tgt_vocab = build_vocab_from_iterator(self._yield_tokens(file_path, 1), specials=["", "", "", ""]) self.src_vocab.set_default_index(self.src_vocab[""]) self.tgt_vocab.set_default_index(self.tgt_vocab[""]) def _yield_tokens(self, file_path, index): with open(file_path, 'r', encoding='utf-8', errors='replace') as f: for line in f: line = line.strip() if line: try: src, tgt = line.split('","') src = src[2:] tgt = tgt[:-3] yield self.src_tokenizer(src) if index == 0 else self.tgt_tokenizer(tgt) except ValueError: continue # Define the PositionalEncoding class class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) if d_model % 2 == 1: # For odd d_model, handle the last column pe[:, 1::2] = torch.cos(position * div_term[:-1]) else: pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) # Define the TransformerModel class class TransformerModel(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1): super(TransformerModel, self).__init__() self.model_type = 'Transformer' self.src_embedding = nn.Embedding(src_vocab_size, d_model) self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model) self.pos_encoder = PositionalEncoding(d_model, dropout) self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout) self.fc_out = nn.Linear(d_model, tgt_vocab_size) self.d_model = d_model self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask): src = self.src_embedding(src) * math.sqrt(self.d_model) tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model) src = self.pos_encoder(src) tgt = self.pos_encoder(tgt) memory = self.transformer( src, tgt, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask ) output = self.fc_out(memory) return output # Translation function def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50): model.eval() src_tokenizer = get_tokenizer('basic_english') src_tokens = [src_vocab[""]] + [src_vocab[token] for token in src_tokenizer(src_sentence)] + [src_vocab[""]] src_tensor = torch.LongTensor(src_tokens).unsqueeze(1).to(device) src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool) with torch.no_grad(): memory = model.transformer.encoder( model.pos_encoder(model.src_embedding(src_tensor) * math.sqrt(model.d_model)), src_mask ) ys = torch.ones(1, 1).fill_(tgt_vocab[""]).type(torch.long).to(device) for _ in range(max_len-1): tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device) with torch.no_grad(): out = model.transformer.decoder( model.pos_encoder(model.tgt_embedding(ys) * math.sqrt(model.d_model)), memory, tgt_mask ) out = model.fc_out(out) prob = out[-1].detach() _, next_word = torch.max(prob, dim=1) next_word = next_word.item() ys = torch.cat([ys, torch.ones(1, 1).type_as(src_tensor.data).fill_(next_word)], dim=0) if next_word == tgt_vocab[""]: break ys = ys.flatten() translated_tokens = [ tgt_vocab.get_itos()[token] for token in ys if token not in [tgt_vocab[""], tgt_vocab[""], tgt_vocab[""]] ] return " ".join(translated_tokens) # Load the model and dataset def load_model_and_data(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load the dataset (for vocabulary) file_path = 'newcode15M.txt' # Replace with the path to your dataset file dataset = TranslationDataset(file_path) # Model hyperparameters (make sure these match your trained model) SRC_VOCAB_SIZE = len(dataset.src_vocab) TGT_VOCAB_SIZE = len(dataset.tgt_vocab) D_MODEL = 256 NHEAD = 8 NUM_ENCODER_LAYERS = 6 NUM_DECODER_LAYERS = 6 DIM_FEEDFORWARD = 512 DROPOUT = 0.2 # Initialize the model model = TransformerModel( SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, NHEAD, NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, DIM_FEEDFORWARD, DROPOUT ).to(device) # Load the trained model model.load_state_dict(torch.load('AllOneLM.pth', map_location=device)) model.eval() return model, dataset.src_vocab, dataset.tgt_vocab, device # Load model and data model, src_vocab, tgt_vocab, device = load_model_and_data() # Define the translation function for Gradio def translate_sentence(src_sentence): translated_sentence = translate(model, src_sentence, src_vocab, tgt_vocab, device) return translated_sentence # Create Gradio interface iface = gr.Interface( fn=translate_sentence, inputs=gr.Textbox(label="Enter a sentence:", lines=2, placeholder="Type here..."), outputs=gr.Textbox(label="Translated:"), title="Translation Talking Script", description="Enter a sentence to translate.", css=css_styles ) # Launch the interface if __name__ == "__main__": iface.launch()