Pendrokar commited on
Commit
eba4564
1 Parent(s): 0408757

HF Spaces API support

Browse files

Uses Gradio Client to fetch the required parameters of public HF Spaces and overrides their TTS defaults if necessary.

A way to fix #15

And to help with #21 and #23

Files changed (1) hide show
  1. app.py +129 -4
app.py CHANGED
@@ -36,7 +36,54 @@ AVAILABLE_MODELS = {
36
  'ElevenLabs': 'eleven',
37
  'OpenVoice': 'openvoice',
38
  'Pheme': 'pheme',
39
- 'MetaVoice': 'metavoice'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  }
41
 
42
  SPACE_ID = os.getenv('SPACE_ID')
@@ -45,6 +92,9 @@ MIN_SAMPLE_TXT_LENGTH = 10
45
  DB_DATASET_ID = os.getenv('DATASET_ID')
46
  DB_NAME = "database.db"
47
 
 
 
 
48
  # If /data available => means local storage is enabled => let's use it!
49
  DB_PATH = f"/data/{DB_NAME}" if os.path.isdir("/data") else DB_NAME
50
  print(f"Using {DB_PATH}")
@@ -118,6 +168,7 @@ if not os.path.isfile(DB_PATH):
118
  # Create DB table (if doesn't exist)
119
  create_db_if_missing()
120
 
 
121
  # Sync local DB with remote repo every 5 minute (only if a change is detected)
122
  scheduler = CommitScheduler(
123
  repo_id=DB_DATASET_ID,
@@ -133,7 +184,7 @@ scheduler = CommitScheduler(
133
  ####################################
134
  # Router API
135
  ####################################
136
- router = Client("TTS-AGI/tts-router", hf_token=os.getenv('HF_TOKEN'))
137
  ####################################
138
  # Gradio app
139
  ####################################
@@ -291,6 +342,9 @@ model_licenses = {
291
  'metavoice': 'Apache 2.0',
292
  'elevenlabs': 'Proprietary',
293
  'whisperspeech': 'MIT',
 
 
 
294
  }
295
  model_links = {
296
  'styletts2': 'https://github.com/yl4579/StyleTTS2',
@@ -561,7 +615,44 @@ def synthandreturn(text):
561
  def predict_and_update_result(text, model, result_storage):
562
  try:
563
  if model in AVAILABLE_MODELS:
564
- result = router.predict(text, AVAILABLE_MODELS[model].lower(), api_name="/synthesize")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  else:
566
  result = router.predict(text, model.lower(), api_name="/synthesize")
567
  except:
@@ -593,6 +684,40 @@ def synthandreturn(text):
593
  # doloudnorm(result)
594
  # except:
595
  # pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  results = {}
597
  thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1, results))
598
  thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2, results))
@@ -709,4 +834,4 @@ with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}
709
  gr.Markdown(f"If you use this data in your publication, please cite us!\n\nCopy the BibTeX citation to cite this source:\n\n```bibtext\n{CITATION_TEXT}\n```\n\nPlease remember that all generated audio clips should be assumed unsuitable for redistribution or commercial use.")
710
 
711
 
712
- demo.queue(api_open=False, default_concurrency_limit=40).launch(show_api=False)
 
36
  'ElevenLabs': 'eleven',
37
  'OpenVoice': 'openvoice',
38
  'Pheme': 'pheme',
39
+ 'MetaVoice': 'metavoice',
40
+
41
+ # '<Space>': <function>#<return-index-of-audio-param>
42
+ # 'coqui/xtts': '1#1',
43
+ # 'collabora/WhisperSpeech': '/whisper_speech_demo#0',
44
+ # 'myshell-ai/OpenVoice': '1#1',
45
+ # 'PolyAI/pheme': '/predict#0', #sleepy HF Space
46
+ # 'mrfakename/MetaVoice-1B-v0.1': '/tts#0',
47
+
48
+ # xVASynth (CPU)
49
+ 'Pendrokar/xVASynth': '/predict#0',
50
+
51
+ # MeloTTS
52
+ # 'mrfakename/MeloTTS': '0#0', #API disabled
53
+
54
+ # CoquiTTS (CPU)
55
+ 'coqui/CoquiTTS': '0#0',
56
+
57
+ # 'pytorch/Tacotron2': '0#0', #old gradio
58
+ }
59
+
60
+ OVERRIDE_INPUTS = {
61
+ 'coqui/xtts': {
62
+ 1: 'en',
63
+ 2: 'https://cdn-uploads.huggingface.co/production/uploads/641de0213239b631552713e4/iKHHqWxWy6Zfmp6QP6CZZ.wav', # voice sample - Scarlett Johanson
64
+ 3: 'https://cdn-uploads.huggingface.co/production/uploads/641de0213239b631552713e4/iKHHqWxWy6Zfmp6QP6CZZ.wav', # voice sample - Scarlett Johanson
65
+ 4: False, #use_mic
66
+ 5: False, #cleanup_reference
67
+ 6: False, #auto_detect
68
+ },
69
+ 'collabora/WhisperSpeech': {
70
+ 1: 'https://cdn-uploads.huggingface.co/production/uploads/641de0213239b631552713e4/iKHHqWxWy6Zfmp6QP6CZZ.wav', # voice sample - Scarlett Johanson
71
+ 2: 'https://cdn-uploads.huggingface.co/production/uploads/641de0213239b631552713e4/iKHHqWxWy6Zfmp6QP6CZZ.wav', # voice sample - Scarlett Johanson
72
+ 3: 14.0, #Tempo - Gradio Slider issue: takes min. rather than value
73
+ },
74
+ 'myshell-ai/OpenVoice': {
75
+ 1: 'default', # style
76
+ 2: 'https://cdn-uploads.huggingface.co/production/uploads/641de0213239b631552713e4/iKHHqWxWy6Zfmp6QP6CZZ.wav', # voice sample - Scarlett Johanson
77
+ },
78
+ 'PolyAI/pheme': {
79
+ 1: 'YOU1000000044_S0000798', # voice
80
+ 2: 210,
81
+ 3: 0.7, #Tempo - Gradio Slider issue: takes min. rather than value
82
+ },
83
+ 'Pendrokar/xVASynth': {
84
+ 1: 'ccby_nvidia_hifi_92_F', #fine-tuned voice model name
85
+ 3: 1.0, #pacing/duration - Gradio Slider issue: takes min. rather than value
86
+ },
87
  }
88
 
89
  SPACE_ID = os.getenv('SPACE_ID')
 
92
  DB_DATASET_ID = os.getenv('DATASET_ID')
93
  DB_NAME = "database.db"
94
 
95
+ SPACE_ID = 'Pendrokar/TTS-Arena'
96
+ DB_DATASET_ID = 'PenLocal'
97
+
98
  # If /data available => means local storage is enabled => let's use it!
99
  DB_PATH = f"/data/{DB_NAME}" if os.path.isdir("/data") else DB_NAME
100
  print(f"Using {DB_PATH}")
 
168
  # Create DB table (if doesn't exist)
169
  create_db_if_missing()
170
 
171
+ hf_token = os.getenv('HF_TOKEN')
172
  # Sync local DB with remote repo every 5 minute (only if a change is detected)
173
  scheduler = CommitScheduler(
174
  repo_id=DB_DATASET_ID,
 
184
  ####################################
185
  # Router API
186
  ####################################
187
+ router = Client("TTS-AGI/tts-router", hf_token=hf_token)
188
  ####################################
189
  # Gradio app
190
  ####################################
 
342
  'metavoice': 'Apache 2.0',
343
  'elevenlabs': 'Proprietary',
344
  'whisperspeech': 'MIT',
345
+
346
+ 'Pendrokar/xVASynth': 'GPT3',
347
+ 'Pendrokar/xVASynthStreaming': 'GPT3',
348
  }
349
  model_links = {
350
  'styletts2': 'https://github.com/yl4579/StyleTTS2',
 
615
  def predict_and_update_result(text, model, result_storage):
616
  try:
617
  if model in AVAILABLE_MODELS:
618
+ if '/' in model:
619
+ # Use public HF Space
620
+ mdl_space = Client(model, hf_token=hf_token)
621
+ # assume the index is one of the first 9 return params
622
+ return_audio_index = int(AVAILABLE_MODELS[model][-1])
623
+ endpoints = mdl_space.view_api(all_endpoints=True, print_info=False, return_format='dict')
624
+
625
+ api_name = None
626
+ fn_index = None
627
+ # has named endpoint
628
+ if '/' == AVAILABLE_MODELS[model][:1]:
629
+ # assume the index is one of the first 9 params
630
+ api_name = AVAILABLE_MODELS[model][:-2]
631
+
632
+ space_inputs = _get_param_examples(
633
+ endpoints['named_endpoints'][api_name]['parameters']
634
+ )
635
+ # has unnamed endpoint
636
+ else:
637
+ # endpoint index is the first character
638
+ fn_index = int(AVAILABLE_MODELS[model][0])
639
+
640
+ space_inputs = _get_param_examples(
641
+ endpoints['unnamed_endpoints'][str(fn_index)]['parameters']
642
+ )
643
+
644
+ space_inputs = _override_params(space_inputs, model)
645
+
646
+ # force text
647
+ space_inputs[0] = text
648
+
649
+ results = mdl_space.predict(*space_inputs, api_name=api_name, fn_index=fn_index)
650
+
651
+ # return path to audio
652
+ result = results[return_audio_index] if (not isinstance(results, str)) else results
653
+ else:
654
+ # Use the private HF Space
655
+ result = router.predict(text, AVAILABLE_MODELS[model].lower(), api_name="/synthesize")
656
  else:
657
  result = router.predict(text, model.lower(), api_name="/synthesize")
658
  except:
 
684
  # doloudnorm(result)
685
  # except:
686
  # pass
687
+
688
+ def _get_param_examples(parameters):
689
+ example_inputs = []
690
+ for param_info in parameters:
691
+ if (
692
+ param_info['component'] == 'Radio'
693
+ or param_info['component'] == 'Dropdown'
694
+ or param_info['component'] == 'Audio'
695
+ or param_info['python_type']['type'] == 'str'
696
+ ):
697
+ example_inputs.append(str(param_info['example_input']))
698
+ continue
699
+ if param_info['python_type']['type'] == 'int':
700
+ example_inputs.append(int(param_info['example_input']))
701
+ continue
702
+ if param_info['python_type']['type'] == 'float':
703
+ example_inputs.append(float(param_info['example_input']))
704
+ continue
705
+ if param_info['python_type']['type'] == 'bool':
706
+ example_inputs.append(bool(param_info['example_input']))
707
+ continue
708
+
709
+ return example_inputs
710
+
711
+ def _override_params(inputs, modelname):
712
+ try:
713
+ for key,value in OVERRIDE_INPUTS[modelname].items():
714
+ inputs[key] = value
715
+ print(f"Default inputs overridden for {modelname}")
716
+ except:
717
+ pass
718
+
719
+ return inputs
720
+
721
  results = {}
722
  thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1, results))
723
  thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2, results))
 
834
  gr.Markdown(f"If you use this data in your publication, please cite us!\n\nCopy the BibTeX citation to cite this source:\n\n```bibtext\n{CITATION_TEXT}\n```\n\nPlease remember that all generated audio clips should be assumed unsuitable for redistribution or commercial use.")
835
 
836
 
837
+ demo.queue(api_open=False, default_concurrency_limit=40).launch(show_api=False)