修改分割模块
This commit is contained in:
2
.idea/.gitignore
generated
vendored
2
.idea/.gitignore
generated
vendored
@ -1,4 +1,4 @@
|
|||||||
# 默认忽略的文件
|
# Default ignored files
|
||||||
/shelf/
|
/shelf/
|
||||||
/workspace.xml
|
/workspace.xml
|
||||||
# 基于编辑器的 HTTP 客户端请求
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
160
chose_bands.py
Normal file
160
chose_bands.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import argparse
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import itertools
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
ap = argparse.ArgumentParser(description="Select 3 bands for false-color to maximize inter-class separability")
|
||||||
|
ap.add_argument("--csv", required=True, help="Input CSV: first column is class label, rest are spectra columns")
|
||||||
|
ap.add_argument("--top_k", type=int, default=30, help="How many best single bands to preselect (default: 30)")
|
||||||
|
ap.add_argument("--top_triplets", type=int, default=10, help="How many best triplets to print (default: 10)")
|
||||||
|
ap.add_argument("--map_order", choices=["auto", "wavelength_bgr"], default="auto",
|
||||||
|
help="RGB mapping order: auto=use scored order; wavelength_bgr=short->B, mid->G, long->R")
|
||||||
|
return ap.parse_args()
|
||||||
|
|
||||||
|
def try_parse_wavelength(col_name):
|
||||||
|
# 支持 "wavelength_912.36" / "912.36" / "band_12" 等
|
||||||
|
m = re.search(r"([0-9]+(\.[0-9]+)?)", str(col_name))
|
||||||
|
if m:
|
||||||
|
try:
|
||||||
|
return float(m.group(1))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def read_data(csv_path):
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
if df.shape[1] < 4:
|
||||||
|
raise ValueError("需要至少1个类别列 + 3个光谱列")
|
||||||
|
y = df.iloc[:, 0].to_numpy()
|
||||||
|
X = df.iloc[:, 1:].to_numpy(dtype=np.float64)
|
||||||
|
cols = list(df.columns[1:])
|
||||||
|
waves = []
|
||||||
|
for c in cols:
|
||||||
|
w = try_parse_wavelength(c)
|
||||||
|
waves.append(w if w is not None else c)
|
||||||
|
return X, y, cols, waves
|
||||||
|
|
||||||
|
def compute_anova_f_scores(X, y):
|
||||||
|
# 一维ANOVA F-score: F = (SS_between/df_between) / (SS_within/df_within)
|
||||||
|
N, M = X.shape
|
||||||
|
classes = []
|
||||||
|
for cls in np.unique(y):
|
||||||
|
idx = (y == cls)
|
||||||
|
if np.sum(idx) > 0:
|
||||||
|
classes.append(idx)
|
||||||
|
k = len(classes)
|
||||||
|
if k < 2:
|
||||||
|
raise ValueError("类别不足2类,无法计算区分度")
|
||||||
|
|
||||||
|
mu = X.mean(axis=0)
|
||||||
|
between_ss = np.zeros(M, dtype=np.float64)
|
||||||
|
within_ss = np.zeros(M, dtype=np.float64)
|
||||||
|
|
||||||
|
for idx in classes:
|
||||||
|
ng = np.sum(idx)
|
||||||
|
Xg = X[idx]
|
||||||
|
if ng == 0:
|
||||||
|
continue
|
||||||
|
mg = Xg.mean(axis=0)
|
||||||
|
# 无偏方差,若ng==1则设为0
|
||||||
|
vg = Xg.var(axis=0, ddof=1) if ng > 1 else np.zeros(M, dtype=np.float64)
|
||||||
|
between_ss += ng * (mg - mu) ** 2
|
||||||
|
within_ss += (ng - 1) * vg
|
||||||
|
|
||||||
|
dfb = k - 1
|
||||||
|
dfw = N - k if N > k else 1
|
||||||
|
F = (between_ss / max(dfb, 1)) / (within_ss / max(dfw, 1) + 1e-12)
|
||||||
|
return F # shape (M,)
|
||||||
|
|
||||||
|
def score_triplet(X, y, idx_triplet, eps=1e-6):
|
||||||
|
# LDA准则: J = trace( (Sw+epsI)^-1 Sb ),三维空间
|
||||||
|
t = list(idx_triplet)
|
||||||
|
Xt = X[:, t] # (N,3)
|
||||||
|
classes = []
|
||||||
|
for cls in np.unique(y):
|
||||||
|
idx = (y == cls)
|
||||||
|
if np.sum(idx) > 0:
|
||||||
|
classes.append(idx)
|
||||||
|
|
||||||
|
m = Xt.mean(axis=0)
|
||||||
|
Sb = np.zeros((3, 3), dtype=np.float64)
|
||||||
|
Sw = np.zeros((3, 3), dtype=np.float64)
|
||||||
|
|
||||||
|
for idx in classes:
|
||||||
|
Xg = Xt[idx]
|
||||||
|
ng = Xg.shape[0]
|
||||||
|
if ng == 0:
|
||||||
|
continue
|
||||||
|
mg = Xg.mean(axis=0)
|
||||||
|
diff = (mg - m).reshape(3, 1)
|
||||||
|
Sb += ng * (diff @ diff.T)
|
||||||
|
if ng > 1:
|
||||||
|
Cg = np.cov(Xg, rowvar=False, ddof=1)
|
||||||
|
else:
|
||||||
|
Cg = np.zeros((3, 3), dtype=np.float64)
|
||||||
|
Sw += (ng - 1) * Cg
|
||||||
|
|
||||||
|
# 稳定逆
|
||||||
|
A = Sw + eps * np.eye(3)
|
||||||
|
try:
|
||||||
|
Ainv = np.linalg.pinv(A)
|
||||||
|
except Exception:
|
||||||
|
return -np.inf
|
||||||
|
J = np.trace(Ainv @ Sb)
|
||||||
|
return float(J)
|
||||||
|
|
||||||
|
def select_bands(csv_path, top_k=30, top_triplets=10, map_order="auto"):
|
||||||
|
X, y, cols, waves = read_data(csv_path)
|
||||||
|
F = compute_anova_f_scores(X, y)
|
||||||
|
order = np.argsort(-F)
|
||||||
|
top_idx = order[:min(top_k, X.shape[1])]
|
||||||
|
|
||||||
|
best = []
|
||||||
|
for comb in itertools.combinations(top_idx, 3):
|
||||||
|
s = score_triplet(X, y, comb)
|
||||||
|
best.append((s, comb))
|
||||||
|
best.sort(key=lambda x: -x[0])
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for s, comb in best[:top_triplets]:
|
||||||
|
names = [cols[i] for i in comb]
|
||||||
|
wavs = [waves[i] for i in comb]
|
||||||
|
# 建议RGB映射
|
||||||
|
if map_order == "wavelength_bgr" and all(isinstance(w, (int, float)) for w in wavs):
|
||||||
|
order_rgb = np.argsort(wavs) # 小->大
|
||||||
|
# 小(blue), 中(green), 大(red)
|
||||||
|
b,g,r = [names[order_rgb[0]]], [names[order_rgb[1]]], [names[order_rgb[2]]]
|
||||||
|
rgb = {"R": r[0], "G": g[0], "B": b[0]}
|
||||||
|
rgb_w = {"R": wavs[order_rgb[2]], "G": wavs[order_rgb[1]], "B": wavs[order_rgb[0]]}
|
||||||
|
else:
|
||||||
|
# 直接用得分顺序:第一->R, 第二->G, 第三->B
|
||||||
|
rgb = {"R": names[0], "G": names[1], "B": names[2]}
|
||||||
|
rgb_w = {"R": waves[comb[0]], "G": waves[comb[1]], "B": waves[comb[2]]}
|
||||||
|
results.append({
|
||||||
|
"score": round(s, 6),
|
||||||
|
"bands": names,
|
||||||
|
"wavelengths": wavs,
|
||||||
|
"rgb_mapping": rgb,
|
||||||
|
"rgb_wavelengths": rgb_w
|
||||||
|
})
|
||||||
|
return results, top_idx, F
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
results, top_idx, F = select_bands(args.csv, args.top_k, args.top_triplets, args.map_order)
|
||||||
|
|
||||||
|
print(f"Top-{args.top_k} bands by ANOVA F-score (index:col_name):")
|
||||||
|
print(", ".join([f"{i}:{i}" for i in top_idx])) # 如需列名可自行打印
|
||||||
|
|
||||||
|
print("\nBest triplets:")
|
||||||
|
for r in results:
|
||||||
|
bands = r["bands"]
|
||||||
|
wavs = r["wavelengths"]
|
||||||
|
rgb = r["rgb_mapping"]
|
||||||
|
print(f"- score={r['score']:.4f} bands={bands} wavelengths={wavs} RGB={rgb}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -587,22 +587,22 @@ def predict_with_model(df, model_path, model_type='SVM', ProcessMethods1='SS', P
|
|||||||
# 主函数,用于训练
|
# 主函数,用于训练
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 使用 pandas 读取 CSV 文件
|
# 使用 pandas 读取 CSV 文件
|
||||||
file_path = r"E:\code\plastic\plastic20260224\plastic\plastic\output\20260224\all.csv"
|
file_path = r"D:\Data2\traindata1\all\isf0303.csv"
|
||||||
df = pd.read_csv(
|
df = pd.read_csv(
|
||||||
file_path,
|
file_path,
|
||||||
encoding='utf-8', # 指定编码,如果出错可尝试 'gbk' 或 'gb18030'
|
encoding='utf-8', # 指定编码,如果出错可尝试 'gbk' 或 'gb18030'
|
||||||
low_memory=False # 避免数据类型推断问题
|
low_memory=False # 避免数据类型推断问题
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用 pandas 选择要删除的列(第93到117列,索引从0开始)
|
# # 使用 pandas 选择要删除的列(第93到117列,索引从0开始)
|
||||||
cols_to_remove = df.columns[np.r_[1:5, 87:110, 166:169]]
|
# cols_to_remove = df.columns[87:110]
|
||||||
|
#
|
||||||
# 使用 pandas 删除指定列
|
# # 使用 pandas 删除指定列
|
||||||
df_filtered = df.drop(columns=cols_to_remove)
|
# df_filtered = df.drop(columns=cols_to_remove)
|
||||||
|
|
||||||
# 使用 pandas 提取特征数据(从第2列开始到最后,排除第一列标签列)
|
# 使用 pandas 提取特征数据(从第2列开始到最后,排除第一列标签列)
|
||||||
x = df_filtered.iloc[:, 1:]
|
# x = df_filtered.iloc[:, 1:]
|
||||||
# x = df.iloc[:, 1:]
|
x = df.iloc[:, 1:]
|
||||||
# 使用 pandas 提取标签(第一列)
|
# 使用 pandas 提取标签(第一列)
|
||||||
y = df.iloc[:, 0]
|
y = df.iloc[:, 0]
|
||||||
X_train, X_test, y_train, y_test = SpectralQualitativeAnalysis(x, y, 'SS', 'None', 'None', 'random', use_smote=True)
|
X_train, X_test, y_train, y_test = SpectralQualitativeAnalysis(x, y, 'SS', 'None', 'None', 'random', use_smote=True)
|
||||||
@ -622,7 +622,7 @@ if __name__ == "__main__":
|
|||||||
# save_model(clf, r"D:\WQ\plastic\classification_model\modelsave\svm.m", model_type='SVM')
|
# save_model(clf, r"D:\WQ\plastic\classification_model\modelsave\svm.m", model_type='SVM')
|
||||||
|
|
||||||
# 示例2: 使用统一的训练和保存函数(推荐)
|
# 示例2: 使用统一的训练和保存函数(推荐)
|
||||||
save_dir = r"E:\code\plastic\plastic20260224\plastic\plastic\output\20260224\modelsave"
|
save_dir = r"D:\plastic\plastic\modelsave\240model\new\0303"
|
||||||
|
|
||||||
# 训练并保存多个模型
|
# 训练并保存多个模型
|
||||||
models_to_train = ['SVM']#'SVM', 'RF', 'XGBoost', 'LogisticRegression'
|
models_to_train = ['SVM']#'SVM', 'RF', 'XGBoost', 'LogisticRegression'
|
||||||
|
|||||||
@ -132,7 +132,7 @@ def Preprocessing(method, input_spectrum):
|
|||||||
elif method == 'MMS':
|
elif method == 'MMS':
|
||||||
output_spectrum = MMS(input_spectrum.values)
|
output_spectrum = MMS(input_spectrum.values)
|
||||||
elif method == 'SS':
|
elif method == 'SS':
|
||||||
output_spectrum = SS(input_spectrum.values, r'E:\code\plastic\plastic20260224\plastic\plastic\output\20260224\modelsave\scaler_params.pkl')
|
output_spectrum = SS(input_spectrum.values, r'D:\plastic\plastic\modelsave\240model\new\0303\scaler_params.pkl')
|
||||||
elif method == 'CT':
|
elif method == 'CT':
|
||||||
output_spectrum = CT(input_spectrum.values)
|
output_spectrum = CT(input_spectrum.values)
|
||||||
elif method == 'SNV':
|
elif method == 'SNV':
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
692
main.py
692
main.py
@ -14,10 +14,8 @@ import time
|
|||||||
|
|
||||||
matplotlib.use('TkAgg')
|
matplotlib.use('TkAgg')
|
||||||
|
|
||||||
# 训练相机波长(168通道)
|
# 训练相机波长(237通道)
|
||||||
TRAIN_WAVELENGTHS = [
|
TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
|
||||||
898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2
|
|
||||||
]
|
|
||||||
|
|
||||||
def read_wavelengths_from_hdr(bil_path):
|
def read_wavelengths_from_hdr(bil_path):
|
||||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||||
@ -50,6 +48,28 @@ def resample_spectra_matrix(X, src_waves, dst_waves):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def apply_background_no_resample(df, bg_spectrum):
|
||||||
|
"""
|
||||||
|
仅做背景校正,不做任何重采样。
|
||||||
|
- 自动选择以 wavelength_ 或 band_ 开头的光谱列
|
||||||
|
- 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正
|
||||||
|
"""
|
||||||
|
# 识别光谱列
|
||||||
|
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
|
||||||
|
if not spec_cols:
|
||||||
|
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
|
||||||
|
|
||||||
|
bg = np.asarray(bg_spectrum, dtype=np.float64).ravel()
|
||||||
|
if bg.size == 0:
|
||||||
|
raise ValueError("背景光谱长度为0,无法进行背景校正")
|
||||||
|
|
||||||
|
# 尾部对齐,取最小长度,避免维度不一致
|
||||||
|
n = min(len(spec_cols), bg.shape[0])
|
||||||
|
use_cols = spec_cols[-n:]
|
||||||
|
df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
"""解析命令行参数"""
|
"""解析命令行参数"""
|
||||||
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification')
|
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification')
|
||||||
@ -59,7 +79,7 @@ def parse_arguments():
|
|||||||
parser.add_argument('--output_path', required=True, help='Path to output classification result')
|
parser.add_argument('--output_path', required=True, help='Path to output classification result')
|
||||||
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
|
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
|
||||||
|
|
||||||
# 可选参数
|
|
||||||
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
|
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
|
||||||
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
|
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
|
||||||
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
|
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
|
||||||
@ -176,11 +196,11 @@ def save_envi_classification(bil_path, df, savepath):
|
|||||||
|
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
contour = row['contour']
|
contour = row['contour']
|
||||||
prediction = int(row['Predictions']) + 1
|
prediction = int(row['Predictions']) + 1 # 先加1
|
||||||
contour = np.array(contour, dtype=np.int32)
|
if prediction in (10, 11): # 再判断是否为10或11
|
||||||
# 先将 classification_result 中的 10 和 11 替换为 0
|
prediction = 0 # 视为背景
|
||||||
classification_result[(classification_result == 10)] = 0
|
|
||||||
|
|
||||||
|
contour = np.array(contour, dtype=np.int32)
|
||||||
cv2.fillPoly(classification_result, [contour], prediction)
|
cv2.fillPoly(classification_result, [contour], prediction)
|
||||||
|
|
||||||
output_path = savepath
|
output_path = savepath
|
||||||
@ -206,7 +226,6 @@ byte order = 0
|
|||||||
wavelength units = nm
|
wavelength units = nm
|
||||||
"""
|
"""
|
||||||
filename, ext = os.path.splitext(savepath)
|
filename, ext = os.path.splitext(savepath)
|
||||||
# 替换扩展名为 '.hdr'
|
|
||||||
header_filename = filename + '.hdr'
|
header_filename = filename + '.hdr'
|
||||||
|
|
||||||
with open(header_filename, 'w') as header_file:
|
with open(header_filename, 'w') as header_file:
|
||||||
@ -214,300 +233,463 @@ wavelength units = nm
|
|||||||
|
|
||||||
|
|
||||||
def change_hdr_file(bil_path, wavelengths=None):
|
def change_hdr_file(bil_path, wavelengths=None):
|
||||||
# wavelengths=None 时仅在HDR缺失wavelength字段才写入;若提供则按提供内容写入
|
|
||||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||||
if not os.path.exists(hdr_path):
|
if not os.path.exists(hdr_path):
|
||||||
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(hdr_path, 'r') as file:
|
# 仅在缺少 wavelength 字段时才尝试写入
|
||||||
|
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
|
|
||||||
if 'wavelength' in content and wavelengths is None:
|
if 'wavelength' in content:
|
||||||
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
|
print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。")
|
||||||
return
|
return
|
||||||
|
|
||||||
if wavelengths is None:
|
if wavelengths is None or len(wavelengths) == 0:
|
||||||
print(f"No wavelengths provided and HDR lacks wavelength; skipping write to avoid wrong bands.")
|
print("HDR 缺少 wavelength,但未提供 wavelengths,跳过写入以避免错误。")
|
||||||
return
|
return
|
||||||
|
|
||||||
needs_newline = not content.endswith('\n')
|
needs_newline = not content.endswith('\n')
|
||||||
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
|
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
|
||||||
|
|
||||||
with open(hdr_path, 'a') as file:
|
with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file:
|
||||||
if needs_newline:
|
if needs_newline:
|
||||||
file.write('\n')
|
file.write('\n')
|
||||||
file.write(wavelength_info)
|
file.write(wavelength_info)
|
||||||
|
|
||||||
print(f"Successfully ensured wavelength information in file: {os.path.basename(hdr_path)}")
|
print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_inputs(bil_path, output_path, model_path):
|
||||||
|
"""验证输入文件和参数"""
|
||||||
|
# 检查BIL和HDR文件存在
|
||||||
|
if not os.path.exists(bil_path):
|
||||||
|
raise FileNotFoundError(f"BIL文件不存在: {bil_path}")
|
||||||
|
|
||||||
|
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||||
|
if not os.path.exists(hdr_path):
|
||||||
|
raise FileNotFoundError(f"HDR文件不存在: {hdr_path}")
|
||||||
|
|
||||||
|
# 检查输出目录可写
|
||||||
|
output_dir = os.path.dirname(output_path)
|
||||||
|
if output_dir and not os.path.exists(output_dir):
|
||||||
|
try:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"无法创建输出目录: {output_dir}") from e
|
||||||
|
|
||||||
|
# 检查模型文件存在
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"主模型文件不存在: {model_path}")
|
||||||
|
|
||||||
|
# 检查BIL文件波段数是否足够
|
||||||
|
try:
|
||||||
|
from spectral.io import envi
|
||||||
|
img = envi.open(hdr_path, bil_path)
|
||||||
|
n_bands = img.nbands
|
||||||
|
# bil2rgb需要波段索引9, 59, 159
|
||||||
|
if n_bands <= 159:
|
||||||
|
raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def generate_rgb(bil_path):
|
||||||
|
"""处理BIL文件生成RGB图像"""
|
||||||
|
try:
|
||||||
|
rgb_img = process_bil_files(bil_path)
|
||||||
|
return rgb_img
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def run_segmentation(rgb_img, segmentation_model_path=None):
|
||||||
|
"""运行分割获取掩膜"""
|
||||||
|
try:
|
||||||
|
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
||||||
|
image=rgb_img,
|
||||||
|
filter_method='threshold',
|
||||||
|
diameter=None,
|
||||||
|
flow_threshold=0.4,
|
||||||
|
cellprob_threshold=-1,
|
||||||
|
model_path=segmentation_model_path,
|
||||||
|
detect_filter=True
|
||||||
|
)
|
||||||
|
return mask, filter_mask_original
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e
|
||||||
|
|
||||||
|
|
||||||
|
def extract_primary_features(bil_path, mask):
|
||||||
|
"""提取主要特征"""
|
||||||
|
try:
|
||||||
|
df = process_images(bil_path, mask)
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def compute_background_spectrum(bil_path, mask):
|
||||||
|
"""计算背景光谱"""
|
||||||
|
try:
|
||||||
|
df_correct = process_images_background(bil_path, mask)
|
||||||
|
return df_correct
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def apply_background_and_optional_resample(df, bg_spectrum, bil_path):
|
||||||
|
"""应用背景校正和可选的重采样"""
|
||||||
|
# 识别光谱列:所有以wavelength_开头的列
|
||||||
|
spec_cols = [c for c in df.columns if c.startswith('wavelength_')]
|
||||||
|
|
||||||
|
if not spec_cols:
|
||||||
|
raise ValueError("未找到光谱列(以wavelength_开头的列)")
|
||||||
|
|
||||||
|
if len(spec_cols) != len(bg_spectrum):
|
||||||
|
raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配")
|
||||||
|
|
||||||
|
# 背景校正:用背景光谱逐列相除
|
||||||
|
df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1)
|
||||||
|
|
||||||
|
# 检查是否需要重采样
|
||||||
|
src_waves = read_wavelengths_from_hdr(bil_path)
|
||||||
|
need_resample = (src_waves.size > 0 and
|
||||||
|
(src_waves.size != len(TRAIN_WAVELENGTHS) or
|
||||||
|
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)))
|
||||||
|
|
||||||
|
if need_resample:
|
||||||
|
print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}")
|
||||||
|
|
||||||
|
# 提取光谱数据
|
||||||
|
X_src = df[spec_cols].to_numpy(dtype=np.float64)
|
||||||
|
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
|
||||||
|
|
||||||
|
# 替换光谱列
|
||||||
|
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
|
||||||
|
df = pd.concat([
|
||||||
|
df.drop(columns=spec_cols), # 移除原有光谱列
|
||||||
|
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
|
||||||
|
], axis=1)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def clean_and_select_columns(df):
|
||||||
|
"""数据清理和列选择"""
|
||||||
|
# 移除NaN值
|
||||||
|
df = df.dropna()
|
||||||
|
|
||||||
|
# 过滤轮廓点数不足的样本
|
||||||
|
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||||||
|
|
||||||
|
# 过滤面积过小的样本
|
||||||
|
df = df[df['area'] >= 500]
|
||||||
|
|
||||||
|
# 列筛选:使用原来的硬编码索引删除逻辑
|
||||||
|
cols_to_remove = df.columns[np.r_[-14: -1]]
|
||||||
|
df = df.drop(columns=cols_to_remove)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def run_primary_classification(df, primary_model_path):
|
||||||
|
"""运行主要分类"""
|
||||||
|
try:
|
||||||
|
# 验证特征维度
|
||||||
|
try:
|
||||||
|
import joblib
|
||||||
|
scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl')
|
||||||
|
if os.path.exists(scaler_path):
|
||||||
|
scaler = joblib.load(scaler_path)
|
||||||
|
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
|
||||||
|
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
|
||||||
|
raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}维")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"警告: 无法验证特征维度: {e}")
|
||||||
|
|
||||||
|
df_pre = predict_with_model(
|
||||||
|
df,
|
||||||
|
primary_model_path,
|
||||||
|
model_type='SVM',
|
||||||
|
ProcessMethods1='SS',
|
||||||
|
ProcessMethods2='None'
|
||||||
|
)
|
||||||
|
return df_pre
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original):
|
||||||
|
"""根据需要运行二次分类"""
|
||||||
|
# 二次分类配置(保持在代码内)
|
||||||
|
secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m')
|
||||||
|
secondary_target_classes = [1, 2] # HDPE, LDPE
|
||||||
|
|
||||||
|
target_classes = set(secondary_target_classes or [])
|
||||||
|
mask_secondary = df_pre['Predictions'].isin(target_classes)
|
||||||
|
|
||||||
|
if not mask_secondary.any():
|
||||||
|
print("未找到目标类别样本,跳过二次分类")
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
print(f"为类别 {sorted(target_classes)} 运行二次分类")
|
||||||
|
|
||||||
|
# 检查二次模型是否存在
|
||||||
|
if not os.path.exists(secondary_model_path):
|
||||||
|
print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类")
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 图像信息的背景矫正
|
||||||
|
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
|
||||||
|
|
||||||
|
# 创建只包含目标类别的掩膜
|
||||||
|
mask_second = np.zeros_like(mask, dtype=np.uint16)
|
||||||
|
for idx in df_pre[mask_secondary].index:
|
||||||
|
contour = df_pre.loc[idx, 'contour']
|
||||||
|
if isinstance(contour, list) and len(contour) > 0:
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
cv2.fillPoly(mask_second, [contour_array], idx + 1)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
df_shape = extract_features(df_correct, mask_second)
|
||||||
|
|
||||||
|
# 确保使用前13列作为模型输入特征
|
||||||
|
if len(df_shape.columns) >= 13:
|
||||||
|
df_shape = df_shape.iloc[:, :13]
|
||||||
|
|
||||||
|
# 二次分类
|
||||||
|
df_secondary = predict_with_model(
|
||||||
|
df_shape,
|
||||||
|
secondary_model_path,
|
||||||
|
model_type='SVM',
|
||||||
|
ProcessMethods1='None',
|
||||||
|
ProcessMethods2='None'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新预测结果(类别+1)
|
||||||
|
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}")
|
||||||
|
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_class7_shadow(df_pre, rgb_img):
|
||||||
|
"""后处理类别7/8中的背景阴影(更稳健)"""
|
||||||
|
# 7和8都纳入检查范围
|
||||||
|
mask_targets = df_pre['Predictions'].isin([7, 8])
|
||||||
|
if not mask_targets.any():
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
print(f"处理 {mask_targets.sum()} 个类别7/8样本,识别背景阴影...")
|
||||||
|
|
||||||
|
# 灰度图
|
||||||
|
if hasattr(rgb_img, 'mode'): # PIL Image
|
||||||
|
rgb_img_array = np.array(rgb_img)
|
||||||
|
else:
|
||||||
|
rgb_img_array = rgb_img
|
||||||
|
if len(rgb_img_array.shape) == 3:
|
||||||
|
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
|
||||||
|
else:
|
||||||
|
gray_img = rgb_img_array
|
||||||
|
|
||||||
|
# 更稳的梯度(Scharr)
|
||||||
|
grad_x = cv2.Scharr(gray_img, cv2.CV_64F, 1, 0)
|
||||||
|
grad_y = cv2.Scharr(gray_img, cv2.CV_64F, 0, 1)
|
||||||
|
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
|
||||||
|
|
||||||
|
# 统计指标
|
||||||
|
edge_ratios = []
|
||||||
|
contrast_norms = []
|
||||||
|
areas_list = []
|
||||||
|
measures_per_idx = {}
|
||||||
|
|
||||||
|
edge_thick = 3
|
||||||
|
ring_thick = 5
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
for idx in df_pre[mask_targets].index:
|
||||||
|
try:
|
||||||
|
contour = df_pre.loc[idx, 'contour']
|
||||||
|
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
|
||||||
|
continue
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
if len(contour_array.shape) == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
poly_mask = np.zeros(gray_img.shape, dtype=np.uint8)
|
||||||
|
cv2.fillPoly(poly_mask, [contour_array], 255)
|
||||||
|
|
||||||
|
# 边界带
|
||||||
|
edge_mask = np.zeros_like(poly_mask)
|
||||||
|
cv2.drawContours(edge_mask, [contour_array], -1, 255, thickness=edge_thick)
|
||||||
|
|
||||||
|
# 外环:膨胀边界去掉边界本身与内区
|
||||||
|
ring_mask = cv2.dilate(edge_mask, np.ones((ring_thick, ring_thick), np.uint8), iterations=1)
|
||||||
|
ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(edge_mask))
|
||||||
|
ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(poly_mask))
|
||||||
|
|
||||||
|
edge_vals = gradient_magnitude[edge_mask > 0]
|
||||||
|
ring_vals = gradient_magnitude[ring_mask > 0]
|
||||||
|
if edge_vals.size == 0 or ring_vals.size == 0:
|
||||||
|
continue
|
||||||
|
r_edge = float(np.median(edge_vals) / (np.median(ring_vals) + eps))
|
||||||
|
|
||||||
|
inside_vals = gray_img[poly_mask > 0]
|
||||||
|
outside_vals = gray_img[ring_mask > 0]
|
||||||
|
if inside_vals.size == 0 or outside_vals.size == 0:
|
||||||
|
continue
|
||||||
|
dI = float(np.median(inside_vals) - np.median(outside_vals))
|
||||||
|
c_norm = abs(dI) / (np.std(outside_vals) + eps)
|
||||||
|
|
||||||
|
# 面积(可选保护)
|
||||||
|
area_val = None
|
||||||
|
if 'area' in df_pre.columns:
|
||||||
|
try:
|
||||||
|
area_val = float(df_pre.loc[idx, 'area'])
|
||||||
|
except Exception:
|
||||||
|
area_val = None
|
||||||
|
|
||||||
|
edge_ratios.append(r_edge)
|
||||||
|
contrast_norms.append(c_norm)
|
||||||
|
areas_list.append(area_val if area_val is not None else 0.0)
|
||||||
|
measures_per_idx[idx] = (r_edge, c_norm, area_val)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not measures_per_idx:
|
||||||
|
print("无可用的7/8类样本进行阴影判别")
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
def robust_q(arr, q):
|
||||||
|
vals = [v for v in arr if v is not None]
|
||||||
|
return float(np.percentile(vals, q)) if len(vals) > 0 else None
|
||||||
|
|
||||||
|
# 稳健阈值(低于30分位更像阴影)
|
||||||
|
r_thresh = robust_q(edge_ratios, 30.0)
|
||||||
|
c_thresh = robust_q(contrast_norms, 30.0)
|
||||||
|
# 面积保护:仅对较小目标允许改写,阈值取面积分布的40%分位,限定上限避免过大
|
||||||
|
a_thresh = robust_q(areas_list, 40.0)
|
||||||
|
if a_thresh is None or a_thresh <= 0:
|
||||||
|
a_thresh = 1200.0
|
||||||
|
a_thresh = min(a_thresh, 2000.0)
|
||||||
|
|
||||||
|
indices_to_update = []
|
||||||
|
for idx, (r_edge, c_norm, area_val) in measures_per_idx.items():
|
||||||
|
small_enough = (area_val is None) or (area_val <= a_thresh)
|
||||||
|
if (r_thresh is not None and c_thresh is not None and small_enough):
|
||||||
|
# 两个指标都低,且面积不大 -> 判定为阴影
|
||||||
|
if (r_edge < r_thresh) and (c_norm < c_thresh):
|
||||||
|
indices_to_update.append(idx)
|
||||||
|
|
||||||
|
if indices_to_update:
|
||||||
|
# 改为背景(0),而不是9(PVC)
|
||||||
|
df_pre.loc[indices_to_update, 'Predictions'] = 9
|
||||||
|
print(f"将 {len(indices_to_update)} 个样本从类别7/8改为背景(阴影),面积阈值≈{a_thresh:.0f}")
|
||||||
|
else:
|
||||||
|
print("无需更新类别7/8样本")
|
||||||
|
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
|
||||||
|
def write_outputs(bil_path, df_pre, output_path):
|
||||||
|
"""写入输出结果"""
|
||||||
|
try:
|
||||||
|
# 收缩轮廓
|
||||||
|
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
|
||||||
|
|
||||||
|
# 保存ENVI分类结果
|
||||||
|
save_envi_classification(bil_path, df_pre, output_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"保存结果失败: output_path={output_path}") from e
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
"""主函数"""
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
bil_path = args.bil_path
|
bil_path = args.bil_path
|
||||||
output_path = args.output_path
|
output_path = args.output_path
|
||||||
primary_model_path = args.model_path
|
primary_model_path = args.model_path
|
||||||
primary_model_type = 'SVM'
|
segmentation_model_path = None
|
||||||
primary_process_methods1 = 'SS'
|
|
||||||
primary_process_methods2 = "None"
|
|
||||||
|
|
||||||
# secondary_model_path = args.secondary_model
|
|
||||||
# secondary_model_type = args.secondary_model_type
|
|
||||||
# secondary_process_methods1 = args.secondary_process_methods1
|
|
||||||
# secondary_process_methods2 = args.secondary_process_methods2
|
|
||||||
# secondary_target_classes = args.secondary_target_classes
|
|
||||||
|
|
||||||
secondary_model_path = "E:\code\plastic\plastic20260224\plastic\plastic\modelsave\HDPELDPE_model\svm.m"
|
|
||||||
secondary_model_type = 'SVM'
|
|
||||||
secondary_process_methods1 = 'None'
|
|
||||||
secondary_process_methods2 = 'None'
|
|
||||||
secondary_target_classes = [1, 2]
|
|
||||||
# 记录总开始时间
|
# 记录总开始时间
|
||||||
total_start_time = time.time()
|
total_start_time = time.time()
|
||||||
# 处理BIL文件生成RGB图像
|
|
||||||
print("Processing BIL file to generate RGB image...\n")
|
|
||||||
rgb_img = process_bil_files(bil_path)
|
|
||||||
|
|
||||||
# 修改hdr
|
|
||||||
change_hdr_file(bil_path)
|
|
||||||
segmentation_start_time = time.time()
|
|
||||||
# 生成掩膜,mask为16位的塑料标签掩膜
|
|
||||||
print("Generating mask ...\n")
|
|
||||||
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
|
||||||
image=rgb_img, # 直接传入cv2.imread的结果
|
|
||||||
filter_method='threshold',
|
|
||||||
diameter=None,
|
|
||||||
flow_threshold=0.4,
|
|
||||||
cellprob_threshold=-1,
|
|
||||||
detect_filter=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提取特征
|
|
||||||
print("Extracting features from BIL file...\n")
|
|
||||||
df = process_images(bil_path, mask)
|
|
||||||
|
|
||||||
# 背景校正(保持现有逻辑)
|
|
||||||
print("Applying background correction...\n")
|
|
||||||
df_correct = process_images_background(bil_path, mask)
|
|
||||||
|
|
||||||
# 自动识别光谱列范围:假设从第2列开始连续为光谱列,长度=背景校正矩阵列数
|
|
||||||
spec_start = 1
|
|
||||||
spec_len_src = len(df_correct)
|
|
||||||
spec_end_src = spec_start + spec_len_src
|
|
||||||
|
|
||||||
# 归一化(按通道逐列相除)
|
|
||||||
df.iloc[:, spec_start:spec_end_src] = df.iloc[:, spec_start:spec_end_src].div(df_correct, axis=1)
|
|
||||||
|
|
||||||
# 读取当前BIL的波长,并重采样到训练用波长
|
|
||||||
src_waves = read_wavelengths_from_hdr(bil_path)
|
|
||||||
need_resample = (src_waves.size > 0) and (src_waves.size != len(TRAIN_WAVELENGTHS) or not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2))
|
|
||||||
if need_resample:
|
|
||||||
print(f"Resampling spectra: src {src_waves.size} bands -> dst {len(TRAIN_WAVELENGTHS)} bands\n")
|
|
||||||
X_src = df.iloc[:, spec_start:spec_end_src].to_numpy(dtype=np.float64)
|
|
||||||
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
|
|
||||||
|
|
||||||
# 用训练维度(168)替换原光谱列;其余形状/统计特征保持不变
|
|
||||||
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
|
|
||||||
df = pd.concat(
|
|
||||||
[
|
|
||||||
df.iloc[:, :spec_start],
|
|
||||||
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index),
|
|
||||||
df.iloc[:, spec_end_src:]
|
|
||||||
],
|
|
||||||
axis=1
|
|
||||||
)
|
|
||||||
# 更新光谱列的新范围
|
|
||||||
spec_end_src = spec_start + len(TRAIN_WAVELENGTHS)
|
|
||||||
|
|
||||||
# 数据清理(保持原有规则)
|
|
||||||
print("Cleaning data...\n")
|
|
||||||
df = df.dropna()
|
|
||||||
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
|
||||||
df = df[df['area'] >= 500]
|
|
||||||
|
|
||||||
# 列筛选:此时光谱列已与训练对齐为168维,因此区间索引仍可复用
|
|
||||||
cols_to_remove = df.columns[np.r_[1:5, 87:110, 166:169, -10:-1]]
|
|
||||||
df = df.drop(columns=cols_to_remove)
|
|
||||||
|
|
||||||
# 继续后续流程
|
|
||||||
segmentation_time = time.time() - segmentation_start_time
|
|
||||||
df = df.iloc[:, :]
|
|
||||||
# 预测分类(分类阶段)
|
|
||||||
classification_start_time = time.time()
|
|
||||||
# 预测分类
|
|
||||||
print("Predicting classes...\n")
|
|
||||||
loaded_model = load_model(primary_model_path)
|
|
||||||
|
|
||||||
# 断言:数值特征维度应等于训练时的scaler输入维度
|
|
||||||
try:
|
try:
|
||||||
import joblib
|
# 验证输入
|
||||||
scaler = joblib.load(os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl'))
|
validate_inputs(bil_path, output_path, primary_model_path)
|
||||||
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
|
bands =[912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
|
||||||
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
|
|
||||||
raise ValueError(f"Feature dimension mismatch: current {len(numeric_cols)} != scaler {scaler.mean_.shape[0]}. Check resampling and column selection.")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
df_pre = predict_with_model(
|
# 修改HDR文件
|
||||||
df,
|
change_hdr_file(bil_path,bands)
|
||||||
primary_model_path,
|
|
||||||
model_type=primary_model_type,
|
|
||||||
ProcessMethods1=primary_process_methods1,
|
|
||||||
ProcessMethods2=primary_process_methods2
|
|
||||||
)
|
|
||||||
|
|
||||||
# 对HDPE和LDPE进行二次分类
|
# 处理BIL文件生成RGB图像
|
||||||
# 从第一次分类结果中提取SECONDARY_TARGET_CLASSES类别的掩膜轮廓
|
print("处理BIL文件生成RGB图像...")
|
||||||
target_classes = set(secondary_target_classes or [])
|
rgb_img = generate_rgb(bil_path)
|
||||||
mask_secondary = df_pre['Predictions'].isin(target_classes)
|
|
||||||
|
|
||||||
if mask_secondary.any():
|
# 分割阶段
|
||||||
# 只有在找到目标类别时才进行背景校正和二次分类
|
segmentation_start_time = time.time()
|
||||||
print(f"Running secondary classification for classes: {sorted(target_classes)}")
|
print("生成掩膜...")
|
||||||
|
mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path)
|
||||||
# 图像信息的背景矫正
|
|
||||||
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
|
|
||||||
|
|
||||||
# 创建新的掩膜mask_second,只包含目标类别的轮廓
|
|
||||||
mask_second = np.zeros_like(mask, dtype=np.uint16)
|
|
||||||
|
|
||||||
for idx in df_pre[mask_secondary].index:
|
|
||||||
contour = df_pre.loc[idx, 'contour']
|
|
||||||
if isinstance(contour, list) and len(contour) > 0:
|
|
||||||
contour_array = np.array(contour, dtype=np.int32)
|
|
||||||
cv2.fillPoly(mask_second, [contour_array], idx + 1) # 使用索引+1作为标签
|
|
||||||
|
|
||||||
# 提取特征
|
# 提取特征
|
||||||
df_shape = extract_features(df_correct, mask_second)
|
print("从BIL文件提取特征...")
|
||||||
|
df = extract_primary_features(bil_path, mask)
|
||||||
|
|
||||||
# 确保使用第2到13列作为模型输入特征
|
# 背景校正
|
||||||
if len(df_shape.columns) >= 13:
|
print("应用背景校正...")
|
||||||
df_shape = df_shape.iloc[:, :13]
|
bg_spectrum = compute_background_spectrum(bil_path, mask)
|
||||||
|
|
||||||
# 二次分类:使用第二个模型预测并更新分类结果
|
# 背景校正 + 仅在与训练相机波长不一致时重采样
|
||||||
if secondary_model_path:
|
df = apply_background_and_optional_resample(df, bg_spectrum, bil_path)
|
||||||
df_secondary = predict_with_model(
|
|
||||||
df_shape,
|
|
||||||
secondary_model_path,
|
|
||||||
model_type=secondary_model_type,
|
|
||||||
ProcessMethods1=secondary_process_methods1,
|
|
||||||
ProcessMethods2=secondary_process_methods2
|
|
||||||
)
|
|
||||||
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
|
|
||||||
else:
|
|
||||||
print("Secondary model path not provided; skipping secondary classification.\n")
|
|
||||||
else:
|
|
||||||
print("No samples from target classes found; skipping secondary classification.\n")
|
|
||||||
|
|
||||||
# 识别类别7中的背景阴影误判:通过边界清晰度特征
|
# 数据清理和列选择
|
||||||
# 真正的类别7边界清晰,背景阴影边界模糊
|
print("清理数据...")
|
||||||
class_7_mask = df_pre['Predictions'] == 7
|
df = clean_and_select_columns(df)
|
||||||
if class_7_mask.any():
|
|
||||||
print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n")
|
|
||||||
|
|
||||||
# 将PIL Image转换为numpy数组
|
segmentation_time = time.time() - segmentation_start_time
|
||||||
if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image
|
|
||||||
rgb_img_array = np.array(rgb_img)
|
|
||||||
else:
|
|
||||||
rgb_img_array = rgb_img
|
|
||||||
|
|
||||||
# 转换为灰度图(用于计算梯度)
|
# 分类阶段
|
||||||
if len(rgb_img_array.shape) == 3:
|
classification_start_time = time.time()
|
||||||
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
|
print("预测分类...")
|
||||||
else:
|
df_pre = run_primary_classification(df, primary_model_path)
|
||||||
gray_img = rgb_img_array
|
|
||||||
|
|
||||||
# 计算梯度图(使用Sobel算子)
|
# 二次分类
|
||||||
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
|
df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original)
|
||||||
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
|
|
||||||
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
|
|
||||||
|
|
||||||
# 先收集所有类别7样本的边缘梯度值,用于确定阈值
|
# 后处理类别7阴影
|
||||||
all_class7_gradients = []
|
df_pre = postprocess_class7_shadow(df_pre, rgb_img)
|
||||||
valid_indices = []
|
|
||||||
for idx in df_pre[class_7_mask].index:
|
|
||||||
try:
|
|
||||||
contour = df_pre.loc[idx, 'contour']
|
|
||||||
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
|
|
||||||
continue
|
|
||||||
|
|
||||||
contour_array = np.array(contour, dtype=np.int32)
|
classification_time = time.time() - classification_start_time
|
||||||
if len(contour_array.shape) == 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
|
# 保存结果
|
||||||
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
|
print("保存ENVI分类结果...")
|
||||||
edge_gradients = gradient_magnitude[mask_img > 0]
|
write_outputs(bil_path, df_pre, output_path)
|
||||||
if len(edge_gradients) > 0:
|
|
||||||
all_class7_gradients.extend(edge_gradients)
|
|
||||||
valid_indices.append(idx)
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 基于类别7样本的梯度分布确定阈值
|
print(f"ENVI分类结果已保存至: {output_path}")
|
||||||
# 使用类别7样本梯度值的中位数作为基准,低于某个分位数(如30%)的认为是背景阴影
|
|
||||||
if len(all_class7_gradients) > 0:
|
|
||||||
gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数
|
|
||||||
else:
|
|
||||||
gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本,使用整张图的30%分位数
|
|
||||||
|
|
||||||
print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n")
|
# 计算总耗时
|
||||||
|
total_time = time.time() - total_start_time
|
||||||
|
|
||||||
# 处理每个类别7的样本,判断是否为背景阴影
|
# 打印耗时统计
|
||||||
indices_to_update = []
|
print(f"\n{'=' * 60}")
|
||||||
for idx in valid_indices:
|
print("处理完成")
|
||||||
try:
|
print(f"{'=' * 60}")
|
||||||
contour = df_pre.loc[idx, 'contour']
|
print(f"分割耗时: {segmentation_time:.2f} 秒")
|
||||||
contour_array = np.array(contour, dtype=np.int32)
|
print(f"分类耗时: {classification_time:.2f} 秒")
|
||||||
|
print(f"总耗时: {total_time:.2f} 秒")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"结果已保存至: {output_path}")
|
||||||
|
|
||||||
# 创建轮廓掩膜(线宽为2像素,用于提取边缘)
|
except Exception as e:
|
||||||
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
|
print(f"处理失败: {e}")
|
||||||
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
|
raise
|
||||||
|
|
||||||
# 提取轮廓边缘的梯度值
|
|
||||||
edge_gradients = gradient_magnitude[mask_img > 0]
|
|
||||||
if len(edge_gradients) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 计算轮廓边缘的平均梯度强度
|
|
||||||
mean_gradient = np.mean(edge_gradients)
|
|
||||||
|
|
||||||
# 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊)
|
|
||||||
if mean_gradient < gradient_threshold:
|
|
||||||
indices_to_update.append(idx)
|
|
||||||
print(
|
|
||||||
f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing sample at index {idx}: {str(e)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 将背景阴影的类别7改为类别9
|
|
||||||
if indices_to_update:
|
|
||||||
df_pre.loc[indices_to_update, 'Predictions'] = 9
|
|
||||||
print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n")
|
|
||||||
else:
|
|
||||||
print("No samples needed to be updated from class 7\n")
|
|
||||||
classification_time = time.time() - classification_start_time
|
|
||||||
|
|
||||||
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
|
|
||||||
# 保存ENVI分类结果
|
|
||||||
print("Saving ENVI classification results...\n")
|
|
||||||
save_envi_classification(bil_path, df_pre, output_path)
|
|
||||||
print(f"ENVI classification results saved to: {output_path}")
|
|
||||||
# 计算总耗时
|
|
||||||
total_time = time.time() - total_start_time
|
|
||||||
|
|
||||||
# 打印耗时统计
|
|
||||||
print(f"\n{'=' * 60}")
|
|
||||||
print(f"处理完成")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
print(f"分割耗时: {segmentation_time:.2f} 秒")
|
|
||||||
print(f"分类耗时: {classification_time:.2f} 秒")
|
|
||||||
print(f"总耗时: {total_time:.2f} 秒")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
print(f"结果已保存至: {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
764
main_batch_nosample.py
Normal file
764
main_batch_nosample.py
Normal file
@ -0,0 +1,764 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import pandas as pd
|
||||||
|
from bil2rgb import process_bil_files
|
||||||
|
from classification_model.Parallel.predict_plastic import load_model, predict_with_model
|
||||||
|
from mask import detect_microplastic_mask_from_array
|
||||||
|
from shape_spectral import process_images
|
||||||
|
from shape_spectral_background import process_images_background
|
||||||
|
from extact_shape import shape_correct_background, extract_features
|
||||||
|
import time
|
||||||
|
##批量预测文件夹内的bil文件,不进行降采样,使用新采集的数据进行训练
|
||||||
|
matplotlib.use('TkAgg')
|
||||||
|
|
||||||
|
# 训练相机波长(168通道)
|
||||||
|
TRAIN_WAVELENGTHS = [
|
||||||
|
898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2
|
||||||
|
]
|
||||||
|
|
||||||
|
def read_wavelengths_from_hdr(bil_path):
|
||||||
|
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||||
|
if not os.path.exists(hdr_path):
|
||||||
|
return np.array([], dtype=np.float64)
|
||||||
|
with open(hdr_path, 'r') as f:
|
||||||
|
txt = f.read()
|
||||||
|
if 'wavelength' not in txt:
|
||||||
|
return np.array([], dtype=np.float64)
|
||||||
|
seg = txt.split('wavelength', 1)[1]
|
||||||
|
seg = seg[seg.find('{')+1: seg.find('}')]
|
||||||
|
vals = [v.strip() for v in seg.split(',') if v.strip()]
|
||||||
|
try:
|
||||||
|
waves = np.array([float(v) for v in vals], dtype=np.float64)
|
||||||
|
except Exception:
|
||||||
|
waves = np.array([], dtype=np.float64)
|
||||||
|
return waves
|
||||||
|
|
||||||
|
def resample_spectra_matrix(X, src_waves, dst_waves):
|
||||||
|
src = np.asarray(src_waves, dtype=np.float64)
|
||||||
|
dst = np.asarray(dst_waves, dtype=np.float64)
|
||||||
|
X = np.asarray(X, dtype=np.float64)
|
||||||
|
if src.size == 0 or dst.size == 0:
|
||||||
|
return X
|
||||||
|
# 线性插值,越界用端点外推,避免维度缺失
|
||||||
|
out = np.empty((X.shape[0], dst.size), dtype=np.float64)
|
||||||
|
for i in range(X.shape[0]):
|
||||||
|
row = X[i]
|
||||||
|
out[i] = np.interp(dst, src, row, left=row[0], right=row[-1])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def apply_background_no_resample(df, bg_spectrum):
|
||||||
|
"""
|
||||||
|
仅做背景校正,不做任何重采样。
|
||||||
|
- 自动选择以 wavelength_ 或 band_ 开头的光谱列
|
||||||
|
- 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正
|
||||||
|
"""
|
||||||
|
# 识别光谱列
|
||||||
|
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
|
||||||
|
if not spec_cols:
|
||||||
|
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
|
||||||
|
|
||||||
|
bg = np.asarray(bg_spectrum, dtype=np.float64).ravel()
|
||||||
|
if bg.size == 0:
|
||||||
|
raise ValueError("背景光谱长度为0,无法进行背景校正")
|
||||||
|
|
||||||
|
# 尾部对齐,取最小长度,避免维度不一致
|
||||||
|
n = min(len(spec_cols), bg.shape[0])
|
||||||
|
use_cols = spec_cols[-n:]
|
||||||
|
df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
"""解析命令行参数"""
|
||||||
|
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification - Batch processing')
|
||||||
|
|
||||||
|
# 必需参数
|
||||||
|
parser.add_argument('--input_dir', required=True, help='Path to input directory containing BIL files')
|
||||||
|
parser.add_argument('--output_dir', required=True, help='Path to output directory for classification results')
|
||||||
|
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
|
||||||
|
|
||||||
|
# 可选参数
|
||||||
|
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
|
||||||
|
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
|
||||||
|
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
|
||||||
|
|
||||||
|
# parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model')
|
||||||
|
# parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)')
|
||||||
|
# parser.add_argument('--secondary_process_methods1', default='None',
|
||||||
|
# help='Secondary process method 1 (default: None)')
|
||||||
|
# parser.add_argument('--secondary_process_methods2', default='None',
|
||||||
|
# help='Secondary process method 2 (default: None)')
|
||||||
|
# parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2],
|
||||||
|
# help='Target classes for secondary classification (space separated)')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 配置参数:直接在此修改
|
||||||
|
# ----------------------------
|
||||||
|
# BIL_PATH = r"D:/Data/Test/PET_bottle2.bil"
|
||||||
|
# OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil'
|
||||||
|
#
|
||||||
|
# PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m"
|
||||||
|
# PRIMARY_MODEL_TYPE = 'SVM'
|
||||||
|
# PRIMARY_PROCESS_METHODS1 = 'SS'
|
||||||
|
# PRIMARY_PROCESS_METHODS2 = 'None'
|
||||||
|
#
|
||||||
|
# SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None
|
||||||
|
# SECONDARY_MODEL_TYPE = 'SVM'
|
||||||
|
# SECONDARY_PROCESS_METHODS1 = 'None'
|
||||||
|
# SECONDARY_PROCESS_METHODS2 = 'None'
|
||||||
|
# SECONDARY_TARGET_CLASSES = [1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def read_hdr_file(bil_path):
|
||||||
|
hdr_path = bil_path.replace('.bil', '.hdr')
|
||||||
|
with open(hdr_path, 'r') as f:
|
||||||
|
header = f.readlines()
|
||||||
|
|
||||||
|
samples, lines = None, None
|
||||||
|
|
||||||
|
for line in header:
|
||||||
|
if line.startswith('samples'):
|
||||||
|
samples = int(line.split('=')[-1].strip())
|
||||||
|
if line.startswith('lines'):
|
||||||
|
lines = int(line.split('=')[-1].strip())
|
||||||
|
|
||||||
|
return samples, lines
|
||||||
|
|
||||||
|
|
||||||
|
def shrink_contours(bil_path, df, shrink_pixels=1):
|
||||||
|
"""
|
||||||
|
对DataFrame中的所有轮廓进行收缩操作,避免塑料之间的相连
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bil_path: BIL文件路径,用于获取图像尺寸
|
||||||
|
df: 包含contour列的DataFrame
|
||||||
|
shrink_pixels: 收缩的像素数,默认1像素
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的DataFrame,contour列已被收缩
|
||||||
|
"""
|
||||||
|
samples, lines = read_hdr_file(bil_path)
|
||||||
|
|
||||||
|
# 创建腐蚀核
|
||||||
|
kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8)
|
||||||
|
|
||||||
|
# 创建临时掩膜用于处理
|
||||||
|
temp_mask = np.zeros((lines, samples), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 创建DataFrame副本
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# 遍历每一行,更新轮廓
|
||||||
|
for idx, row in df.iterrows():
|
||||||
|
contour = row['contour']
|
||||||
|
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
if len(contour_array.shape) == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 清空临时掩膜
|
||||||
|
temp_mask.fill(0)
|
||||||
|
|
||||||
|
# 填充轮廓
|
||||||
|
cv2.fillPoly(temp_mask, [contour_array], 255)
|
||||||
|
|
||||||
|
# 对掩膜进行腐蚀操作
|
||||||
|
eroded_mask = cv2.erode(temp_mask, kernel, iterations=1)
|
||||||
|
|
||||||
|
# 重新提取轮廓
|
||||||
|
contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
if len(contours) > 0:
|
||||||
|
# 选择最大的轮廓(如果有多个)
|
||||||
|
largest_contour = max(contours, key=cv2.contourArea)
|
||||||
|
# 转换为列表格式,保持与原始格式一致
|
||||||
|
if len(largest_contour) >= 3:
|
||||||
|
updated_contour = largest_contour.reshape(-1, 2).tolist()
|
||||||
|
df.at[idx, 'contour'] = updated_contour
|
||||||
|
except Exception as e:
|
||||||
|
# 如果处理失败,保留原始轮廓
|
||||||
|
continue
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def save_envi_classification(bil_path, df, savepath):
|
||||||
|
samples, lines = read_hdr_file(bil_path)
|
||||||
|
classification_result = np.zeros((lines, samples), dtype=np.uint16)
|
||||||
|
|
||||||
|
# 预处理:清除可能存在的类别10和11(移到循环外以提高效率)
|
||||||
|
classification_result[(classification_result == 10)] = 0
|
||||||
|
classification_result[(classification_result == 11)] = 0
|
||||||
|
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
contour = row['contour']
|
||||||
|
prediction = int(row['Predictions']) + 1
|
||||||
|
contour = np.array(contour, dtype=np.int32)
|
||||||
|
|
||||||
|
cv2.fillPoly(classification_result, [contour], prediction)
|
||||||
|
|
||||||
|
output_path = savepath
|
||||||
|
|
||||||
|
with open(output_path, 'wb') as f:
|
||||||
|
classification_result.tofile(f)
|
||||||
|
|
||||||
|
header_content = f"""ENVI
|
||||||
|
description = {{
|
||||||
|
Classification Result.}}
|
||||||
|
samples = {samples}
|
||||||
|
lines = {lines}
|
||||||
|
bands = 1
|
||||||
|
header offset = 0
|
||||||
|
file type = ENVI Standard
|
||||||
|
data type = 2
|
||||||
|
interleave = bil
|
||||||
|
classes = 10
|
||||||
|
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }}
|
||||||
|
single pixel area = 0.000036
|
||||||
|
unit = mm2
|
||||||
|
byte order = 0
|
||||||
|
wavelength units = nm
|
||||||
|
"""
|
||||||
|
filename, ext = os.path.splitext(savepath)
|
||||||
|
# 替换扩展名为 '.hdr'
|
||||||
|
header_filename = filename + '.hdr'
|
||||||
|
|
||||||
|
with open(header_filename, 'w') as header_file:
|
||||||
|
header_file.write(header_content)
|
||||||
|
|
||||||
|
|
||||||
|
def change_hdr_file(bil_path, wavelengths=None):
|
||||||
|
# wavelengths=None 时仅在HDR缺失wavelength字段才写入;若提供则按提供内容写入
|
||||||
|
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||||
|
if not os.path.exists(hdr_path):
|
||||||
|
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(hdr_path, 'r') as file:
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
if 'wavelength' in content and wavelengths is None:
|
||||||
|
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if wavelengths is None:
|
||||||
|
print(f"No wavelengths provided and HDR lacks wavelength; skipping write to avoid wrong bands.")
|
||||||
|
return
|
||||||
|
|
||||||
|
needs_newline = not content.endswith('\n')
|
||||||
|
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
|
||||||
|
|
||||||
|
with open(hdr_path, 'a') as file:
|
||||||
|
if needs_newline:
|
||||||
|
file.write('\n')
|
||||||
|
file.write(wavelength_info)
|
||||||
|
|
||||||
|
print(f"Successfully ensured wavelength information in file: {os.path.basename(hdr_path)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_bil_files(input_dir):
|
||||||
|
"""获取输入目录中的所有BIL文件"""
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
raise FileNotFoundError(f"输入目录不存在: {input_dir}")
|
||||||
|
|
||||||
|
bil_files = []
|
||||||
|
for file in os.listdir(input_dir):
|
||||||
|
if file.lower().endswith('.bil'):
|
||||||
|
bil_path = os.path.join(input_dir, file)
|
||||||
|
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||||
|
if os.path.exists(hdr_path):
|
||||||
|
bil_files.append(bil_path)
|
||||||
|
else:
|
||||||
|
print(f"警告: 找到BIL文件 {file} 但缺少对应的HDR文件,跳过")
|
||||||
|
|
||||||
|
if not bil_files:
|
||||||
|
raise ValueError(f"在输入目录 {input_dir} 中未找到有效的BIL文件")
|
||||||
|
|
||||||
|
return sorted(bil_files)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_inputs(input_dir, output_dir, model_path):
|
||||||
|
"""验证输入参数"""
|
||||||
|
# 检查输入目录存在
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
raise FileNotFoundError(f"输入目录不存在: {input_dir}")
|
||||||
|
|
||||||
|
# 检查输出目录存在,如果不存在则创建
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
try:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"无法创建输出目录: {output_dir}") from e
|
||||||
|
|
||||||
|
# 检查模型文件存在
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"主模型文件不存在: {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_single_bil_file(bil_path):
|
||||||
|
"""验证单个BIL文件"""
|
||||||
|
# 检查BIL和HDR文件存在
|
||||||
|
if not os.path.exists(bil_path):
|
||||||
|
raise FileNotFoundError(f"BIL文件不存在: {bil_path}")
|
||||||
|
|
||||||
|
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||||
|
if not os.path.exists(hdr_path):
|
||||||
|
raise FileNotFoundError(f"HDR文件不存在: {hdr_path}")
|
||||||
|
|
||||||
|
# 检查BIL文件波段数是否足够
|
||||||
|
try:
|
||||||
|
from spectral.io import envi
|
||||||
|
img = envi.open(hdr_path, bil_path)
|
||||||
|
n_bands = img.nbands
|
||||||
|
# bil2rgb需要波段索引9, 59, 159
|
||||||
|
if n_bands <= 159:
|
||||||
|
raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def generate_rgb(bil_path):
|
||||||
|
"""处理BIL文件生成RGB图像"""
|
||||||
|
try:
|
||||||
|
rgb_img = process_bil_files(bil_path)
|
||||||
|
return rgb_img
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def run_segmentation(rgb_img):
|
||||||
|
"""运行分割获取掩膜"""
|
||||||
|
try:
|
||||||
|
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
||||||
|
image=rgb_img,
|
||||||
|
filter_method='threshold',
|
||||||
|
diameter=None,
|
||||||
|
flow_threshold=0.4,
|
||||||
|
cellprob_threshold=-1,
|
||||||
|
detect_filter=True
|
||||||
|
)
|
||||||
|
return mask, filter_mask_original
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e
|
||||||
|
|
||||||
|
|
||||||
|
def extract_primary_features(bil_path, mask):
|
||||||
|
"""提取主要特征"""
|
||||||
|
try:
|
||||||
|
df = process_images(bil_path, mask)
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def compute_background_spectrum(bil_path, mask):
|
||||||
|
"""计算背景光谱"""
|
||||||
|
try:
|
||||||
|
df_correct = process_images_background(bil_path, mask)
|
||||||
|
return df_correct
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def apply_background_and_optional_resample(df, bg_spectrum, bil_path):
|
||||||
|
"""应用背景校正和可选的重采样"""
|
||||||
|
# 识别光谱列:所有以wavelength_开头的列
|
||||||
|
spec_cols = [c for c in df.columns if c.startswith('wavelength_')]
|
||||||
|
|
||||||
|
if not spec_cols:
|
||||||
|
raise ValueError("未找到光谱列(以wavelength_开头的列)")
|
||||||
|
|
||||||
|
if len(spec_cols) != len(bg_spectrum):
|
||||||
|
raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配")
|
||||||
|
|
||||||
|
# 背景校正:用背景光谱逐列相除
|
||||||
|
df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1)
|
||||||
|
|
||||||
|
# 检查是否需要重采样
|
||||||
|
src_waves = read_wavelengths_from_hdr(bil_path)
|
||||||
|
need_resample = (src_waves.size > 0 and
|
||||||
|
(src_waves.size != len(TRAIN_WAVELENGTHS) or
|
||||||
|
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)))
|
||||||
|
|
||||||
|
if need_resample:
|
||||||
|
print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}")
|
||||||
|
|
||||||
|
# 提取光谱数据
|
||||||
|
X_src = df[spec_cols].to_numpy(dtype=np.float64)
|
||||||
|
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
|
||||||
|
|
||||||
|
# 替换光谱列
|
||||||
|
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
|
||||||
|
df = pd.concat([
|
||||||
|
df.drop(columns=spec_cols), # 移除原有光谱列
|
||||||
|
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
|
||||||
|
], axis=1)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def clean_and_select_columns(df):
|
||||||
|
"""数据清理和列选择"""
|
||||||
|
# 移除NaN值
|
||||||
|
df = df.dropna()
|
||||||
|
|
||||||
|
# 过滤轮廓点数不足的样本
|
||||||
|
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||||||
|
|
||||||
|
# 过滤面积过小的样本
|
||||||
|
df = df[df['area'] >= 500]
|
||||||
|
|
||||||
|
# 列筛选:使用原来的硬编码索引删除逻辑
|
||||||
|
# cols_to_remove = df.columns[np.r_[1:10, 11:15, 97:120, 176:179 ]]
|
||||||
|
cols_to_remove = df.columns[np.r_[-10: -1]]
|
||||||
|
df = df.drop(columns=cols_to_remove)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def run_primary_classification(df, primary_model_path):
|
||||||
|
"""运行主要分类"""
|
||||||
|
try:
|
||||||
|
# 验证特征维度
|
||||||
|
try:
|
||||||
|
import joblib
|
||||||
|
scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl')
|
||||||
|
if os.path.exists(scaler_path):
|
||||||
|
scaler = joblib.load(scaler_path)
|
||||||
|
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
|
||||||
|
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
|
||||||
|
raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}维")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"警告: 无法验证特征维度: {e}")
|
||||||
|
|
||||||
|
df_pre = predict_with_model(
|
||||||
|
df,
|
||||||
|
primary_model_path,
|
||||||
|
model_type='SVM',
|
||||||
|
ProcessMethods1='SS',
|
||||||
|
ProcessMethods2='None'
|
||||||
|
)
|
||||||
|
return df_pre
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original):
|
||||||
|
"""根据需要运行二次分类"""
|
||||||
|
# 二次分类配置(保持在代码内)
|
||||||
|
secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m')
|
||||||
|
secondary_target_classes = [1, 2] # HDPE, LDPE
|
||||||
|
|
||||||
|
target_classes = set(secondary_target_classes or [])
|
||||||
|
mask_secondary = df_pre['Predictions'].isin(target_classes)
|
||||||
|
|
||||||
|
if not mask_secondary.any():
|
||||||
|
print("未找到目标类别样本,跳过二次分类")
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
print(f"为类别 {sorted(target_classes)} 运行二次分类")
|
||||||
|
|
||||||
|
# 检查二次模型是否存在
|
||||||
|
if not os.path.exists(secondary_model_path):
|
||||||
|
print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类")
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 图像信息的背景矫正
|
||||||
|
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
|
||||||
|
|
||||||
|
# 创建只包含目标类别的掩膜
|
||||||
|
mask_second = np.zeros_like(mask, dtype=np.uint16)
|
||||||
|
for idx in df_pre[mask_secondary].index:
|
||||||
|
contour = df_pre.loc[idx, 'contour']
|
||||||
|
if isinstance(contour, list) and len(contour) > 0:
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
cv2.fillPoly(mask_second, [contour_array], idx + 1)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
df_shape = extract_features(df_correct, mask_second)
|
||||||
|
|
||||||
|
# 确保使用前13列作为模型输入特征
|
||||||
|
if len(df_shape.columns) >= 13:
|
||||||
|
df_shape = df_shape.iloc[:, :13]
|
||||||
|
|
||||||
|
# 二次分类
|
||||||
|
df_secondary = predict_with_model(
|
||||||
|
df_shape,
|
||||||
|
secondary_model_path,
|
||||||
|
model_type='SVM',
|
||||||
|
ProcessMethods1='None',
|
||||||
|
ProcessMethods2='None'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新预测结果(类别+1)
|
||||||
|
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}")
|
||||||
|
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_class7_shadow(df_pre, rgb_img):
|
||||||
|
"""后处理类别7中的背景阴影"""
|
||||||
|
class_7_mask = df_pre['Predictions'] == 7
|
||||||
|
if not class_7_mask.any():
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
print(f"处理 {class_7_mask.sum()} 个类别7样本,识别背景阴影...")
|
||||||
|
|
||||||
|
# 将PIL Image转换为numpy数组
|
||||||
|
if hasattr(rgb_img, 'mode'): # PIL Image
|
||||||
|
rgb_img_array = np.array(rgb_img)
|
||||||
|
else:
|
||||||
|
rgb_img_array = rgb_img
|
||||||
|
|
||||||
|
# 转换为灰度图
|
||||||
|
if len(rgb_img_array.shape) == 3:
|
||||||
|
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
|
||||||
|
else:
|
||||||
|
gray_img = rgb_img_array
|
||||||
|
|
||||||
|
# 计算梯度图
|
||||||
|
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
|
||||||
|
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
|
||||||
|
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
|
||||||
|
|
||||||
|
# 收集类别7样本的边缘梯度值
|
||||||
|
all_class7_gradients = []
|
||||||
|
valid_indices = []
|
||||||
|
|
||||||
|
for idx in df_pre[class_7_mask].index:
|
||||||
|
try:
|
||||||
|
contour = df_pre.loc[idx, 'contour']
|
||||||
|
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
if len(contour_array.shape) == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
|
||||||
|
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
|
||||||
|
edge_gradients = gradient_magnitude[mask_img > 0]
|
||||||
|
|
||||||
|
if len(edge_gradients) > 0:
|
||||||
|
all_class7_gradients.extend(edge_gradients)
|
||||||
|
valid_indices.append(idx)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确定梯度阈值
|
||||||
|
if len(all_class7_gradients) > 0:
|
||||||
|
gradient_threshold = np.percentile(all_class7_gradients, 30)
|
||||||
|
else:
|
||||||
|
gradient_threshold = np.percentile(gradient_magnitude, 30)
|
||||||
|
|
||||||
|
print(f"类别7梯度阈值: {gradient_threshold:.2f}")
|
||||||
|
|
||||||
|
# 处理每个类别7样本
|
||||||
|
indices_to_update = []
|
||||||
|
for idx in valid_indices:
|
||||||
|
try:
|
||||||
|
contour = df_pre.loc[idx, 'contour']
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
|
||||||
|
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
|
||||||
|
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
|
||||||
|
|
||||||
|
edge_gradients = gradient_magnitude[mask_img > 0]
|
||||||
|
if len(edge_gradients) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mean_gradient = np.mean(edge_gradients)
|
||||||
|
|
||||||
|
if mean_gradient < gradient_threshold:
|
||||||
|
indices_to_update.append(idx)
|
||||||
|
print(f"样本 {idx}: 平均梯度={mean_gradient:.2f}, 阈值={gradient_threshold:.2f} -> 识别为背景阴影")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理样本 {idx} 时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 更新分类结果
|
||||||
|
if indices_to_update:
|
||||||
|
df_pre.loc[indices_to_update, 'Predictions'] = 9
|
||||||
|
print(f"将 {len(indices_to_update)} 个样本从类别7改为类别9(背景阴影)")
|
||||||
|
else:
|
||||||
|
print("无需更新类别7样本")
|
||||||
|
|
||||||
|
return df_pre
|
||||||
|
|
||||||
|
|
||||||
|
def write_outputs(bil_path, df_pre, output_path):
|
||||||
|
"""写入输出结果"""
|
||||||
|
try:
|
||||||
|
# 收缩轮廓
|
||||||
|
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
|
||||||
|
|
||||||
|
# 保存ENVI分类结果
|
||||||
|
save_envi_classification(bil_path, df_pre, output_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"保存结果失败: output_path={output_path}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def process_single_file(bil_path, output_path, primary_model_path):
|
||||||
|
"""处理单个BIL文件的完整流程"""
|
||||||
|
try:
|
||||||
|
# 验证输入
|
||||||
|
validate_single_bil_file(bil_path)
|
||||||
|
|
||||||
|
# 修改HDR文件
|
||||||
|
change_hdr_file(bil_path)
|
||||||
|
|
||||||
|
# 处理BIL文件生成RGB图像
|
||||||
|
print(f" 处理BIL文件生成RGB图像...")
|
||||||
|
rgb_img = generate_rgb(bil_path)
|
||||||
|
|
||||||
|
# 分割阶段
|
||||||
|
segmentation_start_time = time.time()
|
||||||
|
print(f" 生成掩膜...")
|
||||||
|
mask, filter_mask_original = run_segmentation(rgb_img)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
print(f" 从BIL文件提取特征...")
|
||||||
|
df = extract_primary_features(bil_path, mask)
|
||||||
|
|
||||||
|
# 背景校正
|
||||||
|
print(f" 应用背景校正...")
|
||||||
|
bg_spectrum = compute_background_spectrum(bil_path, mask)
|
||||||
|
|
||||||
|
# 仅应用背景校正,不进行重采样
|
||||||
|
df = apply_background_no_resample(df, bg_spectrum)
|
||||||
|
|
||||||
|
# 数据清理和列选择
|
||||||
|
print(f" 清理数据...")
|
||||||
|
df = clean_and_select_columns(df)
|
||||||
|
|
||||||
|
segmentation_time = time.time() - segmentation_start_time
|
||||||
|
|
||||||
|
# 分类阶段
|
||||||
|
classification_start_time = time.time()
|
||||||
|
print(f" 预测分类...")
|
||||||
|
df_pre = run_primary_classification(df, primary_model_path)
|
||||||
|
|
||||||
|
# 二次分类
|
||||||
|
df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original)
|
||||||
|
|
||||||
|
# 后处理类别7阴影
|
||||||
|
df_pre = postprocess_class7_shadow(df_pre, rgb_img)
|
||||||
|
|
||||||
|
classification_time = time.time() - classification_start_time
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
print(f" 保存ENVI分类结果...")
|
||||||
|
write_outputs(bil_path, df_pre, output_path)
|
||||||
|
|
||||||
|
return segmentation_time, classification_time
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理文件失败 {os.path.basename(bil_path)}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数 - 批量处理"""
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
input_dir = args.input_dir
|
||||||
|
output_dir = args.output_dir
|
||||||
|
primary_model_path = args.model_path
|
||||||
|
|
||||||
|
# 记录总开始时间
|
||||||
|
total_start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证输入参数
|
||||||
|
validate_inputs(input_dir, output_dir, primary_model_path)
|
||||||
|
|
||||||
|
# 获取所有BIL文件
|
||||||
|
bil_files = get_bil_files(input_dir)
|
||||||
|
print(f"找到 {len(bil_files)} 个BIL文件待处理")
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
total_files = len(bil_files)
|
||||||
|
processed_files = 0
|
||||||
|
failed_files = 0
|
||||||
|
total_segmentation_time = 0
|
||||||
|
total_classification_time = 0
|
||||||
|
|
||||||
|
# 逐个处理文件
|
||||||
|
for i, bil_path in enumerate(bil_files, 1):
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"处理文件 {i}/{total_files}: {os.path.basename(bil_path)}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 生成输出文件名:原文件名 + "_classification.bil"
|
||||||
|
base_name = os.path.splitext(os.path.basename(bil_path))[0]
|
||||||
|
output_filename = f"{base_name}_classification.bil"
|
||||||
|
output_path = os.path.join(output_dir, output_filename)
|
||||||
|
|
||||||
|
# 处理单个文件
|
||||||
|
segmentation_time, classification_time = process_single_file(
|
||||||
|
bil_path, output_path, primary_model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新统计信息
|
||||||
|
total_segmentation_time += segmentation_time
|
||||||
|
total_classification_time += classification_time
|
||||||
|
processed_files += 1
|
||||||
|
|
||||||
|
print(f"文件 {os.path.basename(bil_path)} 处理完成")
|
||||||
|
print(f"结果保存至: {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"文件 {os.path.basename(bil_path)} 处理失败: {e}")
|
||||||
|
failed_files += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算平均耗时
|
||||||
|
if processed_files > 0:
|
||||||
|
avg_segmentation_time = total_segmentation_time / processed_files
|
||||||
|
avg_classification_time = total_classification_time / processed_files
|
||||||
|
avg_total_time = (total_segmentation_time + total_classification_time) / processed_files
|
||||||
|
|
||||||
|
# 计算总耗时
|
||||||
|
total_time = time.time() - total_start_time
|
||||||
|
|
||||||
|
# 打印汇总统计
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print("批量处理完成")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"总文件数: {total_files}")
|
||||||
|
print(f"成功处理: {processed_files}")
|
||||||
|
print(f"处理失败: {failed_files}")
|
||||||
|
print(f"成功率: {processed_files/total_files*100:.1f}%" if total_files > 0 else "成功率: 0%")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
if processed_files > 0:
|
||||||
|
print(f"平均分割耗时: {avg_segmentation_time:.2f} 秒")
|
||||||
|
print(f"平均分类耗时: {avg_classification_time:.2f} 秒")
|
||||||
|
print(f"平均总耗时: {avg_total_time:.2f} 秒")
|
||||||
|
print(f"实际总耗时: {total_time:.2f} 秒")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"结果保存目录: {output_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"批量处理失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
445
mainv1.py
Normal file
445
mainv1.py
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
from bil2rgb import process_bil_files
|
||||||
|
from classification_model.Parallel.predict_plastic import load_model, predict_with_model
|
||||||
|
from mask import detect_microplastic_mask_from_array
|
||||||
|
from shape_spectral import process_images
|
||||||
|
from shape_spectral_background import process_images_background
|
||||||
|
from extact_shape import shape_correct_background, extract_features
|
||||||
|
import time
|
||||||
|
|
||||||
|
matplotlib.use('TkAgg')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
"""解析命令行参数"""
|
||||||
|
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification')
|
||||||
|
|
||||||
|
# 必需参数
|
||||||
|
parser.add_argument('--bil_path', required=True, help='Path to input BIL file')
|
||||||
|
parser.add_argument('--output_path', required=True, help='Path to output classification result')
|
||||||
|
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
|
||||||
|
|
||||||
|
# 可选参数
|
||||||
|
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
|
||||||
|
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
|
||||||
|
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
|
||||||
|
|
||||||
|
# parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model')
|
||||||
|
# parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)')
|
||||||
|
# parser.add_argument('--secondary_process_methods1', default='None',
|
||||||
|
# help='Secondary process method 1 (default: None)')
|
||||||
|
# parser.add_argument('--secondary_process_methods2', default='None',
|
||||||
|
# help='Secondary process method 2 (default: None)')
|
||||||
|
# parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2],
|
||||||
|
# help='Target classes for secondary classification (space separated)')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 配置参数:直接在此修改
|
||||||
|
# ----------------------------
|
||||||
|
# BIL_PATH = r"D:/Data/Test/PET_bottle2.bil"
|
||||||
|
# OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil'
|
||||||
|
#
|
||||||
|
# PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m"
|
||||||
|
# PRIMARY_MODEL_TYPE = 'SVM'
|
||||||
|
# PRIMARY_PROCESS_METHODS1 = 'SS'
|
||||||
|
# PRIMARY_PROCESS_METHODS2 = 'None'
|
||||||
|
#
|
||||||
|
# SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None
|
||||||
|
# SECONDARY_MODEL_TYPE = 'SVM'
|
||||||
|
# SECONDARY_PROCESS_METHODS1 = 'None'
|
||||||
|
# SECONDARY_PROCESS_METHODS2 = 'None'
|
||||||
|
# SECONDARY_TARGET_CLASSES = [1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def read_hdr_file(bil_path):
|
||||||
|
hdr_path = bil_path.replace('.bil', '.hdr')
|
||||||
|
with open(hdr_path, 'r') as f:
|
||||||
|
header = f.readlines()
|
||||||
|
|
||||||
|
samples, lines = None, None
|
||||||
|
|
||||||
|
for line in header:
|
||||||
|
if line.startswith('samples'):
|
||||||
|
samples = int(line.split('=')[-1].strip())
|
||||||
|
if line.startswith('lines'):
|
||||||
|
lines = int(line.split('=')[-1].strip())
|
||||||
|
|
||||||
|
return samples, lines
|
||||||
|
|
||||||
|
|
||||||
|
def shrink_contours(bil_path, df, shrink_pixels=1):
|
||||||
|
"""
|
||||||
|
对DataFrame中的所有轮廓进行收缩操作,避免塑料之间的相连
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bil_path: BIL文件路径,用于获取图像尺寸
|
||||||
|
df: 包含contour列的DataFrame
|
||||||
|
shrink_pixels: 收缩的像素数,默认1像素
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的DataFrame,contour列已被收缩
|
||||||
|
"""
|
||||||
|
samples, lines = read_hdr_file(bil_path)
|
||||||
|
|
||||||
|
# 创建腐蚀核
|
||||||
|
kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8)
|
||||||
|
|
||||||
|
# 创建临时掩膜用于处理
|
||||||
|
temp_mask = np.zeros((lines, samples), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 创建DataFrame副本
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# 遍历每一行,更新轮廓
|
||||||
|
for idx, row in df.iterrows():
|
||||||
|
contour = row['contour']
|
||||||
|
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
if len(contour_array.shape) == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 清空临时掩膜
|
||||||
|
temp_mask.fill(0)
|
||||||
|
|
||||||
|
# 填充轮廓
|
||||||
|
cv2.fillPoly(temp_mask, [contour_array], 255)
|
||||||
|
|
||||||
|
# 对掩膜进行腐蚀操作
|
||||||
|
eroded_mask = cv2.erode(temp_mask, kernel, iterations=1)
|
||||||
|
|
||||||
|
# 重新提取轮廓
|
||||||
|
contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
if len(contours) > 0:
|
||||||
|
# 选择最大的轮廓(如果有多个)
|
||||||
|
largest_contour = max(contours, key=cv2.contourArea)
|
||||||
|
# 转换为列表格式,保持与原始格式一致
|
||||||
|
if len(largest_contour) >= 3:
|
||||||
|
updated_contour = largest_contour.reshape(-1, 2).tolist()
|
||||||
|
df.at[idx, 'contour'] = updated_contour
|
||||||
|
except Exception as e:
|
||||||
|
# 如果处理失败,保留原始轮廓
|
||||||
|
continue
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def save_envi_classification(bil_path, df, savepath):
|
||||||
|
samples, lines = read_hdr_file(bil_path)
|
||||||
|
classification_result = np.zeros((lines, samples), dtype=np.uint16)
|
||||||
|
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
contour = row['contour']
|
||||||
|
prediction = int(row['Predictions']) + 1
|
||||||
|
contour = np.array(contour, dtype=np.int32)
|
||||||
|
# 先将 classification_result 中的 10 和 11 替换为 0
|
||||||
|
classification_result[(classification_result == 10)] = 0
|
||||||
|
|
||||||
|
cv2.fillPoly(classification_result, [contour], prediction)
|
||||||
|
|
||||||
|
output_path = savepath
|
||||||
|
|
||||||
|
with open(output_path, 'wb') as f:
|
||||||
|
classification_result.tofile(f)
|
||||||
|
|
||||||
|
header_content = f"""ENVI
|
||||||
|
description = {{
|
||||||
|
Classification Result.}}
|
||||||
|
samples = {samples}
|
||||||
|
lines = {lines}
|
||||||
|
bands = 1
|
||||||
|
header offset = 0
|
||||||
|
file type = ENVI Standard
|
||||||
|
data type = 2
|
||||||
|
interleave = bil
|
||||||
|
classes = 10
|
||||||
|
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }}
|
||||||
|
single pixel area = 0.000036
|
||||||
|
unit = mm2
|
||||||
|
byte order = 0
|
||||||
|
wavelength units = nm
|
||||||
|
"""
|
||||||
|
filename, ext = os.path.splitext(savepath)
|
||||||
|
# 替换扩展名为 '.hdr'
|
||||||
|
header_filename = filename + '.hdr'
|
||||||
|
|
||||||
|
with open(header_filename, 'w') as header_file:
|
||||||
|
header_file.write(header_content)
|
||||||
|
|
||||||
|
|
||||||
|
def change_hdr_file(bil_path):
|
||||||
|
# 定义要追加的波长信息
|
||||||
|
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
|
||||||
|
|
||||||
|
# 将.bil路径转换为.hdr路径
|
||||||
|
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||||
|
|
||||||
|
# 检查.hdr文件是否存在
|
||||||
|
if not os.path.exists(hdr_path):
|
||||||
|
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
with open(hdr_path, 'r') as file:
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
# 检查是否已包含波长信息
|
||||||
|
if 'wavelength' in content:
|
||||||
|
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查文件是否以换行符结尾
|
||||||
|
needs_newline = not content.endswith('\n')
|
||||||
|
|
||||||
|
# 追加波长信息
|
||||||
|
with open(hdr_path, 'a') as file:
|
||||||
|
if needs_newline:
|
||||||
|
file.write('\n') # 确保新内容从新行开始
|
||||||
|
file.write(wavelength_info + '\n')
|
||||||
|
|
||||||
|
print(f"Successfully added wavelength information to file: {os.path.basename(hdr_path)}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
bil_path = args.bil_path
|
||||||
|
output_path = args.output_path
|
||||||
|
primary_model_path = args.model_path
|
||||||
|
primary_model_type = 'SVM'
|
||||||
|
primary_process_methods1 = 'SS'
|
||||||
|
primary_process_methods2 = "None"
|
||||||
|
|
||||||
|
# secondary_model_path = args.secondary_model
|
||||||
|
# secondary_model_type = args.secondary_model_type
|
||||||
|
# secondary_process_methods1 = args.secondary_process_methods1
|
||||||
|
# secondary_process_methods2 = args.secondary_process_methods2
|
||||||
|
# secondary_target_classes = args.secondary_target_classes
|
||||||
|
|
||||||
|
secondary_model_path = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m"
|
||||||
|
secondary_model_type = 'SVM'
|
||||||
|
secondary_process_methods1 = 'None'
|
||||||
|
secondary_process_methods2 = 'None'
|
||||||
|
secondary_target_classes = [1, 2]
|
||||||
|
# 记录总开始时间
|
||||||
|
total_start_time = time.time()
|
||||||
|
# 处理BIL文件生成RGB图像
|
||||||
|
print("Processing BIL file to generate RGB image...\n")
|
||||||
|
rgb_img = process_bil_files(bil_path)
|
||||||
|
|
||||||
|
# 修改hdr
|
||||||
|
change_hdr_file(bil_path)
|
||||||
|
segmentation_start_time = time.time()
|
||||||
|
# 生成掩膜,mask为16位的塑料标签掩膜
|
||||||
|
print("Generating mask ...\n")
|
||||||
|
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
||||||
|
image=rgb_img, # 直接传入cv2.imread的结果
|
||||||
|
filter_method='threshold',
|
||||||
|
diameter=None,
|
||||||
|
flow_threshold=0.4,
|
||||||
|
cellprob_threshold=-1,
|
||||||
|
detect_filter=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
print("Extracting features from BIL file...\n")
|
||||||
|
df = process_images(bil_path, mask)
|
||||||
|
|
||||||
|
# 背景校正
|
||||||
|
print("Applying background correction...\n")
|
||||||
|
df_correct = process_images_background(bil_path, mask)
|
||||||
|
df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1)
|
||||||
|
|
||||||
|
# 数据清理
|
||||||
|
print("Cleaning data...\n")
|
||||||
|
df = df.dropna()
|
||||||
|
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||||||
|
df = df[df['area'] >= 500]
|
||||||
|
|
||||||
|
# 使用pandas列选择:获取要删除的列名(从第 94 列到第 118 列,索引从0开始)
|
||||||
|
cols_to_remove = df.columns[np.r_[87:110, -10:-1]]
|
||||||
|
# cols_to_remove = df.columns[87:110]
|
||||||
|
# 删除指定列,保持DataFrame结构
|
||||||
|
df = df.drop(columns=cols_to_remove)
|
||||||
|
segmentation_time = time.time() - segmentation_start_time
|
||||||
|
# 使用pandas列选择:选择从第二列开始的所有列(跳过第一列,通常是'Sample ID'或'filename')
|
||||||
|
# 保持DataFrame结构,不转换为numpy数组(.values会丢失列名和DataFrame结构)
|
||||||
|
df = df.iloc[:, :]
|
||||||
|
# 预测分类(分类阶段)
|
||||||
|
classification_start_time = time.time()
|
||||||
|
# 预测分类
|
||||||
|
print("Predicting classes...\n")
|
||||||
|
loaded_model = load_model(primary_model_path)
|
||||||
|
df_pre = predict_with_model(
|
||||||
|
df,
|
||||||
|
primary_model_path,
|
||||||
|
model_type=primary_model_type,
|
||||||
|
ProcessMethods1=primary_process_methods1,
|
||||||
|
ProcessMethods2=primary_process_methods2
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对HDPE和LDPE进行二次分类
|
||||||
|
# 从第一次分类结果中提取SECONDARY_TARGET_CLASSES类别的掩膜轮廓
|
||||||
|
target_classes = set(secondary_target_classes or [])
|
||||||
|
mask_secondary = df_pre['Predictions'].isin(target_classes)
|
||||||
|
|
||||||
|
if mask_secondary.any():
|
||||||
|
# 只有在找到目标类别时才进行背景校正和二次分类
|
||||||
|
print(f"Running secondary classification for classes: {sorted(target_classes)}")
|
||||||
|
|
||||||
|
# 图像信息的背景矫正
|
||||||
|
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
|
||||||
|
|
||||||
|
# 创建新的掩膜mask_second,只包含目标类别的轮廓
|
||||||
|
mask_second = np.zeros_like(mask, dtype=np.uint16)
|
||||||
|
|
||||||
|
for idx in df_pre[mask_secondary].index:
|
||||||
|
contour = df_pre.loc[idx, 'contour']
|
||||||
|
if isinstance(contour, list) and len(contour) > 0:
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
cv2.fillPoly(mask_second, [contour_array], idx + 1) # 使用索引+1作为标签
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
df_shape = extract_features(df_correct, mask_second)
|
||||||
|
|
||||||
|
# 确保使用第2到13列作为模型输入特征
|
||||||
|
if len(df_shape.columns) >= 13:
|
||||||
|
df_shape = df_shape.iloc[:, :13]
|
||||||
|
|
||||||
|
# 二次分类:使用第二个模型预测并更新分类结果
|
||||||
|
if secondary_model_path:
|
||||||
|
df_secondary = predict_with_model(
|
||||||
|
df_shape,
|
||||||
|
secondary_model_path,
|
||||||
|
model_type=secondary_model_type,
|
||||||
|
ProcessMethods1=secondary_process_methods1,
|
||||||
|
ProcessMethods2=secondary_process_methods2
|
||||||
|
)
|
||||||
|
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
|
||||||
|
else:
|
||||||
|
print("Secondary model path not provided; skipping secondary classification.\n")
|
||||||
|
else:
|
||||||
|
print("No samples from target classes found; skipping secondary classification.\n")
|
||||||
|
|
||||||
|
# 识别类别7中的背景阴影误判:通过边界清晰度特征
|
||||||
|
# 真正的类别7边界清晰,背景阴影边界模糊
|
||||||
|
class_7_mask = df_pre['Predictions'] == 7
|
||||||
|
if class_7_mask.any():
|
||||||
|
print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n")
|
||||||
|
|
||||||
|
# 将PIL Image转换为numpy数组
|
||||||
|
if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image
|
||||||
|
rgb_img_array = np.array(rgb_img)
|
||||||
|
else:
|
||||||
|
rgb_img_array = rgb_img
|
||||||
|
|
||||||
|
# 转换为灰度图(用于计算梯度)
|
||||||
|
if len(rgb_img_array.shape) == 3:
|
||||||
|
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
|
||||||
|
else:
|
||||||
|
gray_img = rgb_img_array
|
||||||
|
|
||||||
|
# 计算梯度图(使用Sobel算子)
|
||||||
|
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
|
||||||
|
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
|
||||||
|
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
|
||||||
|
|
||||||
|
# 先收集所有类别7样本的边缘梯度值,用于确定阈值
|
||||||
|
all_class7_gradients = []
|
||||||
|
valid_indices = []
|
||||||
|
for idx in df_pre[class_7_mask].index:
|
||||||
|
try:
|
||||||
|
contour = df_pre.loc[idx, 'contour']
|
||||||
|
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
if len(contour_array.shape) == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
|
||||||
|
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
|
||||||
|
edge_gradients = gradient_magnitude[mask_img > 0]
|
||||||
|
if len(edge_gradients) > 0:
|
||||||
|
all_class7_gradients.extend(edge_gradients)
|
||||||
|
valid_indices.append(idx)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 基于类别7样本的梯度分布确定阈值
|
||||||
|
# 使用类别7样本梯度值的中位数作为基准,低于某个分位数(如30%)的认为是背景阴影
|
||||||
|
if len(all_class7_gradients) > 0:
|
||||||
|
gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数
|
||||||
|
else:
|
||||||
|
gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本,使用整张图的30%分位数
|
||||||
|
|
||||||
|
print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n")
|
||||||
|
|
||||||
|
# 处理每个类别7的样本,判断是否为背景阴影
|
||||||
|
indices_to_update = []
|
||||||
|
for idx in valid_indices:
|
||||||
|
try:
|
||||||
|
contour = df_pre.loc[idx, 'contour']
|
||||||
|
contour_array = np.array(contour, dtype=np.int32)
|
||||||
|
|
||||||
|
# 创建轮廓掩膜(线宽为2像素,用于提取边缘)
|
||||||
|
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
|
||||||
|
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
|
||||||
|
|
||||||
|
# 提取轮廓边缘的梯度值
|
||||||
|
edge_gradients = gradient_magnitude[mask_img > 0]
|
||||||
|
if len(edge_gradients) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算轮廓边缘的平均梯度强度
|
||||||
|
mean_gradient = np.mean(edge_gradients)
|
||||||
|
|
||||||
|
# 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊)
|
||||||
|
if mean_gradient < gradient_threshold:
|
||||||
|
indices_to_update.append(idx)
|
||||||
|
print(
|
||||||
|
f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing sample at index {idx}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 将背景阴影的类别7改为类别9
|
||||||
|
if indices_to_update:
|
||||||
|
df_pre.loc[indices_to_update, 'Predictions'] = 9
|
||||||
|
print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n")
|
||||||
|
else:
|
||||||
|
print("No samples needed to be updated from class 7\n")
|
||||||
|
classification_time = time.time() - classification_start_time
|
||||||
|
|
||||||
|
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
|
||||||
|
# 保存ENVI分类结果
|
||||||
|
print("Saving ENVI classification results...\n")
|
||||||
|
save_envi_classification(bil_path, df_pre, output_path)
|
||||||
|
print(f"ENVI classification results saved to: {output_path}")
|
||||||
|
# 计算总耗时
|
||||||
|
total_time = time.time() - total_start_time
|
||||||
|
|
||||||
|
# 打印耗时统计
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"处理完成")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"分割耗时: {segmentation_time:.2f} 秒")
|
||||||
|
print(f"分类耗时: {classification_time:.2f} 秒")
|
||||||
|
print(f"总耗时: {total_time:.2f} 秒")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"结果已保存至: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
42
mask.py
42
mask.py
@ -901,22 +901,36 @@ class MicroplasticDetectorV2:
|
|||||||
self.filter_method = filter_method
|
self.filter_method = filter_method
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""加载Cellpose模型"""
|
"""加载Cellpose模型(优先自定义,其次内置模型)"""
|
||||||
if self.model is None:
|
if self.model is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 设备选择
|
||||||
|
want_cpu = (self.device == 'cpu')
|
||||||
|
use_gpu = (not want_cpu) and torch.cuda.is_available()
|
||||||
|
|
||||||
|
try:
|
||||||
if self.model_path and Path(self.model_path).exists():
|
if self.model_path and Path(self.model_path).exists():
|
||||||
# 加载自定义模型 - Cellpose 4.0 API
|
# 正确方式:在构造时通过 pretrained_model 传入自定义权重
|
||||||
self.model = models.CellposeModel(
|
self.model = models.CellposeModel(
|
||||||
gpu=torch.cuda.is_available() and self.device != 'cpu',
|
gpu=use_gpu,
|
||||||
model_type=None
|
pretrained_model=str(self.model_path),
|
||||||
|
model_type=None, # 自定义模型不需要 model_type
|
||||||
)
|
)
|
||||||
self.model.load_model(self.model_path)
|
print(f"已加载自定义模型: {self.model_path}")
|
||||||
else:
|
else:
|
||||||
# 使用预训练模型 - Cellpose 4.0 API
|
# 内置模型(如未安装 cpsam,可考虑改用 'cyto2')
|
||||||
self.model = models.CellposeModel(
|
self.model = models.CellposeModel(
|
||||||
gpu=torch.cuda.is_available() and self.device != 'cpu',
|
gpu=use_gpu,
|
||||||
model_type='cpsam'
|
model_type='cpsam'
|
||||||
)
|
)
|
||||||
print(f"模型已加载,使用设备: {'GPU' if torch.cuda.is_available() and self.device != 'cpu' else 'CPU'}")
|
print("使用内置模型: cpsam")
|
||||||
|
|
||||||
|
print(f"模型已加载,使用设备: {'GPU' if use_gpu else 'CPU'}")
|
||||||
|
except Exception as e:
|
||||||
|
# 回退策略:若自定义/内置加载失败,退回CPU+cyto2以保证可用性
|
||||||
|
print(f"模型加载失败({e}),回退到 CPU + cyto2")
|
||||||
|
self.model = models.CellposeModel(gpu=False, model_type='cyto2')
|
||||||
|
|
||||||
def detect_microplastics(self, image_path: str, output_dir: str = None,
|
def detect_microplastics(self, image_path: str, output_dir: str = None,
|
||||||
diameter: float = 30, flow_threshold: float = 0.4,
|
diameter: float = 30, flow_threshold: float = 0.4,
|
||||||
@ -975,7 +989,8 @@ class MicroplasticDetectorV2:
|
|||||||
masked_image,
|
masked_image,
|
||||||
diameter=diameter,
|
diameter=diameter,
|
||||||
flow_threshold=flow_threshold,
|
flow_threshold=flow_threshold,
|
||||||
cellprob_threshold=cellprob_threshold)
|
cellprob_threshold=cellprob_threshold,
|
||||||
|
channels=[0, 0]) # 明确指定灰度图像通道
|
||||||
|
|
||||||
# 6. 分析检测结果
|
# 6. 分析检测结果
|
||||||
if debug:
|
if debug:
|
||||||
@ -1247,8 +1262,9 @@ def detect_microplastic_mask_from_array(image, filter_method: str = 'shape',
|
|||||||
masked_image,
|
masked_image,
|
||||||
diameter=diameter,
|
diameter=diameter,
|
||||||
flow_threshold=flow_threshold,
|
flow_threshold=flow_threshold,
|
||||||
cellprob_threshold=cellprob_threshold
|
cellprob_threshold=cellprob_threshold,
|
||||||
)
|
channels=[0, 0]) # 明确指定灰度图像通道
|
||||||
|
|
||||||
|
|
||||||
return masks, filter_mask_original
|
return masks, filter_mask_original
|
||||||
|
|
||||||
@ -1268,7 +1284,7 @@ def main():
|
|||||||
output_dir=output_dir + "_threshold",
|
output_dir=output_dir + "_threshold",
|
||||||
diameter=None,
|
diameter=None,
|
||||||
flow_threshold=0.4,
|
flow_threshold=0.4,
|
||||||
cellprob_threshold=0,
|
cellprob_threshold=-1,
|
||||||
debug=True
|
debug=True
|
||||||
)
|
)
|
||||||
print(f"\n霍夫圆变换方法检测完成!")
|
print(f"\n霍夫圆变换方法检测完成!")
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
BIN
modelsave/svm.m
BIN
modelsave/svm.m
Binary file not shown.
@ -9,6 +9,7 @@ import get_glcm
|
|||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import copy
|
||||||
from plantcv.plantcv._helpers import _iterate_analysis, _cv2_findcontours, _object_composition, _grayscale_to_rgb, \
|
from plantcv.plantcv._helpers import _iterate_analysis, _cv2_findcontours, _object_composition, _grayscale_to_rgb, \
|
||||||
_scale_size
|
_scale_size
|
||||||
from plantcv.plantcv import outputs, within_frame
|
from plantcv.plantcv import outputs, within_frame
|
||||||
@ -332,7 +333,7 @@ def process_images(full_bil_path, mask, outdir='None', debug="None"):
|
|||||||
spectral_array = pcv.readimage(filename=args.image, mode='envi')
|
spectral_array = pcv.readimage(filename=args.image, mode='envi')
|
||||||
#过曝区域掩膜
|
#过曝区域掩膜
|
||||||
bath_100 = spectral_array.array_data[:, :, 100]
|
bath_100 = spectral_array.array_data[:, :, 100]
|
||||||
bath_over = pcv.threshold.binary(gray_img=bath_100, threshold=11000, object_type='dark')
|
bath_over = pcv.threshold.binary(gray_img=bath_100, threshold=11000, object_type='dark')#
|
||||||
|
|
||||||
# mask已经是16位标签掩膜,直接使用
|
# mask已经是16位标签掩膜,直接使用
|
||||||
# 确保掩膜是正确的数据类型
|
# 确保掩膜是正确的数据类型
|
||||||
@ -367,8 +368,11 @@ def process_images(full_bil_path, mask, outdir='None', debug="None"):
|
|||||||
shape_img = size(img=binary_img, labeled_mask=labeled_mask, n_labels=num_valid)
|
shape_img = size(img=binary_img, labeled_mask=labeled_mask, n_labels=num_valid)
|
||||||
|
|
||||||
# 分析光谱反射率
|
# 分析光谱反射率
|
||||||
# 应用过曝掩膜
|
# 应用过曝掩膜 ;如果采集的反射率大于1.1,则会被错误判断为过曝区域导致该样本的光谱没有被统计
|
||||||
labeled_spectral_mask = mask * bath_over_binary
|
# labeled_spectral_mask = mask * bath_over_binary
|
||||||
|
labeled_spectral_mask = mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
spectral_hist = pcv.analyze.spectral_reflectance(
|
spectral_hist = pcv.analyze.spectral_reflectance(
|
||||||
hsi=spectral_array,
|
hsi=spectral_array,
|
||||||
@ -377,13 +381,16 @@ def process_images(full_bil_path, mask, outdir='None', debug="None"):
|
|||||||
label=None
|
label=None
|
||||||
)
|
)
|
||||||
|
|
||||||
observations = pcv.outputs.observations
|
# 深拷贝观测结果,避免后续清空影响数据
|
||||||
|
observations = copy.deepcopy(pcv.outputs.observations)
|
||||||
|
# 立即清空全局缓存,防止跨文件/阶段累加
|
||||||
|
pcv.outputs.clear()
|
||||||
# 将结果转换为列表
|
# 将结果转换为列表
|
||||||
combined_data = process_plantcv_outputs(observations)
|
combined_data = process_plantcv_outputs(observations)
|
||||||
|
|
||||||
return combined_data
|
return combined_data
|
||||||
|
|
||||||
# # 示例:批量处理指定路径下的光谱图像和掩膜文件
|
# # # 示例:批量处理指定路径下的光谱图像和掩膜文件
|
||||||
# bil_path = r'D:\WQ\test\Traindata-05'
|
# bil_path = r'D:\WQ\test\Traindata-05'
|
||||||
# mask_path = r'D:\WQ\test\mask'
|
# mask_path = r'D:\WQ\test\mask'
|
||||||
# outdir = r"D:\WQ\test" # 输出文件夹路径
|
# outdir = r"D:\WQ\test" # 输出文件夹路径
|
||||||
|
|||||||
Reference in New Issue
Block a user