修改分割模块
This commit is contained in:
42
mask.py
42
mask.py
@ -901,22 +901,36 @@ class MicroplasticDetectorV2:
|
||||
self.filter_method = filter_method
|
||||
|
||||
def load_model(self):
|
||||
"""加载Cellpose模型"""
|
||||
if self.model is None:
|
||||
"""加载Cellpose模型(优先自定义,其次内置模型)"""
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
# 设备选择
|
||||
want_cpu = (self.device == 'cpu')
|
||||
use_gpu = (not want_cpu) and torch.cuda.is_available()
|
||||
|
||||
try:
|
||||
if self.model_path and Path(self.model_path).exists():
|
||||
# 加载自定义模型 - Cellpose 4.0 API
|
||||
# 正确方式:在构造时通过 pretrained_model 传入自定义权重
|
||||
self.model = models.CellposeModel(
|
||||
gpu=torch.cuda.is_available() and self.device != 'cpu',
|
||||
model_type=None
|
||||
gpu=use_gpu,
|
||||
pretrained_model=str(self.model_path),
|
||||
model_type=None, # 自定义模型不需要 model_type
|
||||
)
|
||||
self.model.load_model(self.model_path)
|
||||
print(f"已加载自定义模型: {self.model_path}")
|
||||
else:
|
||||
# 使用预训练模型 - Cellpose 4.0 API
|
||||
# 内置模型(如未安装 cpsam,可考虑改用 'cyto2')
|
||||
self.model = models.CellposeModel(
|
||||
gpu=torch.cuda.is_available() and self.device != 'cpu',
|
||||
gpu=use_gpu,
|
||||
model_type='cpsam'
|
||||
)
|
||||
print(f"模型已加载,使用设备: {'GPU' if torch.cuda.is_available() and self.device != 'cpu' else 'CPU'}")
|
||||
print("使用内置模型: cpsam")
|
||||
|
||||
print(f"模型已加载,使用设备: {'GPU' if use_gpu else 'CPU'}")
|
||||
except Exception as e:
|
||||
# 回退策略:若自定义/内置加载失败,退回CPU+cyto2以保证可用性
|
||||
print(f"模型加载失败({e}),回退到 CPU + cyto2")
|
||||
self.model = models.CellposeModel(gpu=False, model_type='cyto2')
|
||||
|
||||
def detect_microplastics(self, image_path: str, output_dir: str = None,
|
||||
diameter: float = 30, flow_threshold: float = 0.4,
|
||||
@ -975,7 +989,8 @@ class MicroplasticDetectorV2:
|
||||
masked_image,
|
||||
diameter=diameter,
|
||||
flow_threshold=flow_threshold,
|
||||
cellprob_threshold=cellprob_threshold)
|
||||
cellprob_threshold=cellprob_threshold,
|
||||
channels=[0, 0]) # 明确指定灰度图像通道
|
||||
|
||||
# 6. 分析检测结果
|
||||
if debug:
|
||||
@ -1247,8 +1262,9 @@ def detect_microplastic_mask_from_array(image, filter_method: str = 'shape',
|
||||
masked_image,
|
||||
diameter=diameter,
|
||||
flow_threshold=flow_threshold,
|
||||
cellprob_threshold=cellprob_threshold
|
||||
)
|
||||
cellprob_threshold=cellprob_threshold,
|
||||
channels=[0, 0]) # 明确指定灰度图像通道
|
||||
|
||||
|
||||
return masks, filter_mask_original
|
||||
|
||||
@ -1268,7 +1284,7 @@ def main():
|
||||
output_dir=output_dir + "_threshold",
|
||||
diameter=None,
|
||||
flow_threshold=0.4,
|
||||
cellprob_threshold=0,
|
||||
cellprob_threshold=-1,
|
||||
debug=True
|
||||
)
|
||||
print(f"\n霍夫圆变换方法检测完成!")
|
||||
|
||||
Reference in New Issue
Block a user