修改分割模块
This commit is contained in:
764
main_batch_nosample.py
Normal file
764
main_batch_nosample.py
Normal file
@ -0,0 +1,764 @@
|
||||
import os
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from bil2rgb import process_bil_files
|
||||
from classification_model.Parallel.predict_plastic import load_model, predict_with_model
|
||||
from mask import detect_microplastic_mask_from_array
|
||||
from shape_spectral import process_images
|
||||
from shape_spectral_background import process_images_background
|
||||
from extact_shape import shape_correct_background, extract_features
|
||||
import time
|
||||
##批量预测文件夹内的bil文件,不进行降采样,使用新采集的数据进行训练
|
||||
matplotlib.use('TkAgg')
|
||||
|
||||
# 训练相机波长(168通道)
|
||||
TRAIN_WAVELENGTHS = [
|
||||
898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2
|
||||
]
|
||||
|
||||
def read_wavelengths_from_hdr(bil_path):
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
if not os.path.exists(hdr_path):
|
||||
return np.array([], dtype=np.float64)
|
||||
with open(hdr_path, 'r') as f:
|
||||
txt = f.read()
|
||||
if 'wavelength' not in txt:
|
||||
return np.array([], dtype=np.float64)
|
||||
seg = txt.split('wavelength', 1)[1]
|
||||
seg = seg[seg.find('{')+1: seg.find('}')]
|
||||
vals = [v.strip() for v in seg.split(',') if v.strip()]
|
||||
try:
|
||||
waves = np.array([float(v) for v in vals], dtype=np.float64)
|
||||
except Exception:
|
||||
waves = np.array([], dtype=np.float64)
|
||||
return waves
|
||||
|
||||
def resample_spectra_matrix(X, src_waves, dst_waves):
|
||||
src = np.asarray(src_waves, dtype=np.float64)
|
||||
dst = np.asarray(dst_waves, dtype=np.float64)
|
||||
X = np.asarray(X, dtype=np.float64)
|
||||
if src.size == 0 or dst.size == 0:
|
||||
return X
|
||||
# 线性插值,越界用端点外推,避免维度缺失
|
||||
out = np.empty((X.shape[0], dst.size), dtype=np.float64)
|
||||
for i in range(X.shape[0]):
|
||||
row = X[i]
|
||||
out[i] = np.interp(dst, src, row, left=row[0], right=row[-1])
|
||||
return out
|
||||
|
||||
|
||||
def apply_background_no_resample(df, bg_spectrum):
|
||||
"""
|
||||
仅做背景校正,不做任何重采样。
|
||||
- 自动选择以 wavelength_ 或 band_ 开头的光谱列
|
||||
- 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正
|
||||
"""
|
||||
# 识别光谱列
|
||||
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
|
||||
if not spec_cols:
|
||||
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
|
||||
|
||||
bg = np.asarray(bg_spectrum, dtype=np.float64).ravel()
|
||||
if bg.size == 0:
|
||||
raise ValueError("背景光谱长度为0,无法进行背景校正")
|
||||
|
||||
# 尾部对齐,取最小长度,避免维度不一致
|
||||
n = min(len(spec_cols), bg.shape[0])
|
||||
use_cols = spec_cols[-n:]
|
||||
df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1)
|
||||
return df
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification - Batch processing')
|
||||
|
||||
# 必需参数
|
||||
parser.add_argument('--input_dir', required=True, help='Path to input directory containing BIL files')
|
||||
parser.add_argument('--output_dir', required=True, help='Path to output directory for classification results')
|
||||
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
|
||||
|
||||
# 可选参数
|
||||
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
|
||||
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
|
||||
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
|
||||
|
||||
# parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model')
|
||||
# parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)')
|
||||
# parser.add_argument('--secondary_process_methods1', default='None',
|
||||
# help='Secondary process method 1 (default: None)')
|
||||
# parser.add_argument('--secondary_process_methods2', default='None',
|
||||
# help='Secondary process method 2 (default: None)')
|
||||
# parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2],
|
||||
# help='Target classes for secondary classification (space separated)')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 配置参数:直接在此修改
|
||||
# ----------------------------
|
||||
# BIL_PATH = r"D:/Data/Test/PET_bottle2.bil"
|
||||
# OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil'
|
||||
#
|
||||
# PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m"
|
||||
# PRIMARY_MODEL_TYPE = 'SVM'
|
||||
# PRIMARY_PROCESS_METHODS1 = 'SS'
|
||||
# PRIMARY_PROCESS_METHODS2 = 'None'
|
||||
#
|
||||
# SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None
|
||||
# SECONDARY_MODEL_TYPE = 'SVM'
|
||||
# SECONDARY_PROCESS_METHODS1 = 'None'
|
||||
# SECONDARY_PROCESS_METHODS2 = 'None'
|
||||
# SECONDARY_TARGET_CLASSES = [1, 2]
|
||||
|
||||
|
||||
def read_hdr_file(bil_path):
|
||||
hdr_path = bil_path.replace('.bil', '.hdr')
|
||||
with open(hdr_path, 'r') as f:
|
||||
header = f.readlines()
|
||||
|
||||
samples, lines = None, None
|
||||
|
||||
for line in header:
|
||||
if line.startswith('samples'):
|
||||
samples = int(line.split('=')[-1].strip())
|
||||
if line.startswith('lines'):
|
||||
lines = int(line.split('=')[-1].strip())
|
||||
|
||||
return samples, lines
|
||||
|
||||
|
||||
def shrink_contours(bil_path, df, shrink_pixels=1):
|
||||
"""
|
||||
对DataFrame中的所有轮廓进行收缩操作,避免塑料之间的相连
|
||||
|
||||
Args:
|
||||
bil_path: BIL文件路径,用于获取图像尺寸
|
||||
df: 包含contour列的DataFrame
|
||||
shrink_pixels: 收缩的像素数,默认1像素
|
||||
|
||||
Returns:
|
||||
更新后的DataFrame,contour列已被收缩
|
||||
"""
|
||||
samples, lines = read_hdr_file(bil_path)
|
||||
|
||||
# 创建腐蚀核
|
||||
kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8)
|
||||
|
||||
# 创建临时掩膜用于处理
|
||||
temp_mask = np.zeros((lines, samples), dtype=np.uint8)
|
||||
|
||||
# 创建DataFrame副本
|
||||
df = df.copy()
|
||||
|
||||
# 遍历每一行,更新轮廓
|
||||
for idx, row in df.iterrows():
|
||||
contour = row['contour']
|
||||
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
|
||||
continue
|
||||
|
||||
try:
|
||||
contour_array = np.array(contour, dtype=np.int32)
|
||||
if len(contour_array.shape) == 1:
|
||||
continue
|
||||
|
||||
# 清空临时掩膜
|
||||
temp_mask.fill(0)
|
||||
|
||||
# 填充轮廓
|
||||
cv2.fillPoly(temp_mask, [contour_array], 255)
|
||||
|
||||
# 对掩膜进行腐蚀操作
|
||||
eroded_mask = cv2.erode(temp_mask, kernel, iterations=1)
|
||||
|
||||
# 重新提取轮廓
|
||||
contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
if len(contours) > 0:
|
||||
# 选择最大的轮廓(如果有多个)
|
||||
largest_contour = max(contours, key=cv2.contourArea)
|
||||
# 转换为列表格式,保持与原始格式一致
|
||||
if len(largest_contour) >= 3:
|
||||
updated_contour = largest_contour.reshape(-1, 2).tolist()
|
||||
df.at[idx, 'contour'] = updated_contour
|
||||
except Exception as e:
|
||||
# 如果处理失败,保留原始轮廓
|
||||
continue
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def save_envi_classification(bil_path, df, savepath):
|
||||
samples, lines = read_hdr_file(bil_path)
|
||||
classification_result = np.zeros((lines, samples), dtype=np.uint16)
|
||||
|
||||
# 预处理:清除可能存在的类别10和11(移到循环外以提高效率)
|
||||
classification_result[(classification_result == 10)] = 0
|
||||
classification_result[(classification_result == 11)] = 0
|
||||
|
||||
for _, row in df.iterrows():
|
||||
contour = row['contour']
|
||||
prediction = int(row['Predictions']) + 1
|
||||
contour = np.array(contour, dtype=np.int32)
|
||||
|
||||
cv2.fillPoly(classification_result, [contour], prediction)
|
||||
|
||||
output_path = savepath
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
classification_result.tofile(f)
|
||||
|
||||
header_content = f"""ENVI
|
||||
description = {{
|
||||
Classification Result.}}
|
||||
samples = {samples}
|
||||
lines = {lines}
|
||||
bands = 1
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 2
|
||||
interleave = bil
|
||||
classes = 10
|
||||
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }}
|
||||
single pixel area = 0.000036
|
||||
unit = mm2
|
||||
byte order = 0
|
||||
wavelength units = nm
|
||||
"""
|
||||
filename, ext = os.path.splitext(savepath)
|
||||
# 替换扩展名为 '.hdr'
|
||||
header_filename = filename + '.hdr'
|
||||
|
||||
with open(header_filename, 'w') as header_file:
|
||||
header_file.write(header_content)
|
||||
|
||||
|
||||
def change_hdr_file(bil_path, wavelengths=None):
|
||||
# wavelengths=None 时仅在HDR缺失wavelength字段才写入;若提供则按提供内容写入
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
if not os.path.exists(hdr_path):
|
||||
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
||||
return
|
||||
|
||||
with open(hdr_path, 'r') as file:
|
||||
content = file.read()
|
||||
|
||||
if 'wavelength' in content and wavelengths is None:
|
||||
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
|
||||
return
|
||||
|
||||
if wavelengths is None:
|
||||
print(f"No wavelengths provided and HDR lacks wavelength; skipping write to avoid wrong bands.")
|
||||
return
|
||||
|
||||
needs_newline = not content.endswith('\n')
|
||||
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
|
||||
|
||||
with open(hdr_path, 'a') as file:
|
||||
if needs_newline:
|
||||
file.write('\n')
|
||||
file.write(wavelength_info)
|
||||
|
||||
print(f"Successfully ensured wavelength information in file: {os.path.basename(hdr_path)}")
|
||||
|
||||
|
||||
def get_bil_files(input_dir):
|
||||
"""获取输入目录中的所有BIL文件"""
|
||||
if not os.path.exists(input_dir):
|
||||
raise FileNotFoundError(f"输入目录不存在: {input_dir}")
|
||||
|
||||
bil_files = []
|
||||
for file in os.listdir(input_dir):
|
||||
if file.lower().endswith('.bil'):
|
||||
bil_path = os.path.join(input_dir, file)
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
bil_files.append(bil_path)
|
||||
else:
|
||||
print(f"警告: 找到BIL文件 {file} 但缺少对应的HDR文件,跳过")
|
||||
|
||||
if not bil_files:
|
||||
raise ValueError(f"在输入目录 {input_dir} 中未找到有效的BIL文件")
|
||||
|
||||
return sorted(bil_files)
|
||||
|
||||
|
||||
def validate_inputs(input_dir, output_dir, model_path):
|
||||
"""验证输入参数"""
|
||||
# 检查输入目录存在
|
||||
if not os.path.exists(input_dir):
|
||||
raise FileNotFoundError(f"输入目录不存在: {input_dir}")
|
||||
|
||||
# 检查输出目录存在,如果不存在则创建
|
||||
if not os.path.exists(output_dir):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建输出目录: {output_dir}") from e
|
||||
|
||||
# 检查模型文件存在
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"主模型文件不存在: {model_path}")
|
||||
|
||||
|
||||
def validate_single_bil_file(bil_path):
|
||||
"""验证单个BIL文件"""
|
||||
# 检查BIL和HDR文件存在
|
||||
if not os.path.exists(bil_path):
|
||||
raise FileNotFoundError(f"BIL文件不存在: {bil_path}")
|
||||
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
if not os.path.exists(hdr_path):
|
||||
raise FileNotFoundError(f"HDR文件不存在: {hdr_path}")
|
||||
|
||||
# 检查BIL文件波段数是否足够
|
||||
try:
|
||||
from spectral.io import envi
|
||||
img = envi.open(hdr_path, bil_path)
|
||||
n_bands = img.nbands
|
||||
# bil2rgb需要波段索引9, 59, 159
|
||||
if n_bands <= 159:
|
||||
raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e
|
||||
|
||||
|
||||
def generate_rgb(bil_path):
|
||||
"""处理BIL文件生成RGB图像"""
|
||||
try:
|
||||
rgb_img = process_bil_files(bil_path)
|
||||
return rgb_img
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e
|
||||
|
||||
|
||||
def run_segmentation(rgb_img):
|
||||
"""运行分割获取掩膜"""
|
||||
try:
|
||||
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
||||
image=rgb_img,
|
||||
filter_method='threshold',
|
||||
diameter=None,
|
||||
flow_threshold=0.4,
|
||||
cellprob_threshold=-1,
|
||||
detect_filter=True
|
||||
)
|
||||
return mask, filter_mask_original
|
||||
except Exception as e:
|
||||
raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e
|
||||
|
||||
|
||||
def extract_primary_features(bil_path, mask):
|
||||
"""提取主要特征"""
|
||||
try:
|
||||
df = process_images(bil_path, mask)
|
||||
return df
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e
|
||||
|
||||
|
||||
def compute_background_spectrum(bil_path, mask):
|
||||
"""计算背景光谱"""
|
||||
try:
|
||||
df_correct = process_images_background(bil_path, mask)
|
||||
return df_correct
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e
|
||||
|
||||
|
||||
def apply_background_and_optional_resample(df, bg_spectrum, bil_path):
|
||||
"""应用背景校正和可选的重采样"""
|
||||
# 识别光谱列:所有以wavelength_开头的列
|
||||
spec_cols = [c for c in df.columns if c.startswith('wavelength_')]
|
||||
|
||||
if not spec_cols:
|
||||
raise ValueError("未找到光谱列(以wavelength_开头的列)")
|
||||
|
||||
if len(spec_cols) != len(bg_spectrum):
|
||||
raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配")
|
||||
|
||||
# 背景校正:用背景光谱逐列相除
|
||||
df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1)
|
||||
|
||||
# 检查是否需要重采样
|
||||
src_waves = read_wavelengths_from_hdr(bil_path)
|
||||
need_resample = (src_waves.size > 0 and
|
||||
(src_waves.size != len(TRAIN_WAVELENGTHS) or
|
||||
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)))
|
||||
|
||||
if need_resample:
|
||||
print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}")
|
||||
|
||||
# 提取光谱数据
|
||||
X_src = df[spec_cols].to_numpy(dtype=np.float64)
|
||||
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
|
||||
|
||||
# 替换光谱列
|
||||
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
|
||||
df = pd.concat([
|
||||
df.drop(columns=spec_cols), # 移除原有光谱列
|
||||
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
|
||||
], axis=1)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def clean_and_select_columns(df):
|
||||
"""数据清理和列选择"""
|
||||
# 移除NaN值
|
||||
df = df.dropna()
|
||||
|
||||
# 过滤轮廓点数不足的样本
|
||||
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||||
|
||||
# 过滤面积过小的样本
|
||||
df = df[df['area'] >= 500]
|
||||
|
||||
# 列筛选:使用原来的硬编码索引删除逻辑
|
||||
# cols_to_remove = df.columns[np.r_[1:10, 11:15, 97:120, 176:179 ]]
|
||||
cols_to_remove = df.columns[np.r_[-10: -1]]
|
||||
df = df.drop(columns=cols_to_remove)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def run_primary_classification(df, primary_model_path):
|
||||
"""运行主要分类"""
|
||||
try:
|
||||
# 验证特征维度
|
||||
try:
|
||||
import joblib
|
||||
scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl')
|
||||
if os.path.exists(scaler_path):
|
||||
scaler = joblib.load(scaler_path)
|
||||
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
|
||||
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
|
||||
raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}维")
|
||||
except Exception as e:
|
||||
print(f"警告: 无法验证特征维度: {e}")
|
||||
|
||||
df_pre = predict_with_model(
|
||||
df,
|
||||
primary_model_path,
|
||||
model_type='SVM',
|
||||
ProcessMethods1='SS',
|
||||
ProcessMethods2='None'
|
||||
)
|
||||
return df_pre
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e
|
||||
|
||||
|
||||
def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original):
|
||||
"""根据需要运行二次分类"""
|
||||
# 二次分类配置(保持在代码内)
|
||||
secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m')
|
||||
secondary_target_classes = [1, 2] # HDPE, LDPE
|
||||
|
||||
target_classes = set(secondary_target_classes or [])
|
||||
mask_secondary = df_pre['Predictions'].isin(target_classes)
|
||||
|
||||
if not mask_secondary.any():
|
||||
print("未找到目标类别样本,跳过二次分类")
|
||||
return df_pre
|
||||
|
||||
print(f"为类别 {sorted(target_classes)} 运行二次分类")
|
||||
|
||||
# 检查二次模型是否存在
|
||||
if not os.path.exists(secondary_model_path):
|
||||
print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类")
|
||||
return df_pre
|
||||
|
||||
try:
|
||||
# 图像信息的背景矫正
|
||||
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
|
||||
|
||||
# 创建只包含目标类别的掩膜
|
||||
mask_second = np.zeros_like(mask, dtype=np.uint16)
|
||||
for idx in df_pre[mask_secondary].index:
|
||||
contour = df_pre.loc[idx, 'contour']
|
||||
if isinstance(contour, list) and len(contour) > 0:
|
||||
contour_array = np.array(contour, dtype=np.int32)
|
||||
cv2.fillPoly(mask_second, [contour_array], idx + 1)
|
||||
|
||||
# 提取特征
|
||||
df_shape = extract_features(df_correct, mask_second)
|
||||
|
||||
# 确保使用前13列作为模型输入特征
|
||||
if len(df_shape.columns) >= 13:
|
||||
df_shape = df_shape.iloc[:, :13]
|
||||
|
||||
# 二次分类
|
||||
df_secondary = predict_with_model(
|
||||
df_shape,
|
||||
secondary_model_path,
|
||||
model_type='SVM',
|
||||
ProcessMethods1='None',
|
||||
ProcessMethods2='None'
|
||||
)
|
||||
|
||||
# 更新预测结果(类别+1)
|
||||
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}")
|
||||
|
||||
return df_pre
|
||||
|
||||
|
||||
def postprocess_class7_shadow(df_pre, rgb_img):
|
||||
"""后处理类别7中的背景阴影"""
|
||||
class_7_mask = df_pre['Predictions'] == 7
|
||||
if not class_7_mask.any():
|
||||
return df_pre
|
||||
|
||||
print(f"处理 {class_7_mask.sum()} 个类别7样本,识别背景阴影...")
|
||||
|
||||
# 将PIL Image转换为numpy数组
|
||||
if hasattr(rgb_img, 'mode'): # PIL Image
|
||||
rgb_img_array = np.array(rgb_img)
|
||||
else:
|
||||
rgb_img_array = rgb_img
|
||||
|
||||
# 转换为灰度图
|
||||
if len(rgb_img_array.shape) == 3:
|
||||
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray_img = rgb_img_array
|
||||
|
||||
# 计算梯度图
|
||||
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
|
||||
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
|
||||
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
|
||||
|
||||
# 收集类别7样本的边缘梯度值
|
||||
all_class7_gradients = []
|
||||
valid_indices = []
|
||||
|
||||
for idx in df_pre[class_7_mask].index:
|
||||
try:
|
||||
contour = df_pre.loc[idx, 'contour']
|
||||
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
|
||||
continue
|
||||
|
||||
contour_array = np.array(contour, dtype=np.int32)
|
||||
if len(contour_array.shape) == 1:
|
||||
continue
|
||||
|
||||
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
|
||||
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
|
||||
edge_gradients = gradient_magnitude[mask_img > 0]
|
||||
|
||||
if len(edge_gradients) > 0:
|
||||
all_class7_gradients.extend(edge_gradients)
|
||||
valid_indices.append(idx)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 确定梯度阈值
|
||||
if len(all_class7_gradients) > 0:
|
||||
gradient_threshold = np.percentile(all_class7_gradients, 30)
|
||||
else:
|
||||
gradient_threshold = np.percentile(gradient_magnitude, 30)
|
||||
|
||||
print(f"类别7梯度阈值: {gradient_threshold:.2f}")
|
||||
|
||||
# 处理每个类别7样本
|
||||
indices_to_update = []
|
||||
for idx in valid_indices:
|
||||
try:
|
||||
contour = df_pre.loc[idx, 'contour']
|
||||
contour_array = np.array(contour, dtype=np.int32)
|
||||
|
||||
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
|
||||
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
|
||||
|
||||
edge_gradients = gradient_magnitude[mask_img > 0]
|
||||
if len(edge_gradients) == 0:
|
||||
continue
|
||||
|
||||
mean_gradient = np.mean(edge_gradients)
|
||||
|
||||
if mean_gradient < gradient_threshold:
|
||||
indices_to_update.append(idx)
|
||||
print(f"样本 {idx}: 平均梯度={mean_gradient:.2f}, 阈值={gradient_threshold:.2f} -> 识别为背景阴影")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理样本 {idx} 时出错: {e}")
|
||||
continue
|
||||
|
||||
# 更新分类结果
|
||||
if indices_to_update:
|
||||
df_pre.loc[indices_to_update, 'Predictions'] = 9
|
||||
print(f"将 {len(indices_to_update)} 个样本从类别7改为类别9(背景阴影)")
|
||||
else:
|
||||
print("无需更新类别7样本")
|
||||
|
||||
return df_pre
|
||||
|
||||
|
||||
def write_outputs(bil_path, df_pre, output_path):
|
||||
"""写入输出结果"""
|
||||
try:
|
||||
# 收缩轮廓
|
||||
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
|
||||
|
||||
# 保存ENVI分类结果
|
||||
save_envi_classification(bil_path, df_pre, output_path)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"保存结果失败: output_path={output_path}") from e
|
||||
|
||||
|
||||
def process_single_file(bil_path, output_path, primary_model_path):
|
||||
"""处理单个BIL文件的完整流程"""
|
||||
try:
|
||||
# 验证输入
|
||||
validate_single_bil_file(bil_path)
|
||||
|
||||
# 修改HDR文件
|
||||
change_hdr_file(bil_path)
|
||||
|
||||
# 处理BIL文件生成RGB图像
|
||||
print(f" 处理BIL文件生成RGB图像...")
|
||||
rgb_img = generate_rgb(bil_path)
|
||||
|
||||
# 分割阶段
|
||||
segmentation_start_time = time.time()
|
||||
print(f" 生成掩膜...")
|
||||
mask, filter_mask_original = run_segmentation(rgb_img)
|
||||
|
||||
# 提取特征
|
||||
print(f" 从BIL文件提取特征...")
|
||||
df = extract_primary_features(bil_path, mask)
|
||||
|
||||
# 背景校正
|
||||
print(f" 应用背景校正...")
|
||||
bg_spectrum = compute_background_spectrum(bil_path, mask)
|
||||
|
||||
# 仅应用背景校正,不进行重采样
|
||||
df = apply_background_no_resample(df, bg_spectrum)
|
||||
|
||||
# 数据清理和列选择
|
||||
print(f" 清理数据...")
|
||||
df = clean_and_select_columns(df)
|
||||
|
||||
segmentation_time = time.time() - segmentation_start_time
|
||||
|
||||
# 分类阶段
|
||||
classification_start_time = time.time()
|
||||
print(f" 预测分类...")
|
||||
df_pre = run_primary_classification(df, primary_model_path)
|
||||
|
||||
# 二次分类
|
||||
df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original)
|
||||
|
||||
# 后处理类别7阴影
|
||||
df_pre = postprocess_class7_shadow(df_pre, rgb_img)
|
||||
|
||||
classification_time = time.time() - classification_start_time
|
||||
|
||||
# 保存结果
|
||||
print(f" 保存ENVI分类结果...")
|
||||
write_outputs(bil_path, df_pre, output_path)
|
||||
|
||||
return segmentation_time, classification_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理文件失败 {os.path.basename(bil_path)}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 批量处理"""
|
||||
args = parse_arguments()
|
||||
|
||||
input_dir = args.input_dir
|
||||
output_dir = args.output_dir
|
||||
primary_model_path = args.model_path
|
||||
|
||||
# 记录总开始时间
|
||||
total_start_time = time.time()
|
||||
|
||||
try:
|
||||
# 验证输入参数
|
||||
validate_inputs(input_dir, output_dir, primary_model_path)
|
||||
|
||||
# 获取所有BIL文件
|
||||
bil_files = get_bil_files(input_dir)
|
||||
print(f"找到 {len(bil_files)} 个BIL文件待处理")
|
||||
|
||||
# 统计信息
|
||||
total_files = len(bil_files)
|
||||
processed_files = 0
|
||||
failed_files = 0
|
||||
total_segmentation_time = 0
|
||||
total_classification_time = 0
|
||||
|
||||
# 逐个处理文件
|
||||
for i, bil_path in enumerate(bil_files, 1):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"处理文件 {i}/{total_files}: {os.path.basename(bil_path)}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
# 生成输出文件名:原文件名 + "_classification.bil"
|
||||
base_name = os.path.splitext(os.path.basename(bil_path))[0]
|
||||
output_filename = f"{base_name}_classification.bil"
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
# 处理单个文件
|
||||
segmentation_time, classification_time = process_single_file(
|
||||
bil_path, output_path, primary_model_path
|
||||
)
|
||||
|
||||
# 更新统计信息
|
||||
total_segmentation_time += segmentation_time
|
||||
total_classification_time += classification_time
|
||||
processed_files += 1
|
||||
|
||||
print(f"文件 {os.path.basename(bil_path)} 处理完成")
|
||||
print(f"结果保存至: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"文件 {os.path.basename(bil_path)} 处理失败: {e}")
|
||||
failed_files += 1
|
||||
continue
|
||||
|
||||
# 计算平均耗时
|
||||
if processed_files > 0:
|
||||
avg_segmentation_time = total_segmentation_time / processed_files
|
||||
avg_classification_time = total_classification_time / processed_files
|
||||
avg_total_time = (total_segmentation_time + total_classification_time) / processed_files
|
||||
|
||||
# 计算总耗时
|
||||
total_time = time.time() - total_start_time
|
||||
|
||||
# 打印汇总统计
|
||||
print(f"\n{'=' * 60}")
|
||||
print("批量处理完成")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"总文件数: {total_files}")
|
||||
print(f"成功处理: {processed_files}")
|
||||
print(f"处理失败: {failed_files}")
|
||||
print(f"成功率: {processed_files/total_files*100:.1f}%" if total_files > 0 else "成功率: 0%")
|
||||
print(f"{'=' * 60}")
|
||||
if processed_files > 0:
|
||||
print(f"平均分割耗时: {avg_segmentation_time:.2f} 秒")
|
||||
print(f"平均分类耗时: {avg_classification_time:.2f} 秒")
|
||||
print(f"平均总耗时: {avg_total_time:.2f} 秒")
|
||||
print(f"实际总耗时: {total_time:.2f} 秒")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"结果保存目录: {output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"批量处理失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user