Ümit Gündüz commited on
Commit
d166ac8
1 Parent(s): d84735c

update for evaluate

Browse files
Files changed (46) hide show
  1. data/dataset/test/milliyet.json +0 -0
  2. data/dataset/test/milliyet.pickle +3 -0
  3. data/dataset/test/ntv.json +0 -0
  4. data/dataset/test/ntv.pickle +3 -0
  5. data/dataset/test/trthaber.json +0 -0
  6. data/dataset/test/trthaber.pickle +3 -0
  7. data/dataset/{100 → train/100}/aa.json +0 -0
  8. data/dataset/{100 → train/100}/aa.pickle +0 -0
  9. data/dataset/{100 → train/100}/aksam.pickle +0 -0
  10. data/dataset/{100 → train/100}/cnnturk.pickle +0 -0
  11. data/dataset/{100 → train/100}/cumhuriyet.pickle +0 -0
  12. data/dataset/{100 → train/100}/ensonhaber.pickle +0 -0
  13. data/dataset/{100 → train/100}/haber7.pickle +0 -0
  14. data/dataset/{100 → train/100}/haberglobal.pickle +0 -0
  15. data/dataset/{100 → train/100}/haberler.pickle +0 -0
  16. data/dataset/{100 → train/100}/haberturk.pickle +0 -0
  17. data/dataset/{100 → train/100}/hurriyet.pickle +0 -0
  18. data/dataset/{1000 → train/1000}/aa.pickle +0 -0
  19. data/dataset/{1000 → train/1000}/aksam.pickle +0 -0
  20. data/dataset/{1000 → train/1000}/cnnturk.pickle +0 -0
  21. data/dataset/{1000 → train/1000}/cumhuriyet.pickle +0 -0
  22. data/dataset/{1000 → train/1000}/ensonhaber.pickle +0 -0
  23. data/dataset/{1000 → train/1000}/haber7.pickle +0 -0
  24. data/dataset/{1000 → train/1000}/haberglobal.pickle +0 -0
  25. data/dataset/{1000 → train/1000}/haberler.pickle +0 -0
  26. data/dataset/{1000 → train/1000}/haberturk.pickle +0 -0
  27. data/dataset/{1000 → train/1000}/hurriyet.pickle +0 -0
  28. data/dataset/{10000 → train/10000}/aa.pickle +0 -0
  29. data/dataset/{10000 → train/10000}/aksam.pickle +0 -0
  30. data/dataset/{10000 → train/10000}/cnnturk.pickle +0 -0
  31. data/dataset/{10000 → train/10000}/cumhuriyet.pickle +0 -0
  32. data/dataset/{10000 → train/10000}/ensonhaber.pickle +0 -0
  33. data/dataset/{10000 → train/10000}/haber7.pickle +0 -0
  34. data/dataset/{10000 → train/10000}/haberglobal.pickle +0 -0
  35. data/dataset/{10000 → train/10000}/haberler.pickle +0 -0
  36. data/dataset/{10000 → train/10000}/haberturk.pickle +0 -0
  37. data/dataset/{10000 → train/10000}/hurriyet.pickle +0 -0
  38. model/confusion_matrix_test.jpg +0 -0
  39. model/confusion_matrix_train.jpg +0 -0
  40. model/model-10-1000_0_metrics.json +22 -0
  41. model/model-10-1000_1_metrics.json +22 -0
  42. model/model-10-1000_2_metrics.json +22 -0
  43. model/model-10-1000_3_metrics.json +22 -0
  44. model/model-10-1000_4_metrics.json +22 -0
  45. model/model-10-1000_metrics.json +22 -0
  46. src/train.py +42 -7
data/dataset/test/milliyet.json ADDED
The diff for this file is too large to render. See raw diff
 
data/dataset/test/milliyet.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:585be7ab8cb6bfcaa9008463bae2314c1a44be5f87b7f0adbe6ed22a93e86f19
3
+ size 3421809
data/dataset/test/ntv.json ADDED
The diff for this file is too large to render. See raw diff
 
data/dataset/test/ntv.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dcf2df5b231846a78217d8dee458dba96711b4d0e10b0cb3704f90ecdebdad0
3
+ size 2565944
data/dataset/test/trthaber.json ADDED
The diff for this file is too large to render. See raw diff
 
data/dataset/test/trthaber.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d58c157566b662d53ff919d9b64fbae033163f07328908b88540f2388b638fe3
3
+ size 2848309
data/dataset/{100 → train/100}/aa.json RENAMED
File without changes
data/dataset/{100 → train/100}/aa.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/aksam.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/cnnturk.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/cumhuriyet.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/ensonhaber.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/haber7.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/haberglobal.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/haberler.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/haberturk.pickle RENAMED
File without changes
data/dataset/{100 → train/100}/hurriyet.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/aa.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/aksam.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/cnnturk.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/cumhuriyet.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/ensonhaber.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/haber7.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/haberglobal.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/haberler.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/haberturk.pickle RENAMED
File without changes
data/dataset/{1000 → train/1000}/hurriyet.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/aa.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/aksam.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/cnnturk.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/cumhuriyet.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/ensonhaber.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/haber7.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/haberglobal.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/haberler.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/haberturk.pickle RENAMED
File without changes
data/dataset/{10000 → train/10000}/hurriyet.pickle RENAMED
File without changes
model/confusion_matrix_test.jpg ADDED
model/confusion_matrix_train.jpg ADDED
model/model-10-1000_0_metrics.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "content_precision": 0.9681234674743978,
3
+ "content_recall": 0.9866235484345142,
4
+ "content_f1": 0.9772859638905067,
5
+ "content_number": "6803",
6
+ "date_precision": 0.9992685183193133,
7
+ "date_recall": 0.9961823056300269,
8
+ "date_f1": 0.9977230253689343,
9
+ "date_number": "46625",
10
+ "description_precision": 0.9794250194250195,
11
+ "description_recall": 0.9844120954642009,
12
+ "description_f1": 0.9819122252169442,
13
+ "description_number": "32012",
14
+ "title_precision": 0.9863267466478476,
15
+ "title_recall": 0.9820241824516205,
16
+ "title_f1": 0.9841707621213233,
17
+ "title_number": "34157",
18
+ "overall_precision": 0.98844452620049,
19
+ "overall_recall": 0.98844452620049,
20
+ "overall_f1": 0.98844452620049,
21
+ "overall_accuracy": 0.98844452620049
22
+ }
model/model-10-1000_1_metrics.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "content_precision": 0.999260464428339,
3
+ "content_recall": 0.9930912832573865,
4
+ "content_f1": 0.9961663226186965,
5
+ "content_number": "6803",
6
+ "date_precision": 0.9998069870681335,
7
+ "date_recall": 0.9998927613941019,
8
+ "date_f1": 0.9998498723915328,
9
+ "date_number": "46625",
10
+ "description_precision": 0.9948554859227388,
11
+ "description_recall": 0.9967512182931401,
12
+ "description_f1": 0.9958024498712646,
13
+ "description_number": "32012",
14
+ "title_precision": 0.9970703697193414,
15
+ "title_recall": 0.9963989811751618,
16
+ "title_f1": 0.9967345623874302,
17
+ "title_number": "34157",
18
+ "overall_precision": 0.9976671655643536,
19
+ "overall_recall": 0.9976671655643536,
20
+ "overall_f1": 0.9976671655643536,
21
+ "overall_accuracy": 0.9976671655643536
22
+ }
model/model-10-1000_2_metrics.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "content_precision": 0.9986713906111603,
3
+ "content_recall": 0.9944142290166104,
4
+ "content_f1": 0.9965382632393017,
5
+ "content_number": "6803",
6
+ "date_precision": 0.9999570953555722,
7
+ "date_recall": 0.9997426273458445,
8
+ "date_f1": 0.9998498498498499,
9
+ "date_number": "46625",
10
+ "description_precision": 0.9959747878182726,
11
+ "description_recall": 0.9970948394352118,
12
+ "description_f1": 0.9965344989072745,
13
+ "description_number": "32012",
14
+ "title_precision": 0.9975995316159251,
15
+ "title_recall": 0.9976871505108762,
16
+ "title_f1": 0.9976433391395991,
17
+ "title_number": "34157",
18
+ "overall_precision": 0.9981437661479803,
19
+ "overall_recall": 0.9981437661479803,
20
+ "overall_f1": 0.9981437661479803,
21
+ "overall_accuracy": 0.9981437661479803
22
+ }
model/model-10-1000_3_metrics.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "content_precision": 0.998820406959599,
3
+ "content_recall": 0.9957371747758342,
4
+ "content_f1": 0.9972764078027236,
5
+ "content_number": "6803",
6
+ "date_precision": 0.9998069994853319,
7
+ "date_recall": 0.9999571045576408,
8
+ "date_f1": 0.9998820463879386,
9
+ "date_number": "46625",
10
+ "description_precision": 0.9972210947013458,
11
+ "description_recall": 0.9976883668624266,
12
+ "description_f1": 0.9974546760567778,
13
+ "description_number": "32012",
14
+ "title_precision": 0.9983019088886287,
15
+ "title_recall": 0.9982726820271101,
16
+ "title_f1": 0.9982872952439505,
17
+ "title_number": "34157",
18
+ "overall_precision": 0.9986287281453549,
19
+ "overall_recall": 0.9986287281453549,
20
+ "overall_f1": 0.9986287281453549,
21
+ "overall_accuracy": 0.9986287281453549
22
+ }
model/model-10-1000_4_metrics.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "content_precision": 0.9982324348210341,
3
+ "content_recall": 0.9961781566955755,
4
+ "content_f1": 0.9972042377869335,
5
+ "content_number": "6803",
6
+ "date_precision": 0.9997426659804426,
7
+ "date_recall": 0.9998927613941019,
8
+ "date_f1": 0.9998177080540871,
9
+ "date_number": "46625",
10
+ "description_precision": 0.9968141924602555,
11
+ "description_recall": 0.9969698862926403,
12
+ "description_f1": 0.9968920332974122,
13
+ "description_number": "32012",
14
+ "title_precision": 0.9977458356509266,
15
+ "title_recall": 0.997804256814123,
16
+ "title_f1": 0.997775045377364,
17
+ "title_number": "34157",
18
+ "overall_precision": 0.9983026330091892,
19
+ "overall_recall": 0.9983026330091892,
20
+ "overall_f1": 0.9983026330091892,
21
+ "overall_accuracy": 0.9983026330091892
22
+ }
model/model-10-1000_metrics.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "content_precision": 0.9982324348210341,
3
+ "content_recall": 0.9961781566955755,
4
+ "content_f1": 0.9972042377869335,
5
+ "content_number": "6803",
6
+ "date_precision": 0.9997426659804426,
7
+ "date_recall": 0.9998927613941019,
8
+ "date_f1": 0.9998177080540871,
9
+ "date_number": "46625",
10
+ "description_precision": 0.9968141924602555,
11
+ "description_recall": 0.9969698862926403,
12
+ "description_f1": 0.9968920332974122,
13
+ "description_number": "32012",
14
+ "title_precision": 0.9977458356509266,
15
+ "title_recall": 0.997804256814123,
16
+ "title_f1": 0.997775045377364,
17
+ "title_number": "34157",
18
+ "overall_precision": 0.9983026330091892,
19
+ "overall_recall": 0.9983026330091892,
20
+ "overall_f1": 0.9983026330091892,
21
+ "overall_accuracy": 0.9983026330091892
22
+ }
src/train.py CHANGED
@@ -145,6 +145,40 @@ class NewsTrainer:
145
  ]
146
  return true_predictions, true_labels
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @staticmethod
149
  def __compute_metrics(metric, return_entity_level_metrics=True):
150
  """
@@ -320,9 +354,10 @@ class NewsTrainer:
320
  train_data_path = "../data/dataset/test"
321
  model_path = "../model/model.pth"
322
  label_list = ["" + x for x in list(id2label.values())]
 
323
  dataset = self.__get_dataset(train_data_path)
324
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
325
- dataloader = DataLoader(dataset, batch_size=100)
326
  model = torch.load(model_path, map_location=torch.device(device))
327
  i = 0
328
  y_pred = []
@@ -369,12 +404,12 @@ class NewsTrainer:
369
  if __name__ == '__main__':
370
  trainer = NewsTrainer()
371
  # Eğitim
372
- model_name = "model-10-1000"
373
- _train_data_path = "./data/dataset/100"
374
- _model_output_path = "./models"
375
- trainer.run(model_name=model_name,
376
- train_data_path=_train_data_path,
377
- model_output_path=_model_output_path)
378
 
379
  # Değerlendirme
380
  trainer.evaluate()
 
145
  ]
146
  return true_predictions, true_labels
147
 
148
+ @staticmethod
149
+ def __get_labels_2(predictions, references, label_list, device):
150
+ """
151
+ Tahminleri ve referansları kullanarak etiketleri alır.
152
+
153
+ Args:
154
+ predictions (torch.Tensor): Tahminler tensörü.
155
+ references (torch.Tensor): Referanslar tensörü.
156
+ label_list (list): Etiket listesi.
157
+ device (torch.device): Cihaz türü.
158
+
159
+ Returns:
160
+ list, list: Gerçek tahminler ve gerçek etiketler listeleri.
161
+
162
+ """
163
+ # Tahminleri ve referansları numpy dizilerine dönüştürme
164
+ if device.type == "cpu":
165
+ y_pred = predictions.detach().clone().numpy()
166
+ y_true = references.detach().clone().numpy()
167
+ else:
168
+ y_pred = predictions.detach().cpu().clone().numpy()
169
+ y_true = references.detach().cpu().clone().numpy()
170
+
171
+ # İgnor index'ini (özel belirteçler) kaldırma
172
+ true_predictions = [
173
+ [label_list[p] for (p, l) in zip(pred, gold_label)]
174
+ for pred, gold_label in zip(y_pred, y_true)
175
+ ]
176
+ true_labels = [
177
+ [label_list[l] for (p, l) in zip(pred, gold_label)]
178
+ for pred, gold_label in zip(y_pred, y_true)
179
+ ]
180
+ return true_predictions, true_labels
181
+
182
  @staticmethod
183
  def __compute_metrics(metric, return_entity_level_metrics=True):
184
  """
 
354
  train_data_path = "../data/dataset/test"
355
  model_path = "../model/model.pth"
356
  label_list = ["" + x for x in list(id2label.values())]
357
+ label_list = label_list[:4]
358
  dataset = self.__get_dataset(train_data_path)
359
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
360
+ dataloader = DataLoader(dataset, batch_size=10)
361
  model = torch.load(model_path, map_location=torch.device(device))
362
  i = 0
363
  y_pred = []
 
404
  if __name__ == '__main__':
405
  trainer = NewsTrainer()
406
  # Eğitim
407
+ # model_name = "model-10-1000"
408
+ # _train_data_path = "./data/dataset/100"
409
+ # _model_output_path = "./models"
410
+ # trainer.run(model_name=model_name,
411
+ # train_data_path=_train_data_path,
412
+ # model_output_path=_model_output_path)
413
 
414
  # Değerlendirme
415
  trainer.evaluate()