--- datasets: - Omartificial-Intelligence-Space/Arabic-NLi-Triplet language: - ar base_model: "intfloat/multilingual-e5-small" library_name: sentence-transformers pipeline_tag: sentence-similarity tags: - sentence-transformers - sentence-similarity - feature-extraction - arabic - triplet-loss widget: [] --- # Arabic NLI Triplet - Sentence Transformer Model This repository contains a fine-tuned Sentence Transformer model trained on the "Omartificial-Intelligence-Space/Arabic-NLi-Triplet" dataset. The model is trained to generate 384-dimensional embeddings for semantic similarity tasks like paraphrase mining, sentence similarity, and clustering in Arabic. ## Model Overview - **Model Type:** Sentence Transformer - **Base Model:** `intfloat/multilingual-e5-small` - **Training Dataset:** [Omartificial-Intelligence-Space/Arabic-NLi-Triplet](https://huggingface.co/datasets/Omartificial-Intelligence-Space/Arabic-NLi-Triplet) - **Similarity Function:** Cosine Similarity - **Embedding Dimensionality:** 384 dimensions - **Maximum Sequence Length:** 128 tokens - **Performance Improvement:** The model achieved around 10% improvement when tested on the test set of the provided dataset, compared to the base model's performance. ## Dataset ### Arabic NLI Triplet Dataset The dataset contains triplets of sentences in Arabic: an anchor sentence, a positive sentence (semantically similar to the anchor), and a negative sentence (semantically dissimilar to the anchor). The dataset is designed for learning sentence representations through triplet margin loss. Dataset Link: [Omartificial-Intelligence-Space/Arabic-NLi-Triplet](https://huggingface.co/datasets/Omartificial-Intelligence-Space/Arabic-NLi-Triplet) ## Training Process ### Loss Function: Triplet Margin Loss We used the Triplet Margin Loss with a margin of `1.0`. The model is trained to minimize the distance between anchor and positive embeddings, while maximizing the distance between anchor and negative embeddings. ### Training Loss Progress: Below is the training loss recorded at various steps during the training process: | Step | Training Loss | |-------|---------------| | 500 | 0.136500 | | 1000 | 0.126500 | | 1500 | 0.127300 | | 2000 | 0.114500 | | 2500 | 0.110600 | | 3000 | 0.102300 | | 3500 | 0.101300 | | 4000 | 0.106900 | | 4500 | 0.097200 | | 5000 | 0.091700 | | 5500 | 0.092400 | | 6000 | 0.095500 | ## Model Training Code The model was trained using the following code (without resuming from checkpoints): ```python from datasets import load_dataset from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer from torch.nn import TripletMarginLoss # Load dataset dataset = load_dataset("Omartificial-Intelligence-Space/Arabic-NLi-Triplet") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-small") # Tokenize function def tokenize_function(examples): anchor_encodings = tokenizer(examples['anchor'], truncation=True, padding='max_length', max_length=128) positive_encodings = tokenizer(examples['positive'], truncation=True, padding='max_length', max_length=128) negative_encodings = tokenizer(examples['negative'], truncation=True, padding='max_length', max_length=128) return { 'anchor_input_ids': anchor_encodings['input_ids'], 'anchor_attention_mask': anchor_encodings['attention_mask'], 'positive_input_ids': positive_encodings['input_ids'], 'positive_attention_mask': positive_encodings['attention_mask'], 'negative_input_ids': negative_encodings['input_ids'], 'negative_attention_mask': negative_encodings['attention_mask'], } tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) # Define triplet loss triplet_loss = TripletMarginLoss(margin=1.0) def compute_loss(anchor_embedding, positive_embedding, negative_embedding): return triplet_loss(anchor_embedding, positive_embedding, negative_embedding) # Load model model = AutoModel.from_pretrained("intfloat/multilingual-e5-small") class TripletTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): anchor_input_ids = inputs['anchor_input_ids'].to(self.args.device) anchor_attention_mask = inputs['anchor_attention_mask'].to(self.args.device) positive_input_ids = inputs['positive_input_ids'].to(self.args.device) positive_attention_mask = inputs['positive_attention_mask'].to(self.args.device) negative_input_ids = inputs['negative_input_ids'].to(self.args.device) negative_attention_mask = inputs['negative_attention_mask'].to(self.args.device) anchor_embeds = model(input_ids=anchor_input_ids, attention_mask=anchor_attention_mask).last_hidden_state.mean(dim=1) positive_embeds = model(input_ids=positive_input_ids, attention_mask=positive_attention_mask).last_hidden_state.mean(dim=1) negative_embeds = model(input_ids=negative_input_ids, attention_mask=negative_attention_mask).last_hidden_state.mean(dim=1) return compute_loss(anchor_embeds, positive_embeds, negative_embeds) # Training arguments training_args = TrainingArguments( output_dir="/content/drive/MyDrive/results", learning_rate=2e-5, per_device_train_batch_size=16, num_train_epochs=3, weight_decay=0.01, logging_dir='/content/drive/MyDrive/logs', remove_unused_columns=False, fp16=True, save_total_limit=3, ) # Initialize trainer trainer = TripletTrainer( model=model, args=training_args, train_dataset=tokenized_datasets['train'], ) # Start training trainer.train() # Save model and evaluate trainer.save_model("/content/drive/MyDrive/fine-tuned-multilingual-e5") results = trainer.evaluate() print(results) ``` ## Framework Versions - Python: 3.10.11 - Sentence Transformers: 3.0.1 - Transformers: 4.44.2 - PyTorch: 2.4.0 - Datasets: 2.21.0 ## How to Use To use the model, install the required libraries and load the model with the following code: ```bash pip install -U sentence-transformers ``` ```python from sentence_transformers import SentenceTransformer # Load the fine-tuned model model = SentenceTransformer("gimmeursocks/ara-e5-small") # Run inference sentences = ['أنا سعيد', 'الجو جميل اليوم', 'هذا كلب كبير'] embeddings = model.encode(sentences) print(embeddings.shape) ``` ## Citation If you use this model or dataset, please cite the corresponding paper or dataset source.