414 lines
14 KiB
Python
414 lines
14 KiB
Python
"""
|
||
高光谱分析工具包输出处理模块
|
||
提供统一的输出格式化和结果摘要功能
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import pandas as pd
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from typing import Dict, Any, List, Optional, Union, Tuple
|
||
from datetime import datetime
|
||
import matplotlib.pyplot as plt
|
||
import warnings
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
class OutputFormatter:
|
||
"""输出格式化器"""
|
||
|
||
@staticmethod
|
||
def format_task_result(task_name: str, result: Any) -> Dict[str, Any]:
|
||
"""格式化任务结果"""
|
||
formatted = {
|
||
'task_name': task_name,
|
||
'timestamp': datetime.now().isoformat(),
|
||
'success': True,
|
||
'result_type': type(result).__name__,
|
||
'summary': {}
|
||
}
|
||
|
||
if result is None:
|
||
formatted['summary'] = {'message': '任务完成,无返回结果'}
|
||
elif isinstance(result, dict):
|
||
formatted['summary'] = OutputFormatter._summarize_dict_result(result)
|
||
elif isinstance(result, (list, tuple)):
|
||
formatted['summary'] = OutputFormatter._summarize_list_result(result)
|
||
elif isinstance(result, np.ndarray):
|
||
formatted['summary'] = OutputFormatter._summarize_array_result(result)
|
||
elif isinstance(result, pd.DataFrame):
|
||
formatted['summary'] = OutputFormatter._summarize_dataframe_result(result)
|
||
else:
|
||
formatted['summary'] = {'value': str(result)}
|
||
|
||
return formatted
|
||
|
||
@staticmethod
|
||
def _summarize_dict_result(result: Dict) -> Dict[str, Any]:
|
||
"""汇总字典结果"""
|
||
summary = {
|
||
'keys_count': len(result),
|
||
'keys': list(result.keys())[:10] # 只显示前10个键
|
||
}
|
||
|
||
# 尝试提取数值信息
|
||
numeric_values = {}
|
||
for key, value in result.items():
|
||
if isinstance(value, (int, float)):
|
||
numeric_values[key] = value
|
||
elif isinstance(value, np.ndarray) and value.ndim == 0:
|
||
numeric_values[key] = float(value)
|
||
elif isinstance(value, np.ndarray) and value.size == 1:
|
||
numeric_values[key] = float(value.item())
|
||
|
||
if numeric_values:
|
||
summary['numeric_metrics'] = numeric_values
|
||
|
||
return summary
|
||
|
||
@staticmethod
|
||
def _summarize_list_result(result: Union[List, Tuple]) -> Dict[str, Any]:
|
||
"""汇总列表/元组结果"""
|
||
summary = {
|
||
'length': len(result),
|
||
'element_types': list(set(type(x).__name__ for x in result[:10]))
|
||
}
|
||
|
||
# 如果是数值列表,计算统计信息
|
||
if len(result) > 0 and all(isinstance(x, (int, float, np.number)) for x in result):
|
||
numeric_result = [float(x) for x in result]
|
||
summary['statistics'] = {
|
||
'mean': np.mean(numeric_result),
|
||
'std': np.std(numeric_result),
|
||
'min': np.min(numeric_result),
|
||
'max': np.max(numeric_result)
|
||
}
|
||
|
||
return summary
|
||
|
||
@staticmethod
|
||
def _summarize_array_result(result: np.ndarray) -> Dict[str, Any]:
|
||
"""汇总数组结果"""
|
||
summary = {
|
||
'shape': result.shape,
|
||
'dtype': str(result.dtype),
|
||
'dimensions': result.ndim
|
||
}
|
||
|
||
if result.size > 0:
|
||
summary['statistics'] = {
|
||
'mean': float(np.mean(result)),
|
||
'std': float(np.std(result)),
|
||
'min': float(np.min(result)),
|
||
'max': float(np.max(result))
|
||
}
|
||
|
||
# 检查是否有NaN或Inf
|
||
if np.any(np.isnan(result)):
|
||
summary['has_nan'] = True
|
||
if np.any(np.isinf(result)):
|
||
summary['has_inf'] = True
|
||
|
||
return summary
|
||
|
||
@staticmethod
|
||
def _summarize_dataframe_result(result: pd.DataFrame) -> Dict[str, Any]:
|
||
"""汇总DataFrame结果"""
|
||
summary = {
|
||
'shape': result.shape,
|
||
'columns': list(result.columns),
|
||
'dtypes': result.dtypes.astype(str).to_dict()
|
||
}
|
||
|
||
# 数值列统计
|
||
numeric_cols = result.select_dtypes(include=[np.number]).columns
|
||
if len(numeric_cols) > 0:
|
||
summary['numeric_statistics'] = {}
|
||
for col in numeric_cols:
|
||
summary['numeric_statistics'][col] = {
|
||
'mean': float(result[col].mean()),
|
||
'std': float(result[col].std()),
|
||
'min': float(result[col].min()),
|
||
'max': float(result[col].max())
|
||
}
|
||
|
||
return summary
|
||
|
||
|
||
class ResultExporter:
|
||
"""结果导出器"""
|
||
|
||
@staticmethod
|
||
def save_task_summary(summary: Dict[str, Any], output_dir: str, prefix: str = "task_summary"):
|
||
"""保存任务摘要"""
|
||
output_path = Path(output_dir) / f"{prefix}.json"
|
||
|
||
try:
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(summary, f, indent=2, ensure_ascii=False, default=str)
|
||
return str(output_path)
|
||
except Exception as e:
|
||
print(f"警告: 无法保存任务摘要: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def save_execution_log(task_name: str, args: Dict[str, Any], result: Any,
|
||
output_dir: str, success: bool = True):
|
||
"""保存执行日志"""
|
||
log_entry = {
|
||
'timestamp': datetime.now().isoformat(),
|
||
'task_name': task_name,
|
||
'parameters': args,
|
||
'success': success,
|
||
'result_summary': OutputFormatter.format_task_result(task_name, result)
|
||
}
|
||
|
||
log_path = Path(output_dir) / "execution_log.json"
|
||
|
||
try:
|
||
# 读取现有日志
|
||
if log_path.exists():
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
logs = json.load(f)
|
||
else:
|
||
logs = []
|
||
|
||
# 添加新日志
|
||
logs.append(log_entry)
|
||
|
||
# 保存日志
|
||
with open(log_path, 'w', encoding='utf-8') as f:
|
||
json.dump(logs, f, indent=2, ensure_ascii=False, default=str)
|
||
|
||
return str(log_path)
|
||
except Exception as e:
|
||
print(f"警告: 无法保存执行日志: {e}")
|
||
return None
|
||
|
||
|
||
class SummaryGenerator:
|
||
"""摘要生成器"""
|
||
|
||
@staticmethod
|
||
def generate_task_summary(task_name: str, args, result: Any,
|
||
execution_time: Optional[float] = None) -> str:
|
||
"""生成任务摘要文本"""
|
||
lines = []
|
||
lines.append("=" * 60)
|
||
lines.append(f"任务执行摘要 - {task_name}")
|
||
lines.append("=" * 60)
|
||
|
||
# 基本信息
|
||
lines.append(f"执行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
if execution_time:
|
||
lines.append(f"运行时长: {execution_time:.2f} 秒")
|
||
|
||
# 输入参数
|
||
lines.append(f"\n输入参数:")
|
||
for key, value in vars(args).items():
|
||
if not key.startswith('_'):
|
||
lines.append(f" {key}: {value}")
|
||
|
||
# 结果摘要
|
||
lines.append(f"\n结果摘要:")
|
||
formatted_result = OutputFormatter.format_task_result(task_name, result)
|
||
summary = formatted_result['summary']
|
||
|
||
if 'statistics' in summary:
|
||
stats = summary['statistics']
|
||
lines.append(f" 数据统计:")
|
||
lines.append(f" 均值: {stats.get('mean', 'N/A'):.4f}")
|
||
lines.append(f" 标准差: {stats.get('std', 'N/A'):.4f}")
|
||
lines.append(f" 最小值: {stats.get('min', 'N/A'):.4f}")
|
||
lines.append(f" 最大值: {stats.get('max', 'N/A'):.4f}")
|
||
|
||
if 'shape' in summary:
|
||
lines.append(f" 数据形状: {summary['shape']}")
|
||
|
||
if 'keys_count' in summary:
|
||
lines.append(f" 结果包含 {summary['keys_count']} 个项目")
|
||
|
||
if 'numeric_metrics' in summary:
|
||
lines.append(f" 数值指标:")
|
||
for key, value in summary['numeric_metrics'].items():
|
||
lines.append(f" {key}: {value}")
|
||
|
||
# 输出文件
|
||
if hasattr(args, 'output_dir') and hasattr(args, 'output_prefix'):
|
||
output_dir = Path(args.output_dir)
|
||
lines.append(f"\n输出文件:")
|
||
lines.append(f" 输出目录: {output_dir.absolute()}")
|
||
|
||
# 列出可能的输出文件
|
||
possible_files = [
|
||
f"{args.output_prefix}.dat",
|
||
f"{args.output_prefix}.hdr",
|
||
f"{args.output_prefix}.csv",
|
||
f"{args.output_prefix}.png",
|
||
f"{args.output_prefix}.json"
|
||
]
|
||
|
||
for filename in possible_files:
|
||
filepath = output_dir / filename
|
||
if filepath.exists():
|
||
lines.append(f" ✓ {filename} ({filepath.stat().st_size} bytes)")
|
||
|
||
lines.append("=" * 60)
|
||
|
||
return "\n".join(lines)
|
||
|
||
@staticmethod
|
||
def print_colored_summary(summary_text: str, success: bool = True):
|
||
"""打印彩色摘要"""
|
||
try:
|
||
from colorama import Fore, Style, init
|
||
init()
|
||
|
||
if success:
|
||
color = Fore.GREEN
|
||
else:
|
||
color = Fore.RED
|
||
|
||
print(color + summary_text + Style.RESET_ALL)
|
||
except ImportError:
|
||
# 如果没有colorama,直接打印
|
||
print(summary_text)
|
||
|
||
|
||
class ReportGenerator:
|
||
"""报告生成器"""
|
||
|
||
@staticmethod
|
||
def generate_html_report(task_name: str, args, result: Any,
|
||
output_dir: str, execution_time: Optional[float] = None) -> str:
|
||
"""生成HTML报告"""
|
||
html_content = f"""
|
||
<!DOCTYPE html>
|
||
<html>
|
||
<head>
|
||
<title>高光谱分析报告 - {task_name}</title>
|
||
<meta charset="utf-8">
|
||
<style>
|
||
body {{ font-family: Arial, sans-serif; margin: 40px; }}
|
||
.header {{ background: #f0f0f0; padding: 20px; border-radius: 5px; }}
|
||
.section {{ margin: 20px 0; padding: 15px; border: 1px solid #ddd; border-radius: 5px; }}
|
||
.success {{ border-color: #4CAF50; background: #f9fff9; }}
|
||
.metric {{ display: inline-block; margin: 10px; padding: 10px; background: #e8f4f8; border-radius: 3px; }}
|
||
table {{ border-collapse: collapse; width: 100%; }}
|
||
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
||
th {{ background-color: #f2f2f2; }}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div class="header">
|
||
<h1>高光谱分析执行报告</h1>
|
||
<p><strong>任务:</strong> {task_name}</p>
|
||
<p><strong>执行时间:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||
{"<p><strong>运行时长:</strong> " + f"{execution_time:.2f} 秒</p>" if execution_time else ""}
|
||
</div>
|
||
|
||
<div class="section">
|
||
<h2>输入参数</h2>
|
||
<table>
|
||
<tr><th>参数</th><th>值</th></tr>
|
||
"""
|
||
|
||
# 添加参数
|
||
for key, value in vars(args).items():
|
||
if not key.startswith('_'):
|
||
html_content += f" <tr><td>{key}</td><td>{value}</td></tr>\n"
|
||
|
||
html_content += """
|
||
</table>
|
||
</div>
|
||
|
||
<div class="section success">
|
||
<h2>执行结果</h2>
|
||
"""
|
||
|
||
# 添加结果摘要
|
||
formatted_result = OutputFormatter.format_task_result(task_name, result)
|
||
summary = formatted_result['summary']
|
||
|
||
if 'statistics' in summary:
|
||
stats = summary['statistics']
|
||
html_content += """
|
||
<h3>数据统计</h3>
|
||
<div class="metric">均值: {:.4f}</div>
|
||
<div class="metric">标准差: {:.4f}</div>
|
||
<div class="metric">最小值: {:.4f}</div>
|
||
<div class="metric">最大值: {:.4f}</div>
|
||
""".format(
|
||
stats.get('mean', 0),
|
||
stats.get('std', 0),
|
||
stats.get('min', 0),
|
||
stats.get('max', 0)
|
||
)
|
||
|
||
if 'shape' in summary:
|
||
html_content += f"""
|
||
<h3>数据信息</h3>
|
||
<p>形状: {summary['shape']}</p>
|
||
<p>数据类型: {summary.get('dtype', 'unknown')}</p>
|
||
"""
|
||
|
||
html_content += """
|
||
</div>
|
||
</body>
|
||
</html>
|
||
"""
|
||
|
||
# 保存HTML报告
|
||
report_path = Path(output_dir) / f"{task_name}_report.html"
|
||
try:
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
f.write(html_content)
|
||
return str(report_path)
|
||
except Exception as e:
|
||
print(f"警告: 无法生成HTML报告: {e}")
|
||
return None
|
||
|
||
|
||
def create_unified_output(task_name: str, args, result: Any,
|
||
output_dir: str, execution_time: Optional[float] = None,
|
||
success: bool = True) -> Dict[str, str]:
|
||
"""
|
||
创建统一的输出
|
||
|
||
Returns:
|
||
输出文件路径字典
|
||
"""
|
||
output_files = {}
|
||
|
||
# 生成并打印摘要
|
||
summary_text = SummaryGenerator.generate_task_summary(
|
||
task_name, args, result, execution_time
|
||
)
|
||
SummaryGenerator.print_colored_summary(summary_text, success)
|
||
|
||
# 保存任务摘要
|
||
formatted_result = OutputFormatter.format_task_result(task_name, result)
|
||
summary_path = ResultExporter.save_task_summary(
|
||
formatted_result, output_dir, f"{task_name}_summary"
|
||
)
|
||
if summary_path:
|
||
output_files['summary'] = summary_path
|
||
|
||
# 保存执行日志
|
||
log_path = ResultExporter.save_execution_log(
|
||
task_name, vars(args), result, output_dir, success
|
||
)
|
||
if log_path:
|
||
output_files['log'] = log_path
|
||
|
||
# 生成HTML报告
|
||
html_path = ReportGenerator.generate_html_report(
|
||
task_name, args, result, output_dir, execution_time
|
||
)
|
||
if html_path:
|
||
output_files['report'] = html_path
|
||
|
||
return output_files
|