Narsil HF staff commited on
Commit
b166743
1 Parent(s): 05e59e7

Initial commit.

Browse files

Lfs.

Removing assets.

Files changed (6) hide show
  1. .gitattributes +10 -0
  2. DejaVuSans.ttf +3 -0
  3. GoNotoCurrent.ttf +3 -0
  4. README.md +4 -4
  5. app.py +169 -0
  6. screenshot.py +387 -0
.gitattributes CHANGED
@@ -25,3 +25,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
29
+ GoNotoCurrent.ttf filter=lfs diff=lfs merge=lfs -text
30
+ assets/ filter=lfs diff=lfs merge=lfs -text
31
+ assets/image_5.png filter=lfs diff=lfs merge=lfs -text
32
+ assets/image_6.png filter=lfs diff=lfs merge=lfs -text
33
+ assets/image_7.png filter=lfs diff=lfs merge=lfs -text
34
+ assets/image_1.png filter=lfs diff=lfs merge=lfs -text
35
+ assets/image_2.png filter=lfs diff=lfs merge=lfs -text
36
+ assets/image_3.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/image_4.png filter=lfs diff=lfs merge=lfs -text
DejaVuSans.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7da195a74c55bef988d0d48f9508bd5d849425c1770dba5d7bfc6ce9ed848954
3
+ size 757076
GoNotoCurrent.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89dea97df90b5fa98dc97b12a9fb2ad072a5bc345fba4b70de7ad1f1bb8e49bd
3
+ size 14804740
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Bloom Demo
3
- emoji: 👁
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.0.24
8
  app_file: app.py
 
1
  ---
2
+ title: Bloom Test
3
+ emoji: 🐨
4
+ colorFrom: green
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.0.24
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import re
3
+ import requests
4
+ import json
5
+ import os
6
+ from screenshot import BG_COMP, BOX_COMP, GENERATION_VAR, PROMPT_VAR, main
7
+ from pathlib import Path
8
+
9
+ title = "BLOOM"
10
+ description = """Gradio Demo for BLOOM. To use it, simply add your text, or click one of the examples to load them.
11
+ Tips:
12
+ - Do NOT talk to BLOOM as an entity, it's not a chatbot but a webpage/blog/article completion model.
13
+ - For the best results: MIMIC a few sentences of a webpage similar to the content you want to generate.
14
+ Start a paragraph as if YOU were writing a blog, webpage, math post, coding article and BLOOM will generate a coherent follow-up. Longer prompts usually give more interesting results.
15
+ Options:
16
+ - sampling: imaginative completions (may be not super accurate e.g. math/history)
17
+ - greedy: accurate completions (may be more boring or have repetitions)
18
+ """
19
+
20
+ API_URL = os.getenv("API_URL")
21
+
22
+ examples = [
23
+ [
24
+ 'A "whatpu" is a small, furry animal native to Tanzania. An example of a sentence that uses the word whatpu is: We were traveling in Africa and we saw these very cute whatpus. To do a "farduddle" means to jump up and down really fast. An example of a sentence that uses the word farduddle is:',
25
+ 32,
26
+ "Sample",
27
+ False,
28
+ "Sample 1",
29
+ ],
30
+ [
31
+ "A poem about the beauty of science by Alfred Edgar Brittle\nTitle: The Magic Craft\nIn the old times",
32
+ 50,
33
+ "Sample",
34
+ False,
35
+ "Sample 1",
36
+ ],
37
+ ["استخراج العدد العاملي في لغة بايثون:", 30, "Greedy", False, "Sample 1"],
38
+ ["Pour déguster un ortolan, il faut tout d'abord", 32, "Sample", False, "Sample 1"],
39
+ [
40
+ "Traduce español de España a español de Argentina\nEl coche es rojo - el auto es rojo\nEl ordenador es nuevo - la computadora es nueva\nel boligrafo es negro -",
41
+ 16,
42
+ "Sample",
43
+ False,
44
+ "Sample 1",
45
+ ],
46
+ [
47
+ "Estos ejemplos quitan vocales de las palabras\nEjemplos:\nhola - hl\nmanzana - mnzn\npapas - pps\nalacran - lcrn\npapa -",
48
+ 16,
49
+ "Sample",
50
+ False,
51
+ "Sample 1",
52
+ ],
53
+ [
54
+ "Question: If I put cheese into the fridge, will it melt?\nAnswer:",
55
+ 32,
56
+ "Sample",
57
+ False,
58
+ "Sample 1",
59
+ ],
60
+ ["Math exercise - answers:\n34+10=44\n54+20=", 16, "Greedy", False, "Sample 1"],
61
+ [
62
+ "Question: Where does the Greek Goddess Persephone spend half of the year when she is not with her mother?\nAnswer:",
63
+ 24,
64
+ "Greedy",
65
+ False,
66
+ "Sample 1",
67
+ ],
68
+ [
69
+ "spelling test answers.\nWhat are the letters in « language »?\nAnswer: l-a-n-g-u-a-g-e\nWhat are the letters in « Romanian »?\nAnswer:",
70
+ 24,
71
+ "Greedy",
72
+ False,
73
+ "Sample 1",
74
+ ],
75
+ ]
76
+
77
+
78
+ def query(payload):
79
+ print(payload)
80
+ response = requests.request("POST", API_URL, json=payload)
81
+ print(response)
82
+ return json.loads(response.content.decode("utf-8"))
83
+
84
+
85
+ def inference(input_sentence, max_length, sample_or_greedy, raw_text=False, seed=42):
86
+ if sample_or_greedy == "Sample":
87
+ parameters = {
88
+ "max_new_tokens": max_length,
89
+ "top_p": 0.9,
90
+ "do_sample": True,
91
+ "seed": seed,
92
+ "early_stopping": False,
93
+ "length_penalty": 0.0,
94
+ "eos_token_id": None,
95
+ }
96
+ else:
97
+ parameters = {
98
+ "max_new_tokens": max_length,
99
+ "do_sample": False,
100
+ "seed": seed,
101
+ "early_stopping": False,
102
+ "length_penalty": 0.0,
103
+ "eos_token_id": None,
104
+ }
105
+
106
+ payload = {"inputs": input_sentence, "parameters": parameters}
107
+
108
+ data = query(payload)
109
+
110
+ if raw_text:
111
+ return None, data[0]["generated_text"]
112
+
113
+ width, height = 3246, 3246
114
+ assets_path = "assets"
115
+ font_mapping = {
116
+ "latin characters (faster)": "DejaVuSans.ttf",
117
+ "complete alphabet (slower)": "GoNotoCurrent.ttf",
118
+ }
119
+ working_dir = Path(__file__).parent.resolve()
120
+ font_path = str(working_dir / font_mapping["complete alphabet (slower)"])
121
+ img_save_path = str(working_dir / "output.jpeg")
122
+ colors = {
123
+ BG_COMP: "#000000",
124
+ PROMPT_VAR: "#FFFFFF",
125
+ GENERATION_VAR: "#FF57A0",
126
+ BOX_COMP: "#120F25",
127
+ }
128
+
129
+ new_string = data[0]["generated_text"].split(input_sentence, 1)[1]
130
+
131
+ _, img = main(
132
+ input_sentence,
133
+ new_string,
134
+ width,
135
+ height,
136
+ # assets_path=assets_path,
137
+ font_path=font_path,
138
+ colors=colors,
139
+ frame_to_box_margin=200,
140
+ text_to_text_box_margin=50,
141
+ init_font_size=150,
142
+ right_align=False,
143
+ )
144
+ return img, data[0]["generated_text"]
145
+
146
+
147
+ gr.Interface(
148
+ inference,
149
+ [
150
+ gr.inputs.Textbox(label="Input"),
151
+ gr.inputs.Slider(1, 64, default=32, step=1, label="Tokens to generate"),
152
+ gr.inputs.Radio(
153
+ ["Sample", "Greedy"], label="Sample or greedy", default="Sample"
154
+ ),
155
+ gr.Checkbox(label="Just output raw text"),
156
+ gr.inputs.Radio(
157
+ ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"],
158
+ default="Sample 1",
159
+ label="Sample other generations (only work in 'Sample' mode",
160
+ type="index",
161
+ ),
162
+ ],
163
+ ["image", "text"],
164
+ examples=examples,
165
+ # article=article,
166
+ cache_examples=False,
167
+ title=title,
168
+ description=description,
169
+ ).launch()
screenshot.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from pathlib import Path
4
+ import re
5
+ from time import time
6
+ from typing import List, Optional, Tuple
7
+
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ from PIL.ImageFont import FreeTypeFont
10
+
11
+ from PIL import ImageFilter
12
+ logging.basicConfig(level="DEBUG")
13
+ logger = logging.getLogger(__name__)
14
+
15
+ BG_COMP = "bg_composant"
16
+ PROMPT_VAR = "prompt_variable"
17
+ GENERATION_VAR = "generation_var"
18
+ BOX_COMP = "box_comp"
19
+
20
+
21
+ def wrap_text(font: FreeTypeFont, text: str, max_width: int, direction: str = "ltr") -> str:
22
+ """
23
+ Wraps the text at the given width.
24
+ :param font: Font to use.
25
+ :param text: Text to fit.
26
+ :param max_width: Maximum width of the final text, in pixels.
27
+ :param max_height: Maximum height height of the final text, in pixels.
28
+ :param spacing: The number of pixels between lines.
29
+ :param direction: Direction of the text. It can be 'rtl' (right to
30
+ left), 'ltr' (left to right) or 'ttb' (top to bottom).
31
+ Requires libraqm.
32
+ :return: The wrapped text.
33
+ """
34
+ words = text.split()
35
+
36
+ lines: list[str] = [""]
37
+ curr_line_width = 0
38
+
39
+ for word in words:
40
+ if curr_line_width == 0:
41
+ word_width = font.getlength(word, direction)
42
+
43
+ lines[-1] = word
44
+ curr_line_width = word_width
45
+ else:
46
+ new_line_width = font.getlength(f"{lines[-1]} {word}", direction)
47
+
48
+ if new_line_width > max_width:
49
+ # Word is too long to fit on the current line
50
+ word_width = font.getlength(word, direction)
51
+
52
+ # Put the word on the next line
53
+ lines.append(word)
54
+ curr_line_width = word_width
55
+ else:
56
+ # Put the word on the current line
57
+ lines[-1] = f"{lines[-1]} {word}"
58
+ curr_line_width = new_line_width
59
+
60
+ return "\n".join(lines)
61
+
62
+
63
+ # tr.wrap(my_str, width=30)
64
+
65
+
66
+ def _fit_paragraph(
67
+ paragraph,
68
+ font: FreeTypeFont,
69
+ max_width: int,
70
+ max_height: int,
71
+ line_height,
72
+ num_lines: List[str],
73
+ spacing: int = 4,
74
+ direction: str = "ltr",
75
+
76
+ ):
77
+ paragraph_lines: list[str] = [""]
78
+ curr_line_width = 0
79
+ # This is a very bad splitter for Chinese
80
+ words = list(paragraph) if re.search(u'[\u4e00-\u9fff]',paragraph) else paragraph.split(" ")
81
+
82
+ for word in words:
83
+ if curr_line_width == 0:
84
+ word_width = font.getlength(word, direction)
85
+
86
+ if word_width > max_width:
87
+ # Word is longer than max_width
88
+ return None
89
+
90
+ paragraph_lines[-1] = word
91
+ curr_line_width = word_width
92
+ else:
93
+ new_line_width = font.getlength(f"{paragraph_lines[-1]} {word}", direction)
94
+
95
+ if new_line_width > max_width:
96
+ # Word is too long to fit on the current line
97
+ word_width = font.getlength(word, direction)
98
+ new_num_lines = num_lines + len(paragraph_lines) + 1
99
+ new_text_height = (new_num_lines * line_height) + (new_num_lines * spacing)
100
+
101
+ if word_width > max_width or new_text_height > max_height:
102
+ # Word is longer than max_width, and
103
+ # adding a new line would make the text too tall
104
+ return None
105
+
106
+ # Put the word on the next line
107
+ paragraph_lines.append(word)
108
+ curr_line_width = word_width
109
+ else:
110
+ # Put the word on the current line
111
+ paragraph_lines[-1] = f"{paragraph_lines[-1]} {word}"
112
+ curr_line_width = new_line_width
113
+ return paragraph_lines
114
+
115
+ def try_fit_text(
116
+ font: FreeTypeFont,
117
+ prompt: str,
118
+ generation: str,
119
+ max_width: int,
120
+ max_height: int,
121
+ spacing: int = 4,
122
+ direction: str = "ltr",
123
+ ) -> Optional[str]:
124
+ """
125
+ Attempts to wrap the text into a rectangle.
126
+ Tries to fit the text into a box using the given font at decreasing sizes,
127
+ based on ``scale_factor``. Makes ``max_iterations`` attempts.
128
+ :param font: Font to use.
129
+ :param text: Text to fit.
130
+ :param max_width: Maximum width of the final text, in pixels.
131
+ :param max_height: Maximum height height of the final text, in pixels.
132
+ :param spacing: The number of pixels between lines.
133
+ :param direction: Direction of the text. It can be 'rtl' (right to
134
+ left), 'ltr' (left to right) or 'ttb' (top to bottom).
135
+ Requires libraqm.
136
+ :return: If able to fit the text, the wrapped text. Otherwise, ``None``.
137
+ """
138
+
139
+ line_height = font.size
140
+
141
+ if line_height > max_height:
142
+ # The line height is already too big
143
+ return None
144
+
145
+ prompt_lines: list[str] = []
146
+ paragraphs = prompt.split("\n")
147
+ for paragraph in paragraphs:
148
+ paragraph_lines = _fit_paragraph(
149
+ paragraph=paragraph,
150
+ font=font,
151
+ max_width=max_width,
152
+ max_height=max_height,
153
+ line_height=line_height,
154
+ spacing=spacing,
155
+ num_lines=len(prompt_lines),
156
+ direction=direction,
157
+ )
158
+ if paragraph_lines is None:
159
+ return None
160
+ prompt_lines.extend(paragraph_lines)
161
+ generation_lines: list[str] = []
162
+ paragraphs = f"{prompt_lines[-1]}{generation}".split("\n")
163
+ for paragraph in paragraphs:
164
+ paragraph_lines = _fit_paragraph(
165
+ paragraph=paragraph,
166
+ font=font,
167
+ max_width=max_width,
168
+ max_height=max_height,
169
+ line_height=line_height,
170
+ spacing=spacing,
171
+ num_lines=len(prompt_lines) + len(generation_lines),
172
+ direction=direction,
173
+ )
174
+ if paragraph_lines is None:
175
+ return None
176
+ generation_lines.extend(paragraph_lines)
177
+ generation_lines[0] = generation_lines[0][len(prompt_lines[-1]):]
178
+ return "\n".join(prompt_lines), "\n".join(generation_lines)
179
+
180
+
181
+ # pylint: disable=too-many-arguments
182
+ def fit_text(
183
+ font,
184
+ prompt: str,
185
+ generation: str,
186
+ max_width: int,
187
+ max_height: int,
188
+ spacing: int = 4,
189
+ scale_factor: float = 0.8,
190
+ max_iterations: int = 5,
191
+ direction: str = "ltr",
192
+ ) -> Tuple[FreeTypeFont, str]:
193
+ """
194
+ Automatically determines text wrapping and appropriate font size.
195
+ Tries to fit the text into a box using the given font at decreasing sizes,
196
+ based on ``scale_factor``. Makes ``max_iterations`` attempts.
197
+ If unable to find an appropriate font size within ``max_iterations``
198
+ attempts, wraps the text at the last attempted size.
199
+ :param font: Font to use.
200
+ :param text: Text to fit.
201
+ :param max_width: Maximum width of the final text, in pixels.
202
+ :param max_height: Maximum height height of the final text, in pixels.
203
+ :param spacing: The number of pixels between lines.
204
+ :param scale_factor:
205
+ :param max_iterations: Maximum number of attempts to try to fit the text.
206
+ :param direction: Direction of the text. It can be 'rtl' (right to
207
+ left), 'ltr' (left to right) or 'ttb' (top to bottom).
208
+ Requires libraqm.
209
+ :return: The font at the appropriate size and the wrapped text.
210
+ """
211
+ initial_font_size = font.size
212
+
213
+ # logger.debug('Trying to fit text "%s"', text)
214
+
215
+ for i in range(max_iterations):
216
+ trial_font_size = int(initial_font_size * pow(scale_factor, i))
217
+ trial_font = font.font_variant(size=trial_font_size)
218
+
219
+ logger.debug("Trying font size %i", trial_font_size)
220
+
221
+ wrapped = try_fit_text(trial_font, prompt, generation, max_width, max_height, spacing, direction)
222
+
223
+ if wrapped is not None:
224
+ logger.debug("Successfully fit text")
225
+ return (trial_font, wrapped)
226
+
227
+ # Give up and wrap the text at the last size
228
+ logger.debug("Gave up trying to fit text; just wrapping text")
229
+ wrapped = wrap_text(trial_font, prompt, max_width, direction) + wrap_text(
230
+ trial_font, generation, max_width, direction
231
+ )
232
+
233
+ return (trial_font, wrapped)
234
+
235
+
236
+
237
+ def main(
238
+ prompt,
239
+ generation,
240
+ width,
241
+ height,
242
+ assets_path,
243
+ font_path,
244
+ colors,
245
+ frame_to_box_margin,
246
+ text_to_text_box_margin,
247
+ init_font_size,
248
+ right_align = False
249
+ ):
250
+ # prompt_color = "#ffffff"
251
+ # text_color = "#FF57A0"
252
+ # text_box = ((500, 500, 2300, 1800))
253
+ # margin = 50
254
+ # input_color = "#CDD2E3"
255
+ right_align_params = {"direction":'rtl',"align":'right',"features":'rtla'} if right_align else {}
256
+ text_box_margin_r = frame_to_box_margin
257
+ text_box_margin_t = frame_to_box_margin
258
+ text_box_margin_l = frame_to_box_margin
259
+ text_box_margin_b = int(height / 4.5)
260
+
261
+ text_box = (
262
+ text_box_margin_l,
263
+ text_box_margin_t,
264
+ width - text_box_margin_r,
265
+ height - text_box_margin_b,
266
+ )
267
+
268
+ background = Image.new("RGB", (width, height), color=colors[BG_COMP])
269
+ # Get assets
270
+ assets_path = Path(assets_path)
271
+ flower_1 = Image.open(assets_path / "image_1.png").copy()
272
+ flower_2 = Image.open(assets_path / "image_2.png").copy()
273
+ shadow = Image.open(assets_path / "image_3.png").copy()
274
+ bloom_logo = Image.open(assets_path / "image_4.png").copy()
275
+ input_info = Image.open(assets_path / "image_7.png").copy()
276
+ output_info = Image.open(assets_path / "image_6.png").copy()
277
+ details = Image.open(assets_path / "image_5.png").copy()
278
+
279
+ flower_1_offsets = (int(width - flower_1.width * 2 / 3), int(-flower_1.height / 3))
280
+ background.paste(flower_1, flower_1_offsets, flower_1)
281
+
282
+ flower_2_offsets = (
283
+ -int(flower_2.width * 2 / 5),
284
+ int(height / 2 - flower_2.height / 2),
285
+ )
286
+ background.paste(flower_2, flower_2_offsets, flower_2)
287
+
288
+ bloom_offsets = (
289
+ frame_to_box_margin,
290
+ int(height - bloom_logo.height - frame_to_box_margin),
291
+ )
292
+ background.paste(bloom_logo, bloom_offsets, bloom_logo)
293
+
294
+ input_info_offsets = (
295
+ width - details.width - text_to_text_box_margin - input_info.width - output_info.width ,
296
+ int(height - details.height - frame_to_box_margin),
297
+ )
298
+ background.paste(input_info, input_info_offsets, input_info)
299
+
300
+ output_info_offsets =(
301
+ width - details.width - text_to_text_box_margin - input_info.width - output_info.width ,
302
+ int(height - details.height - frame_to_box_margin + input_info.height + text_to_text_box_margin),
303
+ )
304
+ background.paste(output_info, output_info_offsets, output_info)
305
+
306
+ details_offsets = (
307
+ width - frame_to_box_margin - details.width ,
308
+ int(height - details.height - frame_to_box_margin),
309
+ )
310
+ background.paste(details, details_offsets, details)
311
+ box_margin = (
312
+ text_box[0] + text_to_text_box_margin,
313
+ text_box[1] + text_to_text_box_margin,
314
+ text_box[2] - text_to_text_box_margin,
315
+ text_box[3] - text_to_text_box_margin,
316
+ )
317
+
318
+ # text_box = ImageText(box_margin)
319
+ drawing = ImageDraw.Draw(background, "RGBA")
320
+
321
+ # Text box for main text
322
+ input_color = colors[BOX_COMP][1:]
323
+ input_color = tuple(int(input_color[i : i + 2], 16) for i in (0, 2, 4)) + (
324
+ int(255 * 0.99),
325
+ ) # RGB + A
326
+ drawing.rounded_rectangle(text_box, outline="#000", fill=input_color, radius=47.9)
327
+
328
+ # Adapt text size to box
329
+ # font, (prompt_a, generation_a) = adapt_text_to_ratio(box_margin, prompt, generation)
330
+ # generation_a = adapt_text_to_ratio(box_margin, generation, prompt)
331
+ # font = adapt_font_to_text(box_margin, prompt_a + generation_a, font_path=font_path)
332
+
333
+ init_font = ImageFont.truetype(font_path, size=init_font_size)
334
+ font, (prompt_a, generation_a) = fit_text(
335
+ prompt=prompt,
336
+ generation=generation,
337
+ font=init_font,
338
+ max_height=box_margin[3] - box_margin[1],
339
+ max_width=box_margin[2] - box_margin[0],
340
+ max_iterations=50,
341
+ scale_factor=0.95,
342
+ direction="rtl" if right_align else "ltr"
343
+ )
344
+ # Prompt, main, then the last line
345
+ prompt_s = prompt_a.split("\n")
346
+ prompt_main = "\n".join(prompt_s[:-1])
347
+ prompt_lastline = prompt_s[-1]
348
+ drawing.multiline_text(
349
+ (box_margin[0], box_margin[1]),
350
+ prompt_main,
351
+ colors[PROMPT_VAR],
352
+ font,
353
+ **right_align_params
354
+ )
355
+ end_prompt_main = font.getsize_multiline(prompt_main)
356
+ end_prompt_last = font.getsize_multiline(prompt_lastline)
357
+ drawing.multiline_text(
358
+ ((box_margin[2] - end_prompt_last[0]) if right_align else box_margin[0], box_margin[1] + end_prompt_main[1]),
359
+ prompt_lastline,
360
+ colors[PROMPT_VAR],
361
+ font,
362
+ **right_align_params
363
+ )
364
+
365
+ # Generated text, first line, then the rest
366
+ generation_split = generation_a.split("\n")
367
+ generation_firstline = generation_split[0]
368
+ generation_main = "\n".join(generation_split[1:])
369
+ drawing.multiline_text(
370
+ # margin x + length(last line of prompt), margin Y + length(main part of prompt)
371
+ (box_margin[0], box_margin[1] + end_prompt_main[1]) if right_align else \
372
+ (box_margin[0] + end_prompt_last[0], box_margin[1] + end_prompt_main[1]),
373
+ generation_firstline,
374
+ colors[GENERATION_VAR],
375
+ font,
376
+ **right_align_params
377
+ )
378
+ drawing.multiline_text(
379
+ (box_margin[0], box_margin[1] + end_prompt_main[1] + end_prompt_last[1]),
380
+ generation_main,
381
+ colors[GENERATION_VAR],
382
+ font,
383
+ **right_align_params
384
+ )
385
+
386
+ final_font_size = font.size
387
+ return final_font_size, background