Files
micro_plastic/chose_bands.py
2026-03-05 17:12:01 +08:00

160 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import pandas as pd
import numpy as np
import itertools
import re
from collections import defaultdict
def parse_args():
ap = argparse.ArgumentParser(description="Select 3 bands for false-color to maximize inter-class separability")
ap.add_argument("--csv", required=True, help="Input CSV: first column is class label, rest are spectra columns")
ap.add_argument("--top_k", type=int, default=30, help="How many best single bands to preselect (default: 30)")
ap.add_argument("--top_triplets", type=int, default=10, help="How many best triplets to print (default: 10)")
ap.add_argument("--map_order", choices=["auto", "wavelength_bgr"], default="auto",
help="RGB mapping order: auto=use scored order; wavelength_bgr=short->B, mid->G, long->R")
return ap.parse_args()
def try_parse_wavelength(col_name):
# 支持 "wavelength_912.36" / "912.36" / "band_12" 等
m = re.search(r"([0-9]+(\.[0-9]+)?)", str(col_name))
if m:
try:
return float(m.group(1))
except Exception:
return None
return None
def read_data(csv_path):
df = pd.read_csv(csv_path)
if df.shape[1] < 4:
raise ValueError("需要至少1个类别列 + 3个光谱列")
y = df.iloc[:, 0].to_numpy()
X = df.iloc[:, 1:].to_numpy(dtype=np.float64)
cols = list(df.columns[1:])
waves = []
for c in cols:
w = try_parse_wavelength(c)
waves.append(w if w is not None else c)
return X, y, cols, waves
def compute_anova_f_scores(X, y):
# 一维ANOVA F-score: F = (SS_between/df_between) / (SS_within/df_within)
N, M = X.shape
classes = []
for cls in np.unique(y):
idx = (y == cls)
if np.sum(idx) > 0:
classes.append(idx)
k = len(classes)
if k < 2:
raise ValueError("类别不足2类无法计算区分度")
mu = X.mean(axis=0)
between_ss = np.zeros(M, dtype=np.float64)
within_ss = np.zeros(M, dtype=np.float64)
for idx in classes:
ng = np.sum(idx)
Xg = X[idx]
if ng == 0:
continue
mg = Xg.mean(axis=0)
# 无偏方差若ng==1则设为0
vg = Xg.var(axis=0, ddof=1) if ng > 1 else np.zeros(M, dtype=np.float64)
between_ss += ng * (mg - mu) ** 2
within_ss += (ng - 1) * vg
dfb = k - 1
dfw = N - k if N > k else 1
F = (between_ss / max(dfb, 1)) / (within_ss / max(dfw, 1) + 1e-12)
return F # shape (M,)
def score_triplet(X, y, idx_triplet, eps=1e-6):
# LDA准则: J = trace( (Sw+epsI)^-1 Sb ),三维空间
t = list(idx_triplet)
Xt = X[:, t] # (N,3)
classes = []
for cls in np.unique(y):
idx = (y == cls)
if np.sum(idx) > 0:
classes.append(idx)
m = Xt.mean(axis=0)
Sb = np.zeros((3, 3), dtype=np.float64)
Sw = np.zeros((3, 3), dtype=np.float64)
for idx in classes:
Xg = Xt[idx]
ng = Xg.shape[0]
if ng == 0:
continue
mg = Xg.mean(axis=0)
diff = (mg - m).reshape(3, 1)
Sb += ng * (diff @ diff.T)
if ng > 1:
Cg = np.cov(Xg, rowvar=False, ddof=1)
else:
Cg = np.zeros((3, 3), dtype=np.float64)
Sw += (ng - 1) * Cg
# 稳定逆
A = Sw + eps * np.eye(3)
try:
Ainv = np.linalg.pinv(A)
except Exception:
return -np.inf
J = np.trace(Ainv @ Sb)
return float(J)
def select_bands(csv_path, top_k=30, top_triplets=10, map_order="auto"):
X, y, cols, waves = read_data(csv_path)
F = compute_anova_f_scores(X, y)
order = np.argsort(-F)
top_idx = order[:min(top_k, X.shape[1])]
best = []
for comb in itertools.combinations(top_idx, 3):
s = score_triplet(X, y, comb)
best.append((s, comb))
best.sort(key=lambda x: -x[0])
results = []
for s, comb in best[:top_triplets]:
names = [cols[i] for i in comb]
wavs = [waves[i] for i in comb]
# 建议RGB映射
if map_order == "wavelength_bgr" and all(isinstance(w, (int, float)) for w in wavs):
order_rgb = np.argsort(wavs) # 小->大
# 小(blue), 中(green), 大(red)
b,g,r = [names[order_rgb[0]]], [names[order_rgb[1]]], [names[order_rgb[2]]]
rgb = {"R": r[0], "G": g[0], "B": b[0]}
rgb_w = {"R": wavs[order_rgb[2]], "G": wavs[order_rgb[1]], "B": wavs[order_rgb[0]]}
else:
# 直接用得分顺序:第一->R, 第二->G, 第三->B
rgb = {"R": names[0], "G": names[1], "B": names[2]}
rgb_w = {"R": waves[comb[0]], "G": waves[comb[1]], "B": waves[comb[2]]}
results.append({
"score": round(s, 6),
"bands": names,
"wavelengths": wavs,
"rgb_mapping": rgb,
"rgb_wavelengths": rgb_w
})
return results, top_idx, F
def main():
args = parse_args()
results, top_idx, F = select_bands(args.csv, args.top_k, args.top_triplets, args.map_order)
print(f"Top-{args.top_k} bands by ANOVA F-score (index:col_name):")
print(", ".join([f"{i}:{i}" for i in top_idx])) # 如需列名可自行打印
print("\nBest triplets:")
for r in results:
bands = r["bands"]
wavs = r["wavelengths"]
rgb = r["rgb_mapping"]
print(f"- score={r['score']:.4f} bands={bands} wavelengths={wavs} RGB={rgb}")
if __name__ == "__main__":
main()