jianlong-yuan commited on
Commit
a9d6e30
β€’
1 Parent(s): 80e7272
Files changed (8) hide show
  1. README.md +11 -5
  2. app.py +282 -0
  3. flower.jpg +0 -0
  4. forbidden_city.webp +0 -0
  5. house.png +3 -0
  6. pizza.jpg +0 -0
  7. sunset.jpg +0 -0
  8. utils.py +27 -0
README.md CHANGED
@@ -1,10 +1,16 @@
1
  ---
2
- title: Blip2 T
3
- emoji: 🏒
4
- colorFrom: gray
5
- colorTo: gray
6
- sdk: static
 
 
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: BLIP2
3
+ emoji: πŸŒ–
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.17.0
8
+ app_file: app.py
9
  pinned: false
10
+ license: bsd-3-clause
11
+ models:
12
+ - Salesforce/blip2-opt-6.7b
13
+ - Salesforce/blip2-flan-t5-xxl
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import string
4
+ import gradio as gr
5
+ import requests
6
+ from utils import Endpoint, get_token
7
+
8
+
9
+ def encode_image(image):
10
+ buffered = BytesIO()
11
+ image.save(buffered, format="JPEG")
12
+ buffered.seek(0)
13
+
14
+ return buffered
15
+
16
+
17
+ def query_chat_api(
18
+ image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
19
+ ):
20
+
21
+ url = endpoint.url
22
+ url = url + "/api/generate"
23
+
24
+ headers = {
25
+ "User-Agent": "BLIP-2 HuggingFace Space",
26
+ "Auth-Token": get_token(),
27
+ }
28
+
29
+ data = {
30
+ "prompt": prompt,
31
+ "use_nucleus_sampling": decoding_method == "Nucleus sampling",
32
+ "temperature": temperature,
33
+ "length_penalty": len_penalty,
34
+ "repetition_penalty": repetition_penalty,
35
+ }
36
+
37
+ image = encode_image(image)
38
+ files = {"image": image}
39
+
40
+ response = requests.post(url, data=data, files=files, headers=headers)
41
+
42
+ if response.status_code == 200:
43
+ return response.json()
44
+ else:
45
+ return "Error: " + response.text
46
+
47
+
48
+ def query_caption_api(
49
+ image, decoding_method, temperature, len_penalty, repetition_penalty
50
+ ):
51
+
52
+ url = endpoint.url
53
+ url = url + "/api/caption"
54
+
55
+ headers = {
56
+ "User-Agent": "BLIP-2 HuggingFace Space",
57
+ "Auth-Token": get_token(),
58
+ }
59
+
60
+ data = {
61
+ "use_nucleus_sampling": decoding_method == "Nucleus sampling",
62
+ "temperature": temperature,
63
+ "length_penalty": len_penalty,
64
+ "repetition_penalty": repetition_penalty,
65
+ }
66
+
67
+ image = encode_image(image)
68
+ files = {"image": image}
69
+
70
+ response = requests.post(url, data=data, files=files, headers=headers)
71
+
72
+ if response.status_code == 200:
73
+ return response.json()
74
+ else:
75
+ return "Error: " + response.text
76
+
77
+
78
+ def postprocess_output(output):
79
+ # if last character is not a punctuation, add a full stop
80
+ if not output[0][-1] in string.punctuation:
81
+ output[0] += "."
82
+
83
+ return output
84
+
85
+
86
+ def inference_chat(
87
+ image,
88
+ text_input,
89
+ decoding_method,
90
+ temperature,
91
+ length_penalty,
92
+ repetition_penalty,
93
+ history=[],
94
+ ):
95
+ text_input = text_input
96
+ history.append(text_input)
97
+
98
+ prompt = " ".join(history)
99
+
100
+ output = query_chat_api(
101
+ image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
102
+ )
103
+ output = postprocess_output(output)
104
+ history += output
105
+
106
+ chat = [
107
+ (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
108
+ ] # convert to tuples of list
109
+
110
+ return {chatbot: chat, state: history}
111
+
112
+
113
+ def inference_caption(
114
+ image,
115
+ decoding_method,
116
+ temperature,
117
+ length_penalty,
118
+ repetition_penalty,
119
+ ):
120
+ output = query_caption_api(
121
+ image, decoding_method, temperature, length_penalty, repetition_penalty
122
+ )
123
+
124
+ return output[0]
125
+
126
+
127
+ title = """<h1 align="center">BLIP-2</h1>"""
128
+ description = """Gradio demo for BLIP-2, image-to-text generation from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them.
129
+ <br> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected."""
130
+ article = """<strong>Paper</strong>: <a href='https://arxiv.org/abs/2301.12597' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>
131
+ <br> <strong>Code</strong>: BLIP2 is now integrated into GitHub repo: <a href='https://github.com/salesforce/LAVIS' target='_blank'>LAVIS: a One-stop Library for Language and Vision</a>
132
+ <br> <strong>πŸ€— `transformers` integration</strong>: You can now use `transformers` to use our BLIP-2 models! Check out the <a href='https://huggingface.co/docs/transformers/main/en/model_doc/blip-2' target='_blank'> official docs </a>
133
+ <p> <strong>Project Page</strong>: <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'> BLIP2 on LAVIS</a>
134
+ <br> <strong>Description</strong>: Captioning results from <strong>BLIP2_OPT_6.7B</strong>. Chat results from <strong>BLIP2_FlanT5xxl</strong>.
135
+ """
136
+
137
+ endpoint = Endpoint()
138
+
139
+ examples = [
140
+ ["house.png", "How could someone get out of the house?"],
141
+ ["flower.jpg", "Question: What is this flower and where is it's origin? Answer:"],
142
+ ["pizza.jpg", "What are steps to cook it?"],
143
+ ["sunset.jpg", "Here is a romantic message going along the photo:"],
144
+ ["forbidden_city.webp", "In what dynasties was this place built?"],
145
+ ]
146
+
147
+ with gr.Blocks(
148
+ css="""
149
+ .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
150
+ #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
151
+ """
152
+ ) as iface:
153
+ state = gr.State([])
154
+
155
+ gr.Markdown(title)
156
+ gr.Markdown(description)
157
+ gr.Markdown(article)
158
+
159
+ with gr.Row():
160
+ with gr.Column(scale=1):
161
+ image_input = gr.Image(type="pil")
162
+
163
+ # with gr.Row():
164
+ sampling = gr.Radio(
165
+ choices=["Beam search", "Nucleus sampling"],
166
+ value="Beam search",
167
+ label="Text Decoding Method",
168
+ interactive=True,
169
+ )
170
+
171
+ temperature = gr.Slider(
172
+ minimum=0.5,
173
+ maximum=1.0,
174
+ value=1.0,
175
+ step=0.1,
176
+ interactive=True,
177
+ label="Temperature (used with nucleus sampling)",
178
+ )
179
+
180
+ len_penalty = gr.Slider(
181
+ minimum=-1.0,
182
+ maximum=2.0,
183
+ value=1.0,
184
+ step=0.2,
185
+ interactive=True,
186
+ label="Length Penalty (set to larger for longer sequence, used with beam search)",
187
+ )
188
+
189
+ rep_penalty = gr.Slider(
190
+ minimum=1.0,
191
+ maximum=5.0,
192
+ value=1.5,
193
+ step=0.5,
194
+ interactive=True,
195
+ label="Repeat Penalty (larger value prevents repetition)",
196
+ )
197
+
198
+ with gr.Column(scale=1.8):
199
+
200
+ with gr.Column():
201
+ caption_output = gr.Textbox(lines=1, label="Caption Output")
202
+ caption_button = gr.Button(
203
+ value="Caption it!", interactive=True, variant="primary"
204
+ )
205
+ caption_button.click(
206
+ inference_caption,
207
+ [
208
+ image_input,
209
+ sampling,
210
+ temperature,
211
+ len_penalty,
212
+ rep_penalty,
213
+ ],
214
+ [caption_output],
215
+ )
216
+
217
+ gr.Markdown("""Trying prompting your input for chat; e.g. example prompt for QA, \"Question: {} Answer:\" Use proper punctuation (e.g., question mark).""")
218
+ with gr.Row():
219
+ with gr.Column(
220
+ scale=1.5,
221
+ ):
222
+ chatbot = gr.Chatbot(
223
+ label="Chat Output (from FlanT5)",
224
+ )
225
+
226
+ # with gr.Row():
227
+ with gr.Column(scale=1):
228
+ chat_input = gr.Textbox(lines=1, label="Chat Input")
229
+ chat_input.submit(
230
+ inference_chat,
231
+ [
232
+ image_input,
233
+ chat_input,
234
+ sampling,
235
+ temperature,
236
+ len_penalty,
237
+ rep_penalty,
238
+ state,
239
+ ],
240
+ [chatbot, state],
241
+ )
242
+
243
+ with gr.Row():
244
+ clear_button = gr.Button(value="Clear", interactive=True)
245
+ clear_button.click(
246
+ lambda: ("", [], []),
247
+ [],
248
+ [chat_input, chatbot, state],
249
+ queue=False,
250
+ )
251
+
252
+ submit_button = gr.Button(
253
+ value="Submit", interactive=True, variant="primary"
254
+ )
255
+ submit_button.click(
256
+ inference_chat,
257
+ [
258
+ image_input,
259
+ chat_input,
260
+ sampling,
261
+ temperature,
262
+ len_penalty,
263
+ rep_penalty,
264
+ state,
265
+ ],
266
+ [chatbot, state],
267
+ )
268
+
269
+ image_input.change(
270
+ lambda: ("", "", []),
271
+ [],
272
+ [chatbot, caption_output, state],
273
+ queue=False,
274
+ )
275
+
276
+ examples = gr.Examples(
277
+ examples=examples,
278
+ inputs=[image_input, chat_input],
279
+ )
280
+
281
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
282
+ iface.launch(enable_queue=True)
flower.jpg ADDED
forbidden_city.webp ADDED
house.png ADDED

Git LFS Details

  • SHA256: a7b8999524f8f178a43d3417b9f7dfa80d8aff7ccb7ea1b5ba0e5f96bc17bdc0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
pizza.jpg ADDED
sunset.jpg ADDED
utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class Endpoint:
5
+ def __init__(self):
6
+ self._url = None
7
+
8
+ @property
9
+ def url(self):
10
+ if self._url is None:
11
+ self._url = self.get_url()
12
+
13
+ return self._url
14
+
15
+ def get_url(self):
16
+ endpoint = os.environ.get("endpoint")
17
+
18
+ return endpoint
19
+
20
+
21
+ def get_token():
22
+ token = os.environ.get("auth_token")
23
+
24
+ if token is None:
25
+ raise ValueError("auth-token not found in environment variables")
26
+
27
+ return token