hysts HF staff commited on
Commit
661a9c1
1 Parent(s): 5336236

Process examples when loaded

Browse files
Files changed (1) hide show
  1. app.py +67 -15
app.py CHANGED
@@ -297,23 +297,11 @@ class ImageConductor:
297
  guidance_scale,
298
  num_inference_steps,
299
  personalized,
300
- examples_type,
301
  ):
302
  print("Run!")
303
- if examples_type != "":
304
- ### for adapting high version gradio
305
- tracking_points = gr.State([])
306
- first_frame_path = IMAGE_PATH[examples_type]
307
- points = json.load(open(POINTS[examples_type]))
308
- tracking_points.value.extend(points)
309
- print("example first_frame_path", first_frame_path)
310
- print("example tracking_points", tracking_points.value)
311
 
312
  original_width, original_height = 384, 256
313
- if isinstance(tracking_points, list):
314
- input_all_points = tracking_points
315
- else:
316
- input_all_points = tracking_points.value
317
 
318
  print("input_all_points", input_all_points)
319
  resized_all_points = [
@@ -415,7 +403,7 @@ class ImageConductor:
415
  # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
416
  # save_videos_grid(sample[0][None], outputs_path)
417
  print("Done!")
418
- return {output_image: visualized_drag, output_video: outputs_path}
419
 
420
 
421
  def reset_states(first_frame_path, tracking_points):
@@ -487,6 +475,54 @@ def add_tracking_points(
487
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
488
 
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  def add_drag(tracking_points):
491
  if not tracking_points or tracking_points[-1]:
492
  tracking_points.append([])
@@ -571,6 +607,15 @@ def delete_last_step(tracking_points, first_frame_path, drag_mode):
571
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
572
 
573
 
 
 
 
 
 
 
 
 
 
574
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
575
  ImageConductor_net = ImageConductor(
576
  device=device,
@@ -725,9 +770,16 @@ with block:
725
  guidance_scale,
726
  num_inference_steps,
727
  personalized,
728
- examples_type,
729
  ],
730
  [output_image, output_video],
731
  )
732
 
 
 
 
 
 
 
 
 
733
  block.queue().launch()
 
297
  guidance_scale,
298
  num_inference_steps,
299
  personalized,
 
300
  ):
301
  print("Run!")
 
 
 
 
 
 
 
 
302
 
303
  original_width, original_height = 384, 256
304
+ input_all_points = tracking_points
 
 
 
305
 
306
  print("input_all_points", input_all_points)
307
  resized_all_points = [
 
403
  # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
404
  # save_videos_grid(sample[0][None], outputs_path)
405
  print("Done!")
406
+ return visualized_drag, outputs_path
407
 
408
 
409
  def reset_states(first_frame_path, tracking_points):
 
475
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
476
 
477
 
478
+ def preprocess_example_image(image_path, tracking_points, drag_mode):
479
+ image_pil = image2pil(image_path)
480
+ raw_w, raw_h = image_pil.size
481
+ resize_ratio = max(384 / raw_w, 256 / raw_h)
482
+ image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
483
+ image_pil = transforms.CenterCrop((256, 384))(image_pil.convert("RGB"))
484
+ id = str(uuid.uuid4())[:4]
485
+ first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
486
+ image_pil.save(first_frame_path, quality=95)
487
+
488
+ if drag_mode == "object":
489
+ color = (255, 0, 0, 255)
490
+ elif drag_mode == "camera":
491
+ color = (0, 0, 255, 255)
492
+
493
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
494
+ w, h = transparent_background.size
495
+ transparent_layer = np.zeros((h, w, 4))
496
+
497
+ for track in tracking_points:
498
+ if len(track) > 1:
499
+ for i in range(len(track) - 1):
500
+ start_point = track[i]
501
+ end_point = track[i + 1]
502
+ vx = end_point[0] - start_point[0]
503
+ vy = end_point[1] - start_point[1]
504
+ arrow_length = np.sqrt(vx**2 + vy**2)
505
+ if i == len(track) - 2:
506
+ cv2.arrowedLine(
507
+ transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
508
+ )
509
+ else:
510
+ cv2.line(
511
+ transparent_layer,
512
+ tuple(start_point),
513
+ tuple(end_point),
514
+ color,
515
+ 2,
516
+ )
517
+ else:
518
+ cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
519
+
520
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
521
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
522
+
523
+ return trajectory_map, first_frame_path
524
+
525
+
526
  def add_drag(tracking_points):
527
  if not tracking_points or tracking_points[-1]:
528
  tracking_points.append([])
 
607
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
608
 
609
 
610
+ def load_example(drag_mode, examples_type):
611
+ example_image_path = IMAGE_PATH[examples_type]
612
+ with open(POINTS[examples_type]) as f:
613
+ tracking_points = json.load(f)
614
+ tracking_points = np.round(tracking_points).astype(int).tolist()
615
+ trajectory_map, first_frame_path = preprocess_example_image(example_image_path, tracking_points, drag_mode)
616
+ return {input_image: trajectory_map, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
617
+
618
+
619
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
620
  ImageConductor_net = ImageConductor(
621
  device=device,
 
770
  guidance_scale,
771
  num_inference_steps,
772
  personalized,
 
773
  ],
774
  [output_image, output_video],
775
  )
776
 
777
+ examples_type.change(
778
+ fn=load_example,
779
+ inputs=[drag_mode, examples_type],
780
+ outputs=[input_image, first_frame_path_var, tracking_points_var],
781
+ api_name=False,
782
+ queue=False,
783
+ )
784
+
785
  block.queue().launch()