This commit is contained in:
2026-05-07 17:01:44 +08:00
parent 69ce95cda4
commit 9ce17df28a
2 changed files with 19 additions and 14 deletions

View File

@ -19,7 +19,6 @@ from sklearn.cross_decomposition import PLSRegression
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
from sklearn.tree import DecisionTreeRegressor from sklearn.tree import DecisionTreeRegressor
from sklearn.neural_network import MLPRegressor from sklearn.neural_network import MLPRegressor
from joblib import parallel_backend
# 第三方模型导入 # 第三方模型导入
# try: # try:
# import lightgbm as lgb # import lightgbm as lgb
@ -43,6 +42,11 @@ import os
from src.preprocessing.spectral_Preprocessing import Preprocessing from src.preprocessing.spectral_Preprocessing import Preprocessing
def _sklearn_parallel_n_jobs() -> int:
"""PyInstaller 等打包环境下joblib loky 会再次启动当前 exe出现多个同名进程。"""
return 1 if getattr(sys, "frozen", False) else -1
class WaterQualityModelingBatch: class WaterQualityModelingBatch:
"""水质参数反演批量建模类""" """水质参数反演批量建模类"""
@ -638,25 +642,26 @@ class WaterQualityModelingBatch:
# 网格搜索 - 使用KFold代替StratifiedKFold # 网格搜索 - 使用KFold代替StratifiedKFold
cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state) cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
_n_jobs = _sklearn_parallel_n_jobs()
grid_search = GridSearchCV( grid_search = GridSearchCV(
base_model, base_model,
config['params'], config['params'],
cv=cv_strategy, cv=cv_strategy,
scoring=scoring, scoring=scoring,
n_jobs=-1, n_jobs=_n_jobs,
verbose=1 verbose=1
) )
# 在训练集上训练模型 # 在训练集上训练模型
# with parallel_backend("threading", n_jobs=-1):
# grid_search.fit(X_train, y_train)
grid_search.fit(X_train, y_train) grid_search.fit(X_train, y_train)
# 获取最佳模型 # 获取最佳模型
best_model = grid_search.best_estimator_ best_model = grid_search.best_estimator_
# 交叉验证评估(在训练集上) # 交叉验证评估(在训练集上)
cv_scores = cross_val_score(best_model, X_train, y_train, cv=cv_strategy, scoring=scoring) cv_scores = cross_val_score(
best_model, X_train, y_train, cv=cv_strategy, scoring=scoring, n_jobs=_n_jobs
)
# 计算训练集上的回归指标 # 计算训练集上的回归指标
y_train_pred = best_model.predict(X_train) y_train_pred = best_model.predict(X_train)

View File

@ -2263,7 +2263,7 @@ class Step6Panel(QWidget):
for i, method in enumerate(preproc_methods): for i, method in enumerate(preproc_methods):
checkbox = QCheckBox(method) checkbox = QCheckBox(method)
checkbox.setChecked(True) # 默认全选 checkbox.setChecked(method == 'SNV') # 默认全选
self.preproc_checkboxes[method] = checkbox self.preproc_checkboxes[method] = checkbox
preproc_grid.addWidget(checkbox, i // 4, i % 4) preproc_grid.addWidget(checkbox, i // 4, i % 4)
@ -2344,7 +2344,7 @@ class Step6Panel(QWidget):
for i, method in enumerate(split_methods): for i, method in enumerate(split_methods):
checkbox = QCheckBox(method) checkbox = QCheckBox(method)
checkbox.setChecked(True) # 默认全选 checkbox.setChecked(method == 'spxy') # 默认全选
self.split_checkboxes[method] = checkbox self.split_checkboxes[method] = checkbox
split_grid.addWidget(checkbox, 0, i) split_grid.addWidget(checkbox, 0, i)
@ -4812,7 +4812,7 @@ class Step6_5Panel(QWidget):
for i, method in enumerate(preproc_methods): for i, method in enumerate(preproc_methods):
checkbox = QCheckBox(method) checkbox = QCheckBox(method)
checkbox.setChecked(True) checkbox.setChecked(method == 'SNV')
self.preproc_checkboxes[method] = checkbox self.preproc_checkboxes[method] = checkbox
preproc_grid.addWidget(checkbox, i // 4, i % 4) preproc_grid.addWidget(checkbox, i // 4, i % 4)
@ -4843,7 +4843,7 @@ class Step6_5Panel(QWidget):
row_layout = QHBoxLayout() row_layout = QHBoxLayout()
row_layout.setContentsMargins(0, 0, 0, 0) row_layout.setContentsMargins(0, 0, 0, 0)
checkbox = QCheckBox(algorithm) checkbox = QCheckBox(algorithm)
checkbox.setChecked(True) checkbox.setChecked(algorithm == 'chl_a')
spinbox = QSpinBox() spinbox = QSpinBox()
spinbox.setRange(0, 500) spinbox.setRange(0, 500)
spinbox.setValue(0) spinbox.setValue(0)
@ -5144,7 +5144,7 @@ class Step6_75Panel(QWidget):
for i, method in enumerate(regression_methods): for i, method in enumerate(regression_methods):
checkbox = QCheckBox(method) checkbox = QCheckBox(method)
# 默认选择常用的方法 # 默认选择常用的方法
if method in ['linear', 'exponential', 'power', 'logarithmic']: if method in ['linear', 'exponential', 'power']:
checkbox.setChecked(True) checkbox.setChecked(True)
self.method_checkboxes[method] = checkbox self.method_checkboxes[method] = checkbox
method_grid.addWidget(checkbox, i // 3, i % 3) method_grid.addWidget(checkbox, i // 3, i % 3)
@ -5251,7 +5251,7 @@ class Step6_75Panel(QWidget):
for i, col in enumerate(self.csv_columns): for i, col in enumerate(self.csv_columns):
checkbox = QCheckBox(col) checkbox = QCheckBox(col)
# 默认选择一些常见的指数列 # 默认选择一些常见的指数列
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']): if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd','BGA_','Chl_','Turb_']):
checkbox.setChecked(True) checkbox.setChecked(True)
self.x_column_checkboxes[col] = checkbox self.x_column_checkboxes[col] = checkbox
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3) self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
@ -5260,7 +5260,8 @@ class Step6_75Panel(QWidget):
for i, col in enumerate(self.csv_columns): for i, col in enumerate(self.csv_columns):
checkbox = QCheckBox(col) checkbox = QCheckBox(col)
# 默认选择一些常见的水质参数列 # 默认选择一些常见的水质参数列
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']): target_keywords = ['Chlorophyll', 'COD', 'DO', 'PH', 'Temperature', 'spCond', 'Turbidity', 'TDS', 'Cl-', 'NO3-N', 'NH3-N', 'BGA', 'TT']
if any(keyword.lower() in col.lower() for keyword in target_keywords):
checkbox.setChecked(True) checkbox.setChecked(True)
self.y_column_checkboxes[col] = checkbox self.y_column_checkboxes[col] = checkbox
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2) self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
@ -6990,7 +6991,6 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
#冻结只显示1个exe multiprocessing.freeze_support()
# multiprocessing.freeze_support()
main() main()