kevinwang676 commited on
Commit
0fc73ad
1 Parent(s): d7de8f4

Update bark/generation.py

Browse files
Files changed (1) hide show
  1. bark/generation.py +9 -11
bark/generation.py CHANGED
@@ -494,8 +494,7 @@ def generate_text_semantic(
494
  )
495
  if top_p is not None:
496
  # faster to convert to numpy
497
- logits_device = relevant_logits.device
498
- logits_dtype = relevant_logits.type()
499
  relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
500
  sorted_indices = np.argsort(relevant_logits)[::-1]
501
  sorted_logits = relevant_logits[sorted_indices]
@@ -505,7 +504,7 @@ def generate_text_semantic(
505
  sorted_indices_to_remove[0] = False
506
  relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
507
  relevant_logits = torch.from_numpy(relevant_logits)
508
- relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
509
  if top_k is not None:
510
  v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
511
  relevant_logits[relevant_logits < v[-1]] = -float("Inf")
@@ -599,10 +598,10 @@ def generate_coarse(
599
  and x_coarse_history.shape[-1] >= 0
600
  and x_coarse_history.min() >= 0
601
  and x_coarse_history.max() <= CODEBOOK_SIZE - 1
602
- and (
603
- round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
604
- == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
605
- )
606
  )
607
  x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
608
  # trim histories correctly
@@ -685,8 +684,7 @@ def generate_coarse(
685
  relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
686
  if top_p is not None:
687
  # faster to convert to numpy
688
- logits_device = relevant_logits.device
689
- logits_dtype = relevant_logits.type()
690
  relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
691
  sorted_indices = np.argsort(relevant_logits)[::-1]
692
  sorted_logits = relevant_logits[sorted_indices]
@@ -696,7 +694,7 @@ def generate_coarse(
696
  sorted_indices_to_remove[0] = False
697
  relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
698
  relevant_logits = torch.from_numpy(relevant_logits)
699
- relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
700
  if top_k is not None:
701
  v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
702
  relevant_logits[relevant_logits < v[-1]] = -float("Inf")
@@ -862,4 +860,4 @@ def codec_decode(fine_tokens):
862
  del arr, emb, out
863
  if OFFLOAD_CPU:
864
  model.to("cpu")
865
- return audio_arr
 
494
  )
495
  if top_p is not None:
496
  # faster to convert to numpy
497
+ original_device = relevant_logits.device
 
498
  relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
499
  sorted_indices = np.argsort(relevant_logits)[::-1]
500
  sorted_logits = relevant_logits[sorted_indices]
 
504
  sorted_indices_to_remove[0] = False
505
  relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
506
  relevant_logits = torch.from_numpy(relevant_logits)
507
+ relevant_logits = relevant_logits.to(original_device)
508
  if top_k is not None:
509
  v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
510
  relevant_logits[relevant_logits < v[-1]] = -float("Inf")
 
598
  and x_coarse_history.shape[-1] >= 0
599
  and x_coarse_history.min() >= 0
600
  and x_coarse_history.max() <= CODEBOOK_SIZE - 1
601
+ #and (
602
+ # round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
603
+ # == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
604
+ #)
605
  )
606
  x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
607
  # trim histories correctly
 
684
  relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
685
  if top_p is not None:
686
  # faster to convert to numpy
687
+ original_device = relevant_logits.device
 
688
  relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
689
  sorted_indices = np.argsort(relevant_logits)[::-1]
690
  sorted_logits = relevant_logits[sorted_indices]
 
694
  sorted_indices_to_remove[0] = False
695
  relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
696
  relevant_logits = torch.from_numpy(relevant_logits)
697
+ relevant_logits = relevant_logits.to(original_device)
698
  if top_k is not None:
699
  v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
700
  relevant_logits[relevant_logits < v[-1]] = -float("Inf")
 
860
  del arr, emb, out
861
  if OFFLOAD_CPU:
862
  model.to("cpu")
863
+ return audio_arr