修改分割模块
This commit is contained in:
160
chose_bands.py
Normal file
160
chose_bands.py
Normal file
@ -0,0 +1,160 @@
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import itertools
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
def parse_args():
|
||||
ap = argparse.ArgumentParser(description="Select 3 bands for false-color to maximize inter-class separability")
|
||||
ap.add_argument("--csv", required=True, help="Input CSV: first column is class label, rest are spectra columns")
|
||||
ap.add_argument("--top_k", type=int, default=30, help="How many best single bands to preselect (default: 30)")
|
||||
ap.add_argument("--top_triplets", type=int, default=10, help="How many best triplets to print (default: 10)")
|
||||
ap.add_argument("--map_order", choices=["auto", "wavelength_bgr"], default="auto",
|
||||
help="RGB mapping order: auto=use scored order; wavelength_bgr=short->B, mid->G, long->R")
|
||||
return ap.parse_args()
|
||||
|
||||
def try_parse_wavelength(col_name):
|
||||
# 支持 "wavelength_912.36" / "912.36" / "band_12" 等
|
||||
m = re.search(r"([0-9]+(\.[0-9]+)?)", str(col_name))
|
||||
if m:
|
||||
try:
|
||||
return float(m.group(1))
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def read_data(csv_path):
|
||||
df = pd.read_csv(csv_path)
|
||||
if df.shape[1] < 4:
|
||||
raise ValueError("需要至少1个类别列 + 3个光谱列")
|
||||
y = df.iloc[:, 0].to_numpy()
|
||||
X = df.iloc[:, 1:].to_numpy(dtype=np.float64)
|
||||
cols = list(df.columns[1:])
|
||||
waves = []
|
||||
for c in cols:
|
||||
w = try_parse_wavelength(c)
|
||||
waves.append(w if w is not None else c)
|
||||
return X, y, cols, waves
|
||||
|
||||
def compute_anova_f_scores(X, y):
|
||||
# 一维ANOVA F-score: F = (SS_between/df_between) / (SS_within/df_within)
|
||||
N, M = X.shape
|
||||
classes = []
|
||||
for cls in np.unique(y):
|
||||
idx = (y == cls)
|
||||
if np.sum(idx) > 0:
|
||||
classes.append(idx)
|
||||
k = len(classes)
|
||||
if k < 2:
|
||||
raise ValueError("类别不足2类,无法计算区分度")
|
||||
|
||||
mu = X.mean(axis=0)
|
||||
between_ss = np.zeros(M, dtype=np.float64)
|
||||
within_ss = np.zeros(M, dtype=np.float64)
|
||||
|
||||
for idx in classes:
|
||||
ng = np.sum(idx)
|
||||
Xg = X[idx]
|
||||
if ng == 0:
|
||||
continue
|
||||
mg = Xg.mean(axis=0)
|
||||
# 无偏方差,若ng==1则设为0
|
||||
vg = Xg.var(axis=0, ddof=1) if ng > 1 else np.zeros(M, dtype=np.float64)
|
||||
between_ss += ng * (mg - mu) ** 2
|
||||
within_ss += (ng - 1) * vg
|
||||
|
||||
dfb = k - 1
|
||||
dfw = N - k if N > k else 1
|
||||
F = (between_ss / max(dfb, 1)) / (within_ss / max(dfw, 1) + 1e-12)
|
||||
return F # shape (M,)
|
||||
|
||||
def score_triplet(X, y, idx_triplet, eps=1e-6):
|
||||
# LDA准则: J = trace( (Sw+epsI)^-1 Sb ),三维空间
|
||||
t = list(idx_triplet)
|
||||
Xt = X[:, t] # (N,3)
|
||||
classes = []
|
||||
for cls in np.unique(y):
|
||||
idx = (y == cls)
|
||||
if np.sum(idx) > 0:
|
||||
classes.append(idx)
|
||||
|
||||
m = Xt.mean(axis=0)
|
||||
Sb = np.zeros((3, 3), dtype=np.float64)
|
||||
Sw = np.zeros((3, 3), dtype=np.float64)
|
||||
|
||||
for idx in classes:
|
||||
Xg = Xt[idx]
|
||||
ng = Xg.shape[0]
|
||||
if ng == 0:
|
||||
continue
|
||||
mg = Xg.mean(axis=0)
|
||||
diff = (mg - m).reshape(3, 1)
|
||||
Sb += ng * (diff @ diff.T)
|
||||
if ng > 1:
|
||||
Cg = np.cov(Xg, rowvar=False, ddof=1)
|
||||
else:
|
||||
Cg = np.zeros((3, 3), dtype=np.float64)
|
||||
Sw += (ng - 1) * Cg
|
||||
|
||||
# 稳定逆
|
||||
A = Sw + eps * np.eye(3)
|
||||
try:
|
||||
Ainv = np.linalg.pinv(A)
|
||||
except Exception:
|
||||
return -np.inf
|
||||
J = np.trace(Ainv @ Sb)
|
||||
return float(J)
|
||||
|
||||
def select_bands(csv_path, top_k=30, top_triplets=10, map_order="auto"):
|
||||
X, y, cols, waves = read_data(csv_path)
|
||||
F = compute_anova_f_scores(X, y)
|
||||
order = np.argsort(-F)
|
||||
top_idx = order[:min(top_k, X.shape[1])]
|
||||
|
||||
best = []
|
||||
for comb in itertools.combinations(top_idx, 3):
|
||||
s = score_triplet(X, y, comb)
|
||||
best.append((s, comb))
|
||||
best.sort(key=lambda x: -x[0])
|
||||
|
||||
results = []
|
||||
for s, comb in best[:top_triplets]:
|
||||
names = [cols[i] for i in comb]
|
||||
wavs = [waves[i] for i in comb]
|
||||
# 建议RGB映射
|
||||
if map_order == "wavelength_bgr" and all(isinstance(w, (int, float)) for w in wavs):
|
||||
order_rgb = np.argsort(wavs) # 小->大
|
||||
# 小(blue), 中(green), 大(red)
|
||||
b,g,r = [names[order_rgb[0]]], [names[order_rgb[1]]], [names[order_rgb[2]]]
|
||||
rgb = {"R": r[0], "G": g[0], "B": b[0]}
|
||||
rgb_w = {"R": wavs[order_rgb[2]], "G": wavs[order_rgb[1]], "B": wavs[order_rgb[0]]}
|
||||
else:
|
||||
# 直接用得分顺序:第一->R, 第二->G, 第三->B
|
||||
rgb = {"R": names[0], "G": names[1], "B": names[2]}
|
||||
rgb_w = {"R": waves[comb[0]], "G": waves[comb[1]], "B": waves[comb[2]]}
|
||||
results.append({
|
||||
"score": round(s, 6),
|
||||
"bands": names,
|
||||
"wavelengths": wavs,
|
||||
"rgb_mapping": rgb,
|
||||
"rgb_wavelengths": rgb_w
|
||||
})
|
||||
return results, top_idx, F
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
results, top_idx, F = select_bands(args.csv, args.top_k, args.top_triplets, args.map_order)
|
||||
|
||||
print(f"Top-{args.top_k} bands by ANOVA F-score (index:col_name):")
|
||||
print(", ".join([f"{i}:{i}" for i in top_idx])) # 如需列名可自行打印
|
||||
|
||||
print("\nBest triplets:")
|
||||
for r in results:
|
||||
bands = r["bands"]
|
||||
wavs = r["wavelengths"]
|
||||
rgb = r["rgb_mapping"]
|
||||
print(f"- score={r['score']:.4f} bands={bands} wavelengths={wavs} RGB={rgb}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user