import numpy as np import torch import matplotlib.pyplot as plt from streamlit_image_coordinates import streamlit_image_coordinates import streamlit as st from PIL import Image from transformers import SamModel, SamProcessor import cv2 import os # Define global constants MAX_WIDTH = 700 # Define helpful functions def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=20): pos_points = coords[labels==1] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=0.2) def show_points_on_image(raw_image, input_point, ax, input_labels=None): ax.imshow(raw_image) input_point = np.array(input_point) if input_labels is None: labels = np.ones_like(input_point[:, 0]) else: labels = np.array(input_labels) show_points(input_point, labels, ax) ax.axis('on') # Get SAM if torch.cuda.is_available(): device = 'cuda' else: device = 'cpu' model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") # Get uploaded files from user scale = st.file_uploader('Upload Scale Image') image = st.file_uploader('Upload Particle Image') # Runs when scale image is uploaded if scale: scale_np = np.asarray(bytearray(scale.read()), dtype=np.uint8) scale_np = cv2.imdecode(scale_np, 1) # Save image if it isn't already saved if not os.path.exists(scale.name): with open(scale.name, "wb") as f: f.write(scale.getbuffer()) scale_pil = Image.open(scale.name) # Remove file when done ###os.remove(scale.name) #inputs = processor(raw_image, return_tensors="pt").to(device) inputs = processor(scale_np, return_tensors="pt").to(device) image_embeddings = model.get_image_embeddings(inputs["pixel_values"]) scale_factor = scale_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension #clicked_point = streamlit_image_coordinates(Image.open(scale.name), height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH) clicked_point = streamlit_image_coordinates(scale_pil, height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH) if clicked_point: input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor input_point_list = [input_point_np.astype(int).tolist()] #inputs = processor(raw_image, input_points=input_point, return_tensors="pt").to(device) inputs = processor(scale_np, input_points=input_point_list, return_tensors="pt").to(device) inputs.pop("pixel_values", None) inputs.update({"image_embeddings": image_embeddings}) with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) mask = torch.squeeze(masks[0])[0] # mask.shape: (1,x,y) --> (x,y) mask = mask.to(torch.int) input_label = np.array([1]) fig, ax = plt.subplots() ax.imshow(scale_np) show_mask(mask, ax) #show_points_on_image(scale_np, input_point, input_label, ax) show_points(input_point_np, input_label, ax) ax.axis('off') st.pyplot(fig) # Get pixels per millimeter pixels_per_unit = torch.sum(mask, axis=1) pixels_per_unit = pixels_per_unit[pixels_per_unit > 0] pixels_per_unit = torch.mean(pixels_per_unit, dtype=torch.float).item() # Runs when image is uploaded if image: image_np = np.asarray(bytearray(image.read()), dtype=np.uint8) image_np = cv2.imdecode(image_np, 1) # Save image if it isn't already saved if not os.path.exists(image.name): with open(image.name, "wb") as f: f.write(image.getbuffer()) image_pil = Image.open(image.name) # Remove file when done ###os.remove(image.name) #inputs = processor(raw_image, return_tensors="pt").to(device) inputs = processor(image_np, return_tensors="pt").to(device) image_embeddings = model.get_image_embeddings(inputs["pixel_values"]) scale_factor = image_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension clicked_point = streamlit_image_coordinates(image_pil, height=image_np.shape[0] // scale_factor, width=MAX_WIDTH) if clicked_point: input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor input_point_list = [input_point_np.astype(int).tolist()] #inputs = processor(raw_image, input_points=input_point, return_tensors="pt").to(device) inputs = processor(image_np, input_points=input_point_list, return_tensors="pt").to(device) inputs.pop("pixel_values", None) inputs.update({"image_embeddings": image_embeddings}) with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) mask = torch.squeeze(masks[0])[0] # mask.shape: (1,x,y) --> (x,y) mask = mask.to(torch.int) input_label = np.array([1]) fig, ax = plt.subplots() ax.imshow(image_np) show_mask(mask, ax) #show_points_on_image(scale_np, input_point, input_label, ax) show_points(input_point_np, input_label, ax) ax.axis('off') st.pyplot(fig) # Get the area in square millimeters st.write(f'Area: {torch.sum(mask, dtype=torch.float).item() / pixels_per_unit ** 2} mm^2')