修改分割模块

This commit is contained in:
2026-03-05 17:12:01 +08:00
parent d84d886f35
commit 10fd2b00d4
43 changed files with 1858 additions and 284 deletions

160
chose_bands.py Normal file
View File

@ -0,0 +1,160 @@
import argparse
import pandas as pd
import numpy as np
import itertools
import re
from collections import defaultdict
def parse_args():
ap = argparse.ArgumentParser(description="Select 3 bands for false-color to maximize inter-class separability")
ap.add_argument("--csv", required=True, help="Input CSV: first column is class label, rest are spectra columns")
ap.add_argument("--top_k", type=int, default=30, help="How many best single bands to preselect (default: 30)")
ap.add_argument("--top_triplets", type=int, default=10, help="How many best triplets to print (default: 10)")
ap.add_argument("--map_order", choices=["auto", "wavelength_bgr"], default="auto",
help="RGB mapping order: auto=use scored order; wavelength_bgr=short->B, mid->G, long->R")
return ap.parse_args()
def try_parse_wavelength(col_name):
# 支持 "wavelength_912.36" / "912.36" / "band_12" 等
m = re.search(r"([0-9]+(\.[0-9]+)?)", str(col_name))
if m:
try:
return float(m.group(1))
except Exception:
return None
return None
def read_data(csv_path):
df = pd.read_csv(csv_path)
if df.shape[1] < 4:
raise ValueError("需要至少1个类别列 + 3个光谱列")
y = df.iloc[:, 0].to_numpy()
X = df.iloc[:, 1:].to_numpy(dtype=np.float64)
cols = list(df.columns[1:])
waves = []
for c in cols:
w = try_parse_wavelength(c)
waves.append(w if w is not None else c)
return X, y, cols, waves
def compute_anova_f_scores(X, y):
# 一维ANOVA F-score: F = (SS_between/df_between) / (SS_within/df_within)
N, M = X.shape
classes = []
for cls in np.unique(y):
idx = (y == cls)
if np.sum(idx) > 0:
classes.append(idx)
k = len(classes)
if k < 2:
raise ValueError("类别不足2类无法计算区分度")
mu = X.mean(axis=0)
between_ss = np.zeros(M, dtype=np.float64)
within_ss = np.zeros(M, dtype=np.float64)
for idx in classes:
ng = np.sum(idx)
Xg = X[idx]
if ng == 0:
continue
mg = Xg.mean(axis=0)
# 无偏方差若ng==1则设为0
vg = Xg.var(axis=0, ddof=1) if ng > 1 else np.zeros(M, dtype=np.float64)
between_ss += ng * (mg - mu) ** 2
within_ss += (ng - 1) * vg
dfb = k - 1
dfw = N - k if N > k else 1
F = (between_ss / max(dfb, 1)) / (within_ss / max(dfw, 1) + 1e-12)
return F # shape (M,)
def score_triplet(X, y, idx_triplet, eps=1e-6):
# LDA准则: J = trace( (Sw+epsI)^-1 Sb ),三维空间
t = list(idx_triplet)
Xt = X[:, t] # (N,3)
classes = []
for cls in np.unique(y):
idx = (y == cls)
if np.sum(idx) > 0:
classes.append(idx)
m = Xt.mean(axis=0)
Sb = np.zeros((3, 3), dtype=np.float64)
Sw = np.zeros((3, 3), dtype=np.float64)
for idx in classes:
Xg = Xt[idx]
ng = Xg.shape[0]
if ng == 0:
continue
mg = Xg.mean(axis=0)
diff = (mg - m).reshape(3, 1)
Sb += ng * (diff @ diff.T)
if ng > 1:
Cg = np.cov(Xg, rowvar=False, ddof=1)
else:
Cg = np.zeros((3, 3), dtype=np.float64)
Sw += (ng - 1) * Cg
# 稳定逆
A = Sw + eps * np.eye(3)
try:
Ainv = np.linalg.pinv(A)
except Exception:
return -np.inf
J = np.trace(Ainv @ Sb)
return float(J)
def select_bands(csv_path, top_k=30, top_triplets=10, map_order="auto"):
X, y, cols, waves = read_data(csv_path)
F = compute_anova_f_scores(X, y)
order = np.argsort(-F)
top_idx = order[:min(top_k, X.shape[1])]
best = []
for comb in itertools.combinations(top_idx, 3):
s = score_triplet(X, y, comb)
best.append((s, comb))
best.sort(key=lambda x: -x[0])
results = []
for s, comb in best[:top_triplets]:
names = [cols[i] for i in comb]
wavs = [waves[i] for i in comb]
# 建议RGB映射
if map_order == "wavelength_bgr" and all(isinstance(w, (int, float)) for w in wavs):
order_rgb = np.argsort(wavs) # 小->大
# 小(blue), 中(green), 大(red)
b,g,r = [names[order_rgb[0]]], [names[order_rgb[1]]], [names[order_rgb[2]]]
rgb = {"R": r[0], "G": g[0], "B": b[0]}
rgb_w = {"R": wavs[order_rgb[2]], "G": wavs[order_rgb[1]], "B": wavs[order_rgb[0]]}
else:
# 直接用得分顺序:第一->R, 第二->G, 第三->B
rgb = {"R": names[0], "G": names[1], "B": names[2]}
rgb_w = {"R": waves[comb[0]], "G": waves[comb[1]], "B": waves[comb[2]]}
results.append({
"score": round(s, 6),
"bands": names,
"wavelengths": wavs,
"rgb_mapping": rgb,
"rgb_wavelengths": rgb_w
})
return results, top_idx, F
def main():
args = parse_args()
results, top_idx, F = select_bands(args.csv, args.top_k, args.top_triplets, args.map_order)
print(f"Top-{args.top_k} bands by ANOVA F-score (index:col_name):")
print(", ".join([f"{i}:{i}" for i in top_idx])) # 如需列名可自行打印
print("\nBest triplets:")
for r in results:
bands = r["bands"]
wavs = r["wavelengths"]
rgb = r["rgb_mapping"]
print(f"- score={r['score']:.4f} bands={bands} wavelengths={wavs} RGB={rgb}")
if __name__ == "__main__":
main()