分类 笔记 下的文章

先上效果图

运行效果图

背景

昨天写了一个小脚本,用Python帮我自动汇总业务人员的活动报告,节省了不少的时间。
俗话说,独乐乐 不如众乐乐,何不将其制作成可执行文件,分发给嘉智联的小伙伴们使用呢?

经常把大象放进冰箱里的帅哥都知道,将业务逻辑用Python验证后,用更高效语言重写的方法一般如下:

  • 打开AI助手
  • py文件和业务表格喂食给AI
  • 提出改写需求(顺便加个UI)
  • 测试并完成

原Python脚本

改写后的Rust脚本

#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]

use anyhow::{Context, Result};
use calamine::{open_workbook, Reader, Xlsx};
use chrono::NaiveDateTime;
use eframe::egui;
use rust_xlsxwriter::{
    Format, FormatAlign, FormatBorder, Workbook,
};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use walkdir::WalkDir;

// 线程安全的共享状态
struct SharedState {
    log: Vec<String>,
    status: String,
    processing: bool,
}

#[derive(Clone)]
struct ExcelSummaryApp {
    input_dir: String,
    output_path: String,
    shared_state: Arc<Mutex<SharedState>>,
}

impl Default for ExcelSummaryApp {
    fn default() -> Self {
        Self {
            input_dir: String::new(),
            output_path: String::new(),
            shared_state: Arc::new(Mutex::new(SharedState {
                log: Vec::new(),
                status: String::new(),
                processing: false,
            })),
        }
    }
}

#[derive(Debug)]
struct ActivityInfo {
    topic: String,
    jzl_person: String,
    location: String,
    time: String,
    organizer: String,
    contact_person: String,
    contact_number: String,
    agenda: String,
    summary: String,
    source_file: String,
}

impl eframe::App for ExcelSummaryApp {
    fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) {
        // 锁定共享状态以读取
        let state = self.shared_state.lock().unwrap();
        let status = state.status.clone();
        let log = state.log.clone();
        let processing = state.processing;
        drop(state); // 提前释放锁

        egui::CentralPanel::default().show(ctx, |ui| {
            ui.heading("嘉智联活动信息汇总工具");
            ui.separator();

            // 输入文件夹选择
            ui.horizontal(|ui| {
                ui.label("输入文件夹:");
                ui.text_edit_singleline(&mut self.input_dir);
                if ui.button("浏览...").clicked() {
                    if let Some(path) = rfd::FileDialog::new().pick_folder() {
                        self.input_dir = path.to_string_lossy().to_string();
                    }
                }
            });

            // 输出文件选择
            ui.horizontal(|ui| {
                ui.label("输出文件:");
                ui.text_edit_singleline(&mut self.output_path);
                if ui.button("浏览...").clicked() {
                    if let Some(path) = rfd::FileDialog::new()
                        .add_filter("Excel文件", &["xlsx"])
                        .save_file()
                    {
                        let path_str = path.to_string_lossy().to_string();
                        self.output_path = if path_str.ends_with(".xlsx") {
                            path_str
                        } else {
                            format!("{}.xlsx", path_str)
                        };
                    }
                }
            });

            ui.separator();

            // 处理按钮
            ui.horizontal(|ui| {
                let button_enabled = !processing 
                    && !self.input_dir.is_empty() 
                    && !self.output_path.is_empty()
                    && Path::new(&self.input_dir).exists();

                let button = ui.add_enabled(button_enabled, egui::Button::new("开始汇总").min_size(egui::vec2(120.0, 30.0)));
                
                if button.clicked() {
                    // 更新状态为处理中
                    let mut state = self.shared_state.lock().unwrap();
                    state.processing = true;
                    state.status = "正在处理...".to_string();
                    state.log.clear();
                    drop(state); // 释放锁

                    // 克隆必要的数据以传递给线程
                    let input_dir = self.input_dir.clone();
                    let output_path = self.output_path.clone();
                    let shared_state = Arc::clone(&self.shared_state);
                    
                    // 在新线程中处理,避免UI卡顿
                    std::thread::spawn(move || {
                        let mut log = Vec::new();
                        let result = process_excel_files(&input_dir, &output_path, &mut log);
                        
                        // 更新状态
                        let mut state = shared_state.lock().unwrap();
                        state.status = match result {
                            Ok(_) => "处理完成!".to_string(),
                            Err(e) => format!("处理失败: {}", e),
                        };
                        state.log = log;
                        state.processing = false;
                    });
                }
            });

            // 状态显示
            ui.label(&status);
            
            // 日志显示
            ui.separator();
            ui.label("处理日志:");
            egui::ScrollArea::vertical().show(ui, |ui| {
                for line in &log {
                    ui.label(line);
                }
            });
        });

        // 每帧请求重绘以更新状态
        ctx.request_repaint();
    }
}

fn process_excel_files(input_dir: &str, output_path: &str, log: &mut Vec<String>) -> Result<()> {
    log.push(format!("开始处理文件夹: {}", input_dir));
    
    // 查找所有Excel文件
    let excel_files: Vec<PathBuf> = WalkDir::new(input_dir)
        .into_iter()
        .filter_map(|e| e.ok())
        .filter(|e| {
            e.file_type().is_file() && 
            matches!(e.path().extension().and_then(|s| s.to_str()), 
            Some("xlsx") | Some("xls")) &&
            !e.path().file_name().and_then(|s| s.to_str()).map_or(false, |n| n.contains("汇总"))
        })
        .map(|e| e.path().to_path_buf())
        .collect();
    
    log.push(format!("找到 {} 个Excel文件", excel_files.len()));
    
    if excel_files.is_empty() {
        return Ok(());
    }
    
    // 提取每个文件的信息
    let mut activities = Vec::new();
    for path in &excel_files {
        log.push(format!("处理文件: {}", path.to_string_lossy()));
        
        match extract_activity_info(path) {
            Ok(info) => activities.push(info),
            Err(e) => log.push(format!("处理文件 {} 时出错: {}", path.to_string_lossy(), e)),
        }
    }
    
    // 生成输出Excel
    generate_output_excel(&activities, output_path, log)?;
    
    log.push(format!("汇总完成,结果保存至: {}", output_path));
    Ok(())
}
fn extract_activity_info(path: &Path) -> Result<ActivityInfo> {
    let mut workbook: Xlsx<_> = open_workbook(path)
        .with_context(|| format!("无法打开文件: {}", path.to_string_lossy()))?;

    let sheet_names = workbook.sheet_names().to_vec();

    if sheet_names.is_empty() {
        return Err(anyhow::anyhow!("文件中没有工作表"));
    }

    let first_sheet = &sheet_names[0];
    let range = workbook.worksheet_range(first_sheet)
        .with_context(|| format!("无法读取工作表: {}", first_sheet))?;

    let max_row = range.height() as usize;

    // 辅助函数:安全获取单元格值,使用具体的Xlsx数据类型
    fn get_cell_value(range: &calamine::Range<calamine::Data>, row: usize, col: usize) -> String {
        if row < range.height() as usize && col < range.width() as usize {
            match range.get((row, col)).unwrap_or(&calamine::Data::Empty) {
                calamine::Data::String(s) => s.clone(),
                calamine::Data::Float(f) => f.to_string(),
                calamine::Data::Int(i) => i.to_string(),
                calamine::Data::Bool(b) => b.to_string(),
                calamine::Data::DateTime(d) => d.to_string(),
                calamine::Data::Empty => String::new(),
                calamine::Data::Error(_) => String::new(),
                calamine::Data::DateTimeIso(s) => s.clone(),
                calamine::Data::DurationIso(s) => s.clone(),
            }
        } else {
            String::new()
        }
    }

    // 查找"嘉智联渠道市场推广活动报告"所在列
    let topic_col = (0..range.width() as usize)
        .find(|&col| get_cell_value(&range, 0, col).contains("嘉智联渠道市场推广活动报告"))
        .ok_or_else(|| anyhow::anyhow!("未找到标题列"))?;

    // 查找各Unnamed列(根据实际数据分布估算)
    let unnamed_2_col = topic_col + 2;
    let unnamed_5_col = topic_col + 5;
    let unnamed_7_col = topic_col + 7;
    let unnamed_8_col = topic_col + 8;

    // 提取所需内容,增加索引检查(与Python脚本保持一致)
    let activity_topic = if 2 < max_row { 
        get_cell_value(&range, 2, topic_col) 
    } else { 
        String::new() 
    };
    
    let jzl_person_in_charge = if 2 < max_row { 
        get_cell_value(&range, 2, unnamed_8_col) 
    } else { 
        String::new() 
    };
    
    let mut activity_location = if 5 < max_row { 
        get_cell_value(&range, 5, topic_col) 
    } else { 
        String::new() 
    };
    // 移除活动地点中的下划线
    activity_location = activity_location.replace('_', "");
    
    let activity_time = if 5 < max_row { 
        get_cell_value(&range, 5, unnamed_5_col) 
    } else { 
        String::new() 
    };
    
    let organizer = if 8 < max_row { 
        get_cell_value(&range, 8, topic_col) 
    } else { 
        String::new() 
    };
    
    let contact_person = if 8 < max_row { 
        get_cell_value(&range, 8, unnamed_7_col) 
    } else { 
        String::new() 
    };
    
    // 根据测试结果修正:联系电话在第8行,Unnamed:7列
    let contact_number = if 9 < max_row { 
        get_cell_value(&range, 9, unnamed_7_col) 
    } else { 
        String::new() 
    };

    // 活动议程部分
    let mut activity_agenda = String::new();
    let agenda_start = 12;
    let agenda_end = 19;
    
    if agenda_start < max_row {
        let actual_end = std::cmp::min(agenda_end, max_row - 1);
        let mut agenda_items = Vec::new();
        
        for row in agenda_start..=actual_end {
            let time = get_cell_value(&range, row, topic_col);
            let agenda = get_cell_value(&range, row, unnamed_2_col);
            let person = get_cell_value(&range, row, unnamed_7_col);
            
            // 处理时间格式,将Excel数字格式转换为时间格式
            let formatted_time = if let Ok(numeric_time) = time.parse::<f64>() {
                // 如果是数字格式的时间,转换为小时:分钟格式
                let hours = (numeric_time * 24.0).floor() as u32;
                let minutes = ((numeric_time * 24.0 * 60.0) % 60.0).round() as u32;
                format!("{:02}:{:02}", hours % 24, minutes)
            } else {
                time.clone()
            };
            
            if !formatted_time.is_empty() || !agenda.is_empty() {
                if !person.is_empty() {
                    agenda_items.push(format!("{} - {}({})", formatted_time, agenda, person));
                } else {
                    agenda_items.push(format!("{} - {}", formatted_time, agenda));
                }
            }
        }
        
        activity_agenda = agenda_items.join(";");
    }

    // 活动小结 - 从36行到43行(index=35到42)和A列到J列(index=0到9)提取内容
    let mut activity_summary = String::new();
    let summary_start = 35;
    let summary_end = 42;
    let col_start = 0;
    let col_end = 9;
    
    if summary_start < max_row {
        let actual_summary_end = std::cmp::min(summary_end, max_row - 1);
        let actual_col_end = std::cmp::min(col_end, range.width() as usize - 1);
        
        let mut summary_cells = Vec::new();
        for row_idx in summary_start..=actual_summary_end {
            for col_idx in col_start..=actual_col_end {
                let cell_value = get_cell_value(&range, row_idx, col_idx);
                if !cell_value.trim().is_empty() {
                    summary_cells.push(cell_value.trim().to_string());
                }
            }
        }
        
        activity_summary = summary_cells.join(" ");
    }
    
    Ok(ActivityInfo {
        topic: activity_topic.trim().to_string(),
        jzl_person: jzl_person_in_charge.trim().to_string(),
        location: activity_location.trim().to_string(),
        time: activity_time.trim().to_string(),
        organizer: organizer.trim().to_string(),
        contact_person: contact_person.trim().to_string(),
        contact_number: contact_number.trim().to_string(),
        agenda: activity_agenda,
        summary: activity_summary,
        source_file: path.file_name().and_then(|n| n.to_str()).unwrap_or("").to_string(),
    })
}
fn generate_output_excel(activities: &[ActivityInfo], output_path: &str, _log: &mut Vec<String>) -> Result<()> {
    let mut workbook = Workbook::new();
    
    // 创建格式 - 使用构建器模式避免所有权问题
    let header_format = Format::new()
        .set_bold()
        .set_border(FormatBorder::Thin)
        .set_align(FormatAlign::Center);
    
    let default_format = Format::new()
        .set_border(FormatBorder::Thin)
        .set_text_wrap();
    
    let date_format = Format::new()
        .set_border(FormatBorder::Thin)
        .set_text_wrap()
        .set_align(FormatAlign::Center)
        .set_num_format("yyyy-mm-dd hh:mm");
    
    let phone_format = Format::new()
        .set_border(FormatBorder::Thin)
        .set_text_wrap()
        .set_num_format("@");
    
    let agenda_format = Format::new()
        .set_border(FormatBorder::Thin)
        .set_text_wrap();
    
    // 添加工作表
    let worksheet = workbook.add_worksheet();
    
    // 设置列宽
    worksheet.set_column_width(0, 30.0)?;  // 活动主题
    worksheet.set_column_width(1, 12.0)?;  // 嘉智联担当
    worksheet.set_column_width(2, 25.0)?;  // 活动地点
    worksheet.set_column_width(3, 20.0)?;  // 活动时间
    worksheet.set_column_width(4, 15.0)?;  // 主办单位
    worksheet.set_column_width(5, 12.0)?;  // 联系人
    worksheet.set_column_width(6, 15.0)?;  // 联系电话
    worksheet.set_column_width(7, 40.0)?;  // 活动议程
    worksheet.set_column_width(8, 50.0)?;  // 活动小结
    worksheet.set_column_width(9, 20.0)?;  // 来源文件
    
    // 写入表头
    let headers = [
        "活动主题", "嘉智联担当", "活动地点", "活动时间", "主办单位",
        "联系人", "联系电话", "活动议程", "活动小结", "原始报告"
    ];
    
    for (col, header) in headers.iter().enumerate() {
        worksheet.write_with_format(0, col as u16, *header, &header_format)?;
    }
    
    // 写入数据
    for (row_idx, activity) in activities.iter().enumerate() {
        let row = (row_idx + 1) as u32;
        
        worksheet.write_with_format(row, 0, &activity.topic, &default_format)?;
        worksheet.write_with_format(row, 1, &activity.jzl_person, &default_format)?;
        worksheet.write_with_format(row, 2, &activity.location, &default_format)?;
        
        // 尝试解析日期
        if let Ok(date) = NaiveDateTime::parse_from_str(&activity.time, "%Y年%m月%d日 %H:%M") {
            worksheet.write_with_format(row, 3, date.to_string(), &date_format)?;
        } else {
            worksheet.write_with_format(row, 3, &activity.time, &date_format)?;
        }
        
        worksheet.write_with_format(row, 4, &activity.organizer, &default_format)?;
        worksheet.write_with_format(row, 5, &activity.contact_person, &default_format)?;
        worksheet.write_with_format(row, 6, &activity.contact_number, &phone_format)?;
        // 使用专用格式写入活动议程列,并将分号替换为换行符以实现真正的自动换行
        let formatted_agenda = activity.agenda.replace(";", "\n");
        worksheet.write_with_format(row, 7, &formatted_agenda, &agenda_format)?;
        worksheet.write_with_format(row, 8, &activity.summary, &default_format)?;
        worksheet.write_with_format(row, 9, &activity.source_file, &default_format)?;
    }
    
    // 保存文件
    workbook.save(output_path).with_context(|| format!("无法保存文件: {}", output_path))?;
    Ok(())
}

fn main() -> Result<()> {
    let options = eframe::NativeOptions {
        initial_window_size: Some(egui::vec2(800.0, 600.0)),
        ..Default::default()
    };
    
    let _ = eframe::run_native(
        "嘉智联活动信息汇总工具",
        options,
        Box::new(|cc| {
            // 设置中文字体支持 - 动态检测并使用系统字体
            setup_fonts(cc);
            Box::new(ExcelSummaryApp::default())
        }),
    );
    
    Ok(())
}

// 设置字体的函数
fn setup_fonts(cc: &eframe::CreationContext<'_>) {
    // 获取系统字体
    let mut fonts = egui::FontDefinitions::default();
    
    // 在Windows系统上尝试使用系统默认中文字体
    if cfg!(target_os = "windows") {
        // Windows系统常见的中文字体路径
        let system_fonts = [
            "C:\\Windows\\Fonts\\msyh.ttc",      // 微软雅黑
            "C:\\Windows\\Fonts\\msyh.ttf",      // 微软雅黑
            "C:\\Windows\\Fonts\\simhei.ttf",    // 黑体
            "C:\\Windows\\Fonts\\simsun.ttc",    // 宋体
            "C:\\Windows\\Fonts\\simkai.ttf",    // 楷体
        ];
        
        // 检查系统字体文件是否存在,如果存在则使用
        for font_path in &system_fonts {
            if std::path::Path::new(font_path).exists() {
                match std::fs::read(font_path) {
                    Ok(font_data) => {
                        fonts.font_data.insert(
                            "SystemChineseFont".to_owned(),
                            egui::FontData::from_owned(font_data),
                        );
                        
                        fonts.families.get_mut(&egui::FontFamily::Proportional).unwrap()
                            .insert(0, "SystemChineseFont".to_owned());
                        
                        fonts.families.get_mut(&egui::FontFamily::Monospace).unwrap()
                            .insert(0, "SystemChineseFont".to_owned());
                        
                        // 找到第一个可用字体就退出循环
                        break;
                    }
                    Err(_) => continue,
                }
            }
        }
    }
    
    cc.egui_ctx.set_fonts(fonts);
}

最终生成exe文件,文件大小仅仅4335K,以及超快的运行速度,这是Python无法比拟的优势。

分享一个Python自动化案例

背景

  • 渠道正在开展一系列共性的市场活动
  • 我们需要总结上一阶段的活动
  • 统计结果以及结算费用
  • 事先通过销售总监安排了每场活动的小结,
  • 计划写一个脚本自动汇总小结生成月度报告
import pandas as pd
import os
import fnmatch


def extract_summary_info(excel_path):
    try:
        # 读取文件
        excel_file = pd.ExcelFile(excel_path)
        # 获取所有表名
        sheet_names = excel_file.sheet_names
        if not sheet_names:
            raise ValueError(f"文件 {excel_path} 中没有工作表")

        # 获取指定工作表中的数据
        df = excel_file.parse(sheet_names[0])

        # 提取所需内容
        activity_topic = df.loc[1, '嘉智联渠道市场推广活动报告']
        jzl_person_in_charge = df.loc[1, 'Unnamed: 8']
        activity_location = df.loc[4, '嘉智联渠道市场推广活动报告']
        activity_time = df.loc[4, 'Unnamed: 5']
        organizer = df.loc[7, '嘉智联渠道市场推广活动报告']
        contact_person = df.loc[7, 'Unnamed: 7']
        contact_number = df.loc[8, 'Unnamed: 8']

        # 活动议程部分,从第 12 行(index=11)到第 15 行(index=14)获取时间和议程信息
        agenda_rows = df.loc[11:14, ['嘉智联渠道市场推广活动报告', 'Unnamed: 2', 'Unnamed: 7']]
        agenda_rows.columns = ['开始时间', '议程', '负责担当']
        activity_agenda = ';'.join([
            f"{row['开始时间']} - {row['议程']}({row['负责担当']})" if pd.notna(row['负责担当']) 
            else f"{row['开始时间']} - {row['议程']}" for _, row in agenda_rows.iterrows()
        ])
        activity_summary = df.loc[35, '嘉智联渠道市场推广活动报告']

        # 创建汇总数据的 DataFrame
        data = {
            '活动主题': [activity_topic],
            '嘉智联担当': [jzl_person_in_charge],
            '活动地点': [activity_location],
            '活动时间': [activity_time],
            '主办单位': [organizer],
            '联系人': [contact_person],
            '联系电话': [contact_number],
            '活动议程': [activity_agenda],
            '活动小结': [activity_summary]
        }

        result_df = pd.DataFrame(data)
        return result_df
    except Exception as e:
        print(f"处理文件 {excel_path} 时出现错误: {e}")
        return pd.DataFrame()


def find_excel_files(root_folder):
    excel_files = []
    for root, dirs, files in os.walk(root_folder):
        for file in files:
            if fnmatch.fnmatch(file, '*.xlsx') or fnmatch.fnmatch(file, '*.xls'):
                excel_files.append(os.path.join(root, file))
    return excel_files


def main(file_list, result_file_path):
    dfs = []
    excel_files = file_list
    for excel_file in excel_files:
        result_df = extract_summary_info(excel_file)
        dfs.append(result_df)

    # 循环结束后一次性合并 DataFrame
    if dfs:
        result = pd.concat(dfs, ignore_index=True)
        result.to_excel(result_file_path, index=False)
        print(result)
    else:
        print("没有找到有效的数据")


if __name__ == '__main__':
    path = r".\会议资料2507"
    main(find_excel_files(path), r".\会议汇总2507.xlsx")

完美,俺又可以愉快的喝茶了~

根据历史销售数据训练预测模型

评估以下模型,找出契合度最高的

  • '移动平均'
  • '自回归模型 (AR)'
  • 'ARIMA 模型'
  • '简单指数平滑法'
  • '季节性 ARIMA 模型 (SARIMAX)'

根据预测模型预测下一阶段的销售趋势

ps.想要用Rust重写,但评估了Rust的学习曲线后,果断放弃。以后有时间再考虑吧~


import pandas as pd
from rich.logging import RichHandler
from statsmodels.tsa.ar_model import AutoReg
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.holtwinters import SimpleExpSmoothing
from statsmodels.tsa.statespace.sarimax import SARIMAX
from prophet import Prophet
from sklearn.metrics import mean_squared_error
import warnings
import logging
from time_counter import timeit  # 导入本地的 time_counter.py 中的 timeit 装饰器
import itertools
import numpy as np
from statsmodels.tsa.seasonal import STL  # 新增导入

logging.basicConfig(
    level="NOTSET",
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler()]
)

log = logging.getLogger("rich")

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 忽略警告
warnings.filterwarnings("ignore")

# 配置参数
DATA_FILE = '20210401_20250430.xlsx'
DATE_COLUMN = 'Created On(Delivery)'
PRODUCT_NAME_COLUMN = 'Material Short Text'
QUANTITY_COLUMN = '发货单数量'
START_DATE = '2021-01-01'
FREQ = 'MS'

# 将 methods 字典定义为全局变量
methods = {
    '移动平均': lambda data: data.rolling(window=3).mean().iloc[-1],
    '自回归模型 (AR)': lambda data: AutoReg(data, lags=3).fit().forecast(steps=5),
    'ARIMA 模型': lambda data: ARIMA(data, order=(1, 1, 1)).fit().forecast(steps=5),
    '简单指数平滑法': lambda data: SimpleExpSmoothing(data).fit().forecast(steps=5),
    '季节性 ARIMA 模型 (SARIMAX)': lambda data: SARIMAX(data, order=(1, 1, 1), seasonal_order=(1, 1, 1, 12)).fit().forecast(steps=5),
    'Prophet 模型': lambda data: Prophet().fit(pd.DataFrame({'ds': pd.date_range(start=START_DATE, periods=len(data), freq=FREQ), 'y': data})).make_future_dataframe(periods=5, freq=FREQ).tail(5)['yhat']
}

@timeit
def read_and_preprocess_data(product_name):
    """
    读取并预处理数据
    :param product_name: 产品名称
    :return: 预处理后的数据
    """
    try:
        logging.info(f"正在读取 Excel 文件: {DATA_FILE}")
        df = pd.read_excel(DATA_FILE)
        logging.info("Excel 文件读取完成")
        logging.info(f"读取到的数据总行数: {len(df)}")
    except FileNotFoundError:
        logging.error(f"未找到文件: {DATA_FILE}")
        return None

    df[PRODUCT_NAME_COLUMN] = df[PRODUCT_NAME_COLUMN].str.strip().str.upper()
    product_name = product_name.strip().upper()

    logging.info(f"正在筛选 {PRODUCT_NAME_COLUMN} 为 {product_name} 的数据...")
    filtered_df = df[df[PRODUCT_NAME_COLUMN] == product_name].copy()
    logging.info(f"筛选后的数据总行数: {len(filtered_df)}")
    if len(filtered_df) == 0:
        logging.error(f"未找到 {product_name} 相关数据,请检查数据文件。")
        return None
    logging.info("数据筛选完成")

    try:
        logging.info(f"正在将 {DATE_COLUMN} 转换为日期序列...")
        filtered_df[DATE_COLUMN] = pd.to_datetime(filtered_df[DATE_COLUMN])
        logging.info("日期序列转换完成")
    except ValueError:
        logging.error(f"无法将 {DATE_COLUMN} 转换为日期序列,请检查数据格式。")
        return None

    # 提取年 - 月信息
    filtered_df['YearMonth'] = filtered_df[DATE_COLUMN].dt.to_period('M')

    logging.info("正在提取年、月信息...")
    filtered_df['Year'] = filtered_df[DATE_COLUMN].dt.year
    filtered_df['Month'] = filtered_df[DATE_COLUMN].dt.month
    logging.info("年、月信息提取完成")

    logging.info("正在汇总年、月的发货单数量并转换为整数...")
    aggregated_df = filtered_df.groupby(['YearMonth', 'Year', 'Month'])[QUANTITY_COLUMN].sum().astype(int).reset_index()
    aggregated_df.rename(columns={QUANTITY_COLUMN: 'Actual'}, inplace=True)
    logging.info("发货单数量汇总完成")

    # 检查日期连续性(补充缺失月份)
    logging.info("正在检查日期连续性并补充缺失月份...")
    full_date_range = pd.date_range(start=START_DATE, end=aggregated_df['YearMonth'].dt.to_timestamp().max(), freq=FREQ)
    aggregated_df['YearMonth'] = aggregated_df['YearMonth'].dt.to_timestamp()
    aggregated_df = aggregated_df.set_index('YearMonth').reindex(full_date_range).fillna(0).reset_index()
    aggregated_df.rename(columns={'index': 'YearMonth'}, inplace=True)
    aggregated_df['Year'] = aggregated_df['YearMonth'].dt.year
    aggregated_df['Month'] = aggregated_df['YearMonth'].dt.month
    logging.info("日期连续性检查和缺失月份补充完成")

    # 新增季节性分解代码
    stl = STL(aggregated_df['Actual'], period=12).fit()
    aggregated_df = aggregated_df.assign(
        trend=stl.trend,
        seasonal=stl.seasonal,
        residual=stl.resid
    )
    logging.info("季节性分解完成")

    return aggregated_df

@timeit
def train_and_predict(method_name, train_data, steps):
    """
    训练模型并进行预测
    :param method_name: 预测方法名称
    :param train_data: 训练数据
    :param steps: 预测步数
    :return: 预测结果
    """
    if method_name == '移动平均':
        return [methods[method_name](train_data)] * steps
    elif method_name == '自回归模型 (AR)':
        # 定义参数范围
        lags_range = range(1, 10)
        best_aic = float('inf')
        best_lags = None
        for lags in lags_range:
            try:
                model = AutoReg(train_data, lags=lags).fit()
                if model.aic < best_aic:
                    best_aic = model.aic
                    best_lags = lags
            except:
                continue
        model = AutoReg(train_data, lags=best_lags).fit()
        return model.forecast(steps=steps)
    elif method_name == 'ARIMA 模型':
        # 定义参数范围
        p = d = q = range(0, 2)
        pdq = list(itertools.product(p, d, q))
        best_aic = float('inf')
        best_pdq = None
        for param in pdq:
            try:
                model = ARIMA(train_data, order=param).fit()
                if model.aic < best_aic:
                    best_aic = model.aic
                    best_pdq = param
            except:
                continue
        model = ARIMA(train_data, order=best_pdq).fit()
        return model.forecast(steps=steps)
    elif method_name == '简单指数平滑法':
        # 定义参数范围
        smoothing_levels = np.linspace(0.1, 1, 10)
        best_mse = float('inf')
        best_smoothing_level = None
        for smoothing_level in smoothing_levels:
            try:
                model = SimpleExpSmoothing(train_data).fit(smoothing_level=smoothing_level)
                forecast = model.forecast(steps=steps)
                mse = mean_squared_error(train_data, forecast[:len(train_data)])
                if mse < best_mse:
                    best_mse = mse
                    best_smoothing_level = smoothing_level
            except:
                continue
        model = SimpleExpSmoothing(train_data).fit(smoothing_level=best_smoothing_level)
        return model.forecast(steps=steps)
    elif method_name == '季节性 ARIMA 模型 (SARIMAX)':
        # 定义参数范围
        p = d = q = range(0, 2)
        P = D = Q = range(0, 2)
        pdq = list(itertools.product(p, d, q))
        seasonal_pdq = [(x[0], x[1], x[2], 12) for x in list(itertools.product(P, D, Q))]
        best_aic = float('inf')
        best_pdq = None
        best_seasonal_pdq = None
        for param in pdq:
            for param_seasonal in seasonal_pdq:
                try:
                    model = SARIMAX(train_data, order=param, seasonal_order=param_seasonal).fit()
                    if model.aic < best_aic:
                        best_aic = model.aic
                        best_pdq = param
                        best_seasonal_pdq = param_seasonal
                except:
                    continue
        model = SARIMAX(train_data, order=best_pdq, seasonal_order=best_seasonal_pdq).fit()
        return model.forecast(steps=steps)
    elif method_name == 'Prophet 模型':
        train_df = pd.DataFrame({'ds': pd.date_range(start=START_DATE, periods=len(train_data), freq=FREQ), 'y': train_data})
        model = Prophet()
        model.fit(train_df)
        future = model.make_future_dataframe(periods=steps, freq=FREQ)
        return model.predict(future)['yhat'][-steps:]

@timeit
def evaluate_forecasting_methods(aggregated_df, train_size):
    """
    评估不同的预测方法
    :param aggregated_df: 汇总后的数据
    :param train_size: 训练集大小
    :return: 最佳预测方法和最佳均方误差
    """
    test_size = len(aggregated_df) - train_size
    train_data = aggregated_df['Actual'].iloc[:train_size]
    test_data = aggregated_df['Actual'].iloc[train_size:]
    logging.info(f"训练集数据量: {len(train_data)},测试集数据量: {len(test_data)}")
    logging.info("训练集和测试集划分完成")

    logging.info("开始评估每种预测方法...")
    best_method = None
    best_mse = float('inf')
    for method_name in methods:
        logging.info(f"正在使用 {method_name} 方法进行预测...")
        try:
            forecast = train_and_predict(method_name, train_data, test_size)
            mse = mean_squared_error(test_data, forecast)
            logging.info(f"{method_name} 方法的均方误差 (MSE): {mse}")
            if mse < best_mse:
                best_mse = mse
                best_method = method_name
        except Exception as e:
            logging.error(f"{method_name} 计算时出现错误: {e}")
    logging.info("预测方法评估完成")
    return best_method, best_mse

def generate_future_dates(aggregated_df, forecast_steps):
    """
    生成未来多个月的年和月信息
    :param aggregated_df: 汇总后的数据
    :param forecast_steps: 预测步数
    :return: 未来日期列表
    """
    last_year = aggregated_df['Year'].iloc[-1]
    last_month = aggregated_df['Month'].iloc[-1]
    future_dates = []
    for i in range(1, forecast_steps + 1):
        next_month = (last_month + i) % 12
        if next_month == 0:
            next_month = 12
        next_year = last_year + (last_month + i - 1) // 12
        future_dates.append((next_year, next_month))
    return future_dates

@timeit
def make_final_forecast(aggregated_df, best_method, product_name, forecast_steps):
    if best_method:
        print(f"正在使用 {best_method} 方法对未来 {forecast_steps} 个月度周期进行预测...")
        if best_method == 'Prophet 模型':
            train_df = pd.DataFrame({'ds': pd.date_range(start='2021-01-01', periods=len(aggregated_df['Actual']), freq='MS'), 'y': aggregated_df['Actual']})
            model = Prophet()
            model.fit(train_df)
            future = model.make_future_dataframe(periods=forecast_steps, freq='MS')
            final_forecast = model.predict(future)['yhat'][-forecast_steps:]
        else:
            final_forecast = (methods[best_method])(aggregated_df['Actual'])
        # 处理移动平均方法返回单一数值的情况
        if best_method == '移动平均':
            final_forecast = [final_forecast] * forecast_steps
        # 将预测结果转换为整数类型
        final_forecast = pd.Series(final_forecast).astype(int)

        # 获取最后一个已知月份的年和月
        last_year = aggregated_df['Year'].iloc[-1]
        last_month = aggregated_df['Month'].iloc[-1]

        # 生成未来多个月的年和月信息
        future_dates = []
        for i in range(1, forecast_steps + 1):
            next_month = (last_month + i) % 12
            if next_month == 0:
                next_month = 12
            next_year = last_year + (last_month + i - 1) // 12
            future_dates.append((next_year, next_month))

        # 创建一个包含预测结果和对应年月的 DataFrame
        forecast_df = pd.DataFrame({
            'Year': [year for year, _ in future_dates],
            'Month': [month for _, month in future_dates],
            'Actual': [None] * forecast_steps,  # 预测部分实际值为空
            'Forecast': final_forecast
        })

        # 合并历史数据和预测数据
        combined_df = pd.concat([aggregated_df, forecast_df], ignore_index=True)

        # 打印包含年和月信息的预测结果
        print(f"未来 {forecast_steps} 个月度周期的预测结果:")
        for (year, month), forecast in zip(future_dates, final_forecast):
            print(f"{year}年{month}月: {forecast}")

        # 保存为 CSV 文件,文件名包含 product_name
        csv_filename = f'{product_name}_forecast.csv'
        combined_df.to_csv(csv_filename, index=False)
        print(f"预测结果已保存到 {csv_filename}")

if __name__ == "__main__":
    product_name = input("请输入要预测的产品名称: ").strip()
    # product_name = 'IUP27 硒鼓'  # 可修改为其他产品名称
    forecast_steps = 12  # 预测周期
    train_percentage_minmax = (0.5, 0.9)  # 训练集百分比的最小值和最大值,这里以 70% 到 90% 为例
    aggregated_df = read_and_preprocess_data(product_name)

    best_overall_mse = float('inf')
    best_overall_method = None
    best_train_percentage = None

    # 遍历不同的训练集百分比
    for train_percentage in [i / 100 for i in range(int(train_percentage_minmax[0] * 100), int(train_percentage_minmax[1] * 100) + 1)]:
        train_size = int(len(aggregated_df) * train_percentage)
        print(f"\n正在评估训练集百分比为 {train_percentage * 100}% 的情况...")
        best_method, best_mse = evaluate_forecasting_methods(aggregated_df, train_size)
        if best_mse < best_overall_mse:
            best_overall_mse = best_mse
            best_overall_method = best_method
            best_train_percentage = train_percentage

    print(f"\n最佳整体训练集百分比: {best_train_percentage * 100}%")
    print(f"最佳整体预测方法: {best_overall_method},均方误差: {best_overall_mse}")
    best_train_size = int(len(aggregated_df) * best_train_percentage)
    make_final_forecast(aggregated_df, best_overall_method, product_name, forecast_steps)