当前位置: 首页 > news >正文

基于SAM大模型的遥感影像分割工具,用于创建交互式标注、识别地物的能力,可利用Flask进行封装作为Web后台服务

如有帮助,支持一下(GitHub - Lvbta/ImageSegmentationTool-SAM: An interactive annotation case developed based on SAM for remote sensing image annotation, which can generate corresponding segmentation results based on point, multi-point, and rectangular box prompts, and convert the recognition results into vector data shp.)

本项目提供了一个图像分割工具,利用 Segment Anything Model (SAM) 对大规模的卫星或航拍图像进行分割。该工具支持通过单点、多点或边界框输入进行图像分割,并将分割结果保存为 shapefile,以便进一步进行地理空间分析。

功能特点

  • 单点分割:支持基于单个点的输入进行分割。
  • 多点分割:支持使用多个点进行分割。
  • 边界框分割:支持在指定的边界框内进行分割。
  • 地理空间集成:使用 GDAL 读取地理空间图像,并将分割的掩膜转换为多边形。
  • Shapefile 导出:将分割结果保存为 shapefile,方便与 GIS 工具集成。
  • 可视化:在原始图像上可视化分割结果,便于验证和分析。

安装

  1. 克隆仓库:

    git clone https://github.com/Lvbta/ImageSegmentationTool.git
    cd ImageSegmentationTool
  2. 下载SAM权重:

    • default or vit_h: ViT-H SAM model.
    • vit_l: ViT-L SAM model.
    • vit_b: ViT-B SAM model.
  3. 安装所需的依赖:

    pip install -r requirements.txt
  4. 设置环境变量:

    • 代码内已设置 KMP_DUPLICATE_LIB_OK 变量,以避免冲突。

使用方法

步骤 1:准备数据

  • 图像:确保您拥有地理参考的卫星或航拍图像,格式为 TIFF。
  • SAM 模型检查点:下载 SAM 模型检查点文件,并将其放置在项目目录中。

步骤 2:配置参数

在脚本中设置以下参数:

  • image_path: 您的地理参考图像文件的路径(例如 ./sentinel2.tif)。
  • sam_checkpoint: 您的 SAM 模型检查点文件的路径(例如 ./sam_vit_b_01ec64.pth)。
  • model_type: 用于分割的模型类型(vit_bvit_l 等)。
  • device: 用于运行模型的设备(cpu 或 cuda)。
  • output_shp: 保存输出 shapefile 的路径。

步骤 3:运行分割

选择分割模式并指定必要的输入点或边界框:

  • 单点模式

    seg_mode = 'single_point'
    input_points = [[1248, 1507]]
    single_label = [1]
  • 多点模式

    seg_mode = 'multi_point'
    input_points = [[389, 1041],[411, 1094]]
    single_label = [1, 1]
  • 边界框模式

    seg_mode = 'box'
    input_box = [[0, 951, 1909, 2383]]
    single_label = [1]

步骤 4:执行脚本

运行脚本以进行分割:

python main.py

步骤 5:可视化并保存结果

分割的掩膜将被可视化,多边形将作为 shapefile 保存到指定位置。

示例

使用边界框对图像进行分割,脚本配置如下:

# 边界框模式示例配置
seg_mode = 'box'
input_box = [[0, 951, 1909, 2383]]
single_label = [1]segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device)
masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box, input_labels=single_label, multimask_output=True)
polygons = segmenter.masks_to_polygons(masks, x_off, y_off)
segmenter.save_polygons_gdal(polygons, output_shp)
segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)
import numpy as np
import torch
import cv2
import sys
from osgeo import gdal, ogr, osr
from shapely.geometry import Polygon
from shapely.wkb import dumps
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
plt.rcParams['font.sans-serif'] = 'SimHei'  # 设置中文显示
plt.rcParams['axes.unicode_minus'] = False
# plt.style.use('ggplot')class ImageSegmentation:def __init__(self, image_path, sam_checkpoint, model_type='vit_b', device='cpu'):self.image_path = image_pathself.sam_checkpoint = sam_checkpointself.model_type = model_typeself.device = deviceself.geo_transform, self.proj = self.get_geoinfo()self.sam = self.load_sam_model()self.predictor = self.init_predictor()def get_geoinfo(self):dataset = gdal.Open(self.image_path)geo_transform = dataset.GetGeoTransform()proj = dataset.GetProjection()dataset = None  # 关闭return geo_transform, projdef read_image_chunk(self, x_off, y_off, x_size, y_size):dataset = gdal.Open(self.image_path)image = dataset.ReadAsArray(x_off, y_off, x_size, y_size)dataset = None  # 关闭if len(image.shape) == 3:image = np.transpose(image, (1, 2, 0))  # GDAL reads in (bands, height, width) formatelse:image = np.stack([image] * 3, axis=-1)  # If it's a single-band image, stack to (height, width, 3)return imagedef load_sam_model(self):sys.path.append("..")from segment_anything import sam_model_registrysam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)sam.to(device=self.device)return samdef init_predictor(self):from segment_anything import SamPredictorpredictor = SamPredictor(self.sam)return predictordef predict(self, mode='single_point', input_points=None, input_labels=None, input_box=None, multimask_output=None):if mode == 'single_point':assert input_points is not None and input_labels is not None, "Points and labels are required for single point mode."x, y = input_points[0]chunk_size = 512  # or any appropriate sizex_off = max(x - chunk_size // 2, 0)y_off = max(y - chunk_size // 2, 0)x_size = y_size = chunk_sizeimage_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_points = [(x - x_off, y - y_off)]masks, scores, logits = self.predictor.predict(point_coords=np.array(adjusted_points),point_labels=np.array(input_labels),multimask_output=multimask_output,)elif mode == 'multi_point':assert input_points is not None and input_labels is not None, "Points and labels are required for multi point mode."# Determine bounding box of all pointsx_min = min(p[0] for p in input_points)y_min = min(p[1] for p in input_points)x_max = max(p[0] for p in input_points)y_max = max(p[1] for p in input_points)margin = 256  # or any appropriate marginx_off = max(x_min - margin, 0)y_off = max(y_min - margin, 0)x_size = min(x_max - x_min + 2 * margin, 2048)y_size = min(y_max - y_min + 2 * margin, 2048)image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_points = [(x - x_off, y - y_off) for x, y in input_points]masks, scores, logits = self.predictor.predict(point_coords=np.array(adjusted_points),point_labels=np.array(input_labels),multimask_output=multimask_output,)elif mode == 'box':assert input_box is not None, "Box coordinates are required for box mode."x_min, y_min, x_max, y_max = input_box[0]margin = 256  # or any appropriate marginx_off = max(x_min - margin, 0)y_off = max(y_min - margin, 0)x_size = min(x_max - x_min + 2 * margin, 2048)y_size = min(y_max - y_min + 2 * margin, 2048)image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_box = [(x_min - x_off, y_min - y_off, x_max - x_off, y_max - y_off)]masks, scores, logits = self.predictor.predict(box=np.array(adjusted_box).reshape(1, -1),multimask_output=multimask_output,)else:raise ValueError("Mode must be 'single_point', 'multi_point', or 'box'.")return masks, scores, x_off, y_offdef masks_to_polygons(self, masks, x_off, y_off):polygons = []for mask in masks:contours, _ = cv2.findContours((mask > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)for contour in contours:contour = contour.squeeze()if len(contour.shape) == 2 and len(contour) >= 3:  # valid polygongeo_contour = [self.pixel_to_geo(x + x_off, y + y_off) for x, y in contour]polygon = Polygon(geo_contour)if polygon.is_valid:polygons.append(polygon)return polygonsdef pixel_to_geo(self, x, y):geox = self.geo_transform[0] + x * self.geo_transform[1] + y * self.geo_transform[2]geoy = self.geo_transform[3] + x * self.geo_transform[4] + y * self.geo_transform[5]return geox, geoydef save_polygons_gdal(self, polygons, output_shp):driver = ogr.GetDriverByName("ESRI Shapefile")data_source = driver.CreateDataSource(output_shp)spatial_ref = osr.SpatialReference()spatial_ref.ImportFromWkt(self.proj)  # 使用图像的投影信息layer = data_source.CreateLayer("segmentation", spatial_ref, ogr.wkbPolygon)layer_defn = layer.GetLayerDefn()for i, polygon in enumerate(polygons):feature = ogr.Feature(layer_defn)geom_wkb = dumps(polygon)  # 将Shapely几何对象转换为WKBogr_geom = ogr.CreateGeometryFromWkb(geom_wkb)  # 从WKB创建OGR几何对象feature.SetGeometry(ogr_geom)feature.SetField("id", i + 1)layer.CreateFeature(feature)feature = Nonedata_source = Nonedef show_masks(self, mode, masks, scores,x_off, y_off, input_point, input_label, image):for i, (mask, score) in enumerate(zip(masks, scores)):plt.figure(figsize=(10, 10))plt.imshow(image)self.show_mask(mask, plt.gca())if mode == 'box':self.show_box(np.array(input_point[0]), plt.gca(), x_off, y_off)else:self.show_points(np.array(input_point), np.array(input_label), plt.gca(), x_off, y_off)plt.title(f"{mode}模式 {i + 1}, Score: {score:.3f}", fontsize=18)plt.axis('on')plt.show()def show_mask(self, mask, ax, x_off=0, y_off=0):mask_resized = np.zeros((mask.shape[0] + y_off, mask.shape[1] + x_off), dtype=np.uint8)mask_resized[y_off:y_off + mask.shape[0], x_off:x_off + mask.shape[1]] = mask.astype(np.uint8)contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)for contour in contours:contour[:, :, 0] += x_offcontour[:, :, 1] += y_offax.plot(contour[:, 0, 0], contour[:, 0, 1], color='lime', linewidth=2)def show_points(self, points, labels, ax, x_off, y_off):for point, label in zip(points, labels):x, y = pointx -= x_off  y -= y_off  ax.scatter(x, y, c='red', marker='o', label=f'Label: {label}')@staticmethoddef show_box(box, ax, x_off, y_off):x0, y0 = box[0]-x_off, box[1]-y_offw, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2))if __name__ == '__main__':# Usageimage_path = r'./data/sentinel2.tif'sam_checkpoint = "./model/sam_vit_b_01ec64.pth"model_type = "vit_b"device = "cpu"output_shp = r'./result/segmentation_results.shp'# # 预测模式# seg_mode = 'single_point'# # # 模型参数# input_points = [[1248, 1507]]# single_label = [1]# # 预测模式# seg_mode = 'multi_point'# # 模型参数# input_points = [[389, 1041],[411, 1094]]# single_label = [1, 1]# 预测模式seg_mode = 'box'# 模型参数input_box = [[0, 951, 1909, 2383]]single_label = [1]# 实例化类segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device)# # 调用segAnything模型# masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_points=input_points,#                                                     input_labels=single_label, multimask_output=False)# boxmasks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box,input_labels=single_label, multimask_output=True)# 模型预测结果转矢量多边形polygons = segmenter.masks_to_polygons(masks, x_off, y_off)# 保存为shpsegmenter.save_polygons_gdal(polygons, output_shp)# 可视化image_chunk = segmenter.read_image_chunk(x_off, y_off, 512, 512)# segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_points, single_label, image_chunk)# boxsegmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 利用H5无插件播放RTSP流的实现方案
  • 【二等奖论文】2024年华为杯研究生数学建模F题成品论文(后续会更新)
  • 搜维尔科技:Unity中的A.R.T.测量工具
  • Spring Cloud Alibaba-(4)Sentinel【流控和降级】
  • C# 入坑JAVA 潜规则 大小写敏感文件名和类名 枚举等 入门系列2
  • 策略模式在 Spring Boot 框架中的应用
  • 实验3 Hadoop集群运行环境搭建和使用
  • 创建索引遇到这个Bug,19c中还没有修复
  • 粒子向上持续瀑布动画效果(直接粘贴到记事本改html即可)
  • 【AI实战攻略】保姆级教程:用AI打造治愈动画vlog,轻松打造爆款,快速涨粉!
  • maxcompute使用篇
  • 8. 防火墙
  • Nginx从入门到入土(二): 学习内容与安装
  • LeetCode 面试经典150题 190.颠倒二进制位
  • 微服务Docker相关指令
  • 【347天】每日项目总结系列085(2018.01.18)
  • 2017 前端面试准备 - 收藏集 - 掘金
  • Android优雅地处理按钮重复点击
  • Date型的使用
  • Debian下无root权限使用Python访问Oracle
  • ES6系列(二)变量的解构赋值
  • JS函数式编程 数组部分风格 ES6版
  • Otto开发初探——微服务依赖管理新利器
  • Puppeteer:浏览器控制器
  • Python 使用 Tornado 框架实现 WebHook 自动部署 Git 项目
  • Solarized Scheme
  • Wamp集成环境 添加PHP的新版本
  • 对象管理器(defineProperty)学习笔记
  • 浮现式设计
  • 关于 Linux 进程的 UID、EUID、GID 和 EGID
  • 可能是历史上最全的CC0版权可以免费商用的图片网站
  • 坑!为什么View.startAnimation不起作用?
  • 力扣(LeetCode)357
  • 码农张的Bug人生 - 见面之礼
  • 你不可错过的前端面试题(一)
  • 使用Tinker来调试Laravel应用程序的数据以及使用Tinker一些总结
  • 原生js练习题---第五课
  • ​用户画像从0到100的构建思路
  • ### RabbitMQ五种工作模式:
  • #NOIP 2014# day.2 T2 寻找道路
  • (55)MOS管专题--->(10)MOS管的封装
  • (CVPRW,2024)可学习的提示:遥感领域小样本语义分割
  • (C语言)fread与fwrite详解
  • (Java)【深基9.例1】选举学生会
  • (k8s中)docker netty OOM问题记录
  • (备份) esp32 GPIO
  • (代码示例)使用setTimeout来延迟加载JS脚本文件
  • (二)windows配置JDK环境
  • (二)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB
  • (附源码)ssm基于微信小程序的疫苗管理系统 毕业设计 092354
  • (九)c52学习之旅-定时器
  • (每日一问)设计模式:设计模式的原则与分类——如何提升代码质量?
  • (四)stm32之通信协议
  • (一)插入排序
  • (转)jQuery 基础