Yoshinoheart commited on
Commit
bf87a38
1 Parent(s): 339d2fd
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -10,6 +10,10 @@ args = TTSettings(num_beams=5, min_length=1)
10
  tokenizer = T5Tokenizer.from_pretrained("thaboe01/t5-spelling-corrector")
11
  model = T5ForConditionalGeneration.from_pretrained("thaboe01/t5-spelling-corrector")
12
 
 
 
 
 
13
  # Function to split text into chunks
14
  def split_text(text, chunk_size=500):
15
  chunks = []
@@ -19,9 +23,9 @@ def split_text(text, chunk_size=500):
19
 
20
  # Function to correct spelling using T5 model
21
  def correct_spelling(text):
22
- input_ids = tokenizer(text, return_tensors="pt").input_ids
23
  outputs = model.generate(input_ids)
24
- corrected_text = tokenizer.decode(outputs[0])
25
  return corrected_text
26
 
27
  # Streamlit app
 
10
  tokenizer = T5Tokenizer.from_pretrained("thaboe01/t5-spelling-corrector")
11
  model = T5ForConditionalGeneration.from_pretrained("thaboe01/t5-spelling-corrector")
12
 
13
+ # Place the model on the appropriate device
14
+ device = "cuda" if st.session_state.use_gpu else "cpu" # Use GPU if available, otherwise CPU
15
+ model = model.to(device)
16
+
17
  # Function to split text into chunks
18
  def split_text(text, chunk_size=500):
19
  chunks = []
 
23
 
24
  # Function to correct spelling using T5 model
25
  def correct_spelling(text):
26
+ input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
27
  outputs = model.generate(input_ids)
28
+ corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
  return corrected_text
30
 
31
  # Streamlit app