Changes
This commit is contained in:
@ -19,7 +19,6 @@ from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
from joblib import parallel_backend
|
||||
# 第三方模型导入
|
||||
# try:
|
||||
# import lightgbm as lgb
|
||||
@ -43,6 +42,11 @@ import os
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
|
||||
def _sklearn_parallel_n_jobs() -> int:
|
||||
"""PyInstaller 等打包环境下,joblib loky 会再次启动当前 exe,出现多个同名进程。"""
|
||||
return 1 if getattr(sys, "frozen", False) else -1
|
||||
|
||||
|
||||
class WaterQualityModelingBatch:
|
||||
"""水质参数反演批量建模类"""
|
||||
|
||||
@ -638,25 +642,26 @@ class WaterQualityModelingBatch:
|
||||
# 网格搜索 - 使用KFold代替StratifiedKFold
|
||||
cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
||||
|
||||
_n_jobs = _sklearn_parallel_n_jobs()
|
||||
grid_search = GridSearchCV(
|
||||
base_model,
|
||||
config['params'],
|
||||
cv=cv_strategy,
|
||||
scoring=scoring,
|
||||
n_jobs=-1,
|
||||
n_jobs=_n_jobs,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
# 在训练集上训练模型
|
||||
# with parallel_backend("threading", n_jobs=-1):
|
||||
# grid_search.fit(X_train, y_train)
|
||||
grid_search.fit(X_train, y_train)
|
||||
|
||||
# 获取最佳模型
|
||||
best_model = grid_search.best_estimator_
|
||||
|
||||
# 交叉验证评估(在训练集上)
|
||||
cv_scores = cross_val_score(best_model, X_train, y_train, cv=cv_strategy, scoring=scoring)
|
||||
cv_scores = cross_val_score(
|
||||
best_model, X_train, y_train, cv=cv_strategy, scoring=scoring, n_jobs=_n_jobs
|
||||
)
|
||||
|
||||
# 计算训练集上的回归指标
|
||||
y_train_pred = best_model.predict(X_train)
|
||||
|
||||
@ -2263,7 +2263,7 @@ class Step6Panel(QWidget):
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True) # 默认全选
|
||||
checkbox.setChecked(method == 'SNV') # 默认全选
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
@ -2344,7 +2344,7 @@ class Step6Panel(QWidget):
|
||||
|
||||
for i, method in enumerate(split_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True) # 默认全选
|
||||
checkbox.setChecked(method == 'spxy') # 默认全选
|
||||
self.split_checkboxes[method] = checkbox
|
||||
split_grid.addWidget(checkbox, 0, i)
|
||||
|
||||
@ -4812,7 +4812,7 @@ class Step6_5Panel(QWidget):
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True)
|
||||
checkbox.setChecked(method == 'SNV')
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
@ -4843,7 +4843,7 @@ class Step6_5Panel(QWidget):
|
||||
row_layout = QHBoxLayout()
|
||||
row_layout.setContentsMargins(0, 0, 0, 0)
|
||||
checkbox = QCheckBox(algorithm)
|
||||
checkbox.setChecked(True)
|
||||
checkbox.setChecked(algorithm == 'chl_a')
|
||||
spinbox = QSpinBox()
|
||||
spinbox.setRange(0, 500)
|
||||
spinbox.setValue(0)
|
||||
@ -5144,7 +5144,7 @@ class Step6_75Panel(QWidget):
|
||||
for i, method in enumerate(regression_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
# 默认选择常用的方法
|
||||
if method in ['linear', 'exponential', 'power', 'logarithmic']:
|
||||
if method in ['linear', 'exponential', 'power']:
|
||||
checkbox.setChecked(True)
|
||||
self.method_checkboxes[method] = checkbox
|
||||
method_grid.addWidget(checkbox, i // 3, i % 3)
|
||||
@ -5251,7 +5251,7 @@ class Step6_75Panel(QWidget):
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
# 默认选择一些常见的指数列
|
||||
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
|
||||
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd','BGA_','Chl_','Turb_']):
|
||||
checkbox.setChecked(True)
|
||||
self.x_column_checkboxes[col] = checkbox
|
||||
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
|
||||
@ -5260,7 +5260,8 @@ class Step6_75Panel(QWidget):
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
# 默认选择一些常见的水质参数列
|
||||
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
|
||||
target_keywords = ['Chlorophyll', 'COD', 'DO', 'PH', 'Temperature', 'spCond', 'Turbidity', 'TDS', 'Cl-', 'NO3-N', 'NH3-N', 'BGA', 'TT']
|
||||
if any(keyword.lower() in col.lower() for keyword in target_keywords):
|
||||
checkbox.setChecked(True)
|
||||
self.y_column_checkboxes[col] = checkbox
|
||||
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
|
||||
@ -6990,7 +6991,6 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#冻结,只显示1个exe
|
||||
# multiprocessing.freeze_support()
|
||||
multiprocessing.freeze_support()
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user