萤火跑模型 | Informer 上手实践

Vachel    April 18, 2021

长时间序列预测(Long Sequence Time-Series Forecasting,以下称为 LTSF)在现实世界中是比较基础但又十分重要的研究场景,例如商品销量及库存的预测、电力消耗的规划、股票价格、疾病传播与扩散等实际问题。然而LTSF因为其历史数据量大、计算复杂性高、预测精度要求高,一直以来并没有取得太好的效果。 ​

今年人工智能顶级大会AAAI的最佳论文奖项中有一篇来自北京航空航天大学的工作:Informer,其主要的工作是改造 Transfomer 算法来实现LTSF,并开源了代码与数据。笔者最近在幻方AI的萤火平台上尝试复现了该论文的实验,为大家带来第一手的测试体验。 ​

论文标题:Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting

原文地址https://ojs.aaai.org/index.php/AAAI/article/view/17325

源码地址https://github.com/zhouhaoyi/Informer2020

模型仓库https://github.com/HFAiLab/ltsf-former ​ ​

模型介绍

近年来的研究表明,Transformer具有提高预测能力的潜力。然而,Transformer也存在几个问题,使其不能直接适用于LTSF问题,例如时间复杂度、高内存使用和“编码-解码”体系结构的固有局限性。为了解决这些问题,作者基于Transformer设计了一种适用于LTSF问题的模型,即Informer模型,该模型具有三个显著特征:

  1. ProbSpare self-attention机制,有效降低了时间复杂度和内存使用量。
  2. 通过将级联层输入减半来突出Self-attention中的主导因子,有效地处理过长的输入序列。
  3. 对长时间序列进行一次预测而不是一步步方式进行预测,极大提高了长序列预测的推理速度。

Informer模型的整体框架如下图所示,可以看出该模型仍然保存了Encoder-Decoder的架构:

image.png

编码过程(左):编码器接收长序列输入(绿色部分),通过ProbSparse自注意力模块和自注意力蒸馏模块,得到特征表示。ProbSparse自注意力机制利用稀疏矩阵来替代原来的注意力矩阵,大幅减少算力需求的同时并保持的良好的性能。

解码过程(右):解码器接收长序列输入(预测目标部分设置为0),通过多头注意力与编码特征进行交互,最后直接预测输出目标部分(橙黄色部分)。这里作者采用的是一次生成式预测方式,并说明该方式相比step-by-step方式推理速度更快且效果相当。

至于特征输入,在 LTSF 问题中,时序建模不仅需要局部时序信息还需要层次时序信息,如星期、月和年等,以及突发事件或某些节假日等。经典自注意力机制很难直接适配,因此Informer提出了三层特征输入形式,如下图所示:

image.png

Informer的输入分为了三种位置嵌入表示:

  1. 局部时间戳,即Transformer中的固定位置嵌入。
  2. 全局时间戳。对于层次时间信息,构建一个词汇表,通过Embedding特征表示每一个词汇。
  3. 对齐维度,使用一维卷积将输入的序列标量转化为向量。

模型实践

看完上述的介绍,是否有想一试算法的冲动了呢?笔者也是带着好奇,在幻方萤火集群上尝试复现论文中的实验效果。 ​

作者的开源比较完备,包含ETT(变压器温度)、ECL(耗电量)和WTH(气象)3个数据集,采用PyTorch实现且没有特殊包依赖的模型代码。作者提供了执行的脚本 scripts/*.sh,包括了不同的实验参数:

### M

python -u main_informer.py --model informer --data ETTh1 --features M --seq_len 48 --label_len 48 --pred_len 24 --e_layers 2 --d_layers 1 --attn prob --des 'Exp' --itr 5 --factor 3

python -u main_informer.py --model informer --data ETTh1 --features M --seq_len 96 --label_len 48 --pred_len 48 --e_layers 2 --d_layers 1 --attn prob --des 'Exp' --itr 5

python -u main_informer.py --model informer --data ETTh1 --features M --seq_len 168 --label_len 168 --pred_len 168 --e_layers 2 --d_layers 1 --attn prob --des 'Exp' --itr 5

### S

python -u main_informer.py --model informer --data ETTh1 --features S --seq_len 720 --label_len 168 --pred_len 24 --e_layers 2 --d_layers 1 --attn prob --des 'Exp' --itr 5

python -u main_informer.py --model informer --data ETTh1 --features S --seq_len 720 --label_len 168 --pred_len 48 --e_layers 2 --d_layers 1 --attn prob --des 'Exp' --itr 5

python -u main_informer.py --model informer --data ETTh1 --features S --seq_len 720 --label_len 336 --pred_len 168 --e_layers 2 --d_layers 1 --attn prob --des 'Exp' --itr 5

一共有60组不同的实验,我们用幻方萤火平台快速运行一下。 ​

萤火提供了统一的训练管理平台,可以将大量的训练任务依据优先级负载均衡到不同的GPU上执行,充分利用起计算算力,节省模型训练的时间。我们只需加入几行代码,便可以把训练任务提交给集群,喝杯咖啡的功夫就能拿到结果啦。

  1. 登录萤火平台,引入haienv, hfai
# ./main_informer.py
import haienv
haienv.set_env('202111')

# ./exp/exp_informer.py
import hfai
  1. 对于每一轮训练,加入接收集群调度的逻辑代码,并做好模型checkpoint的保存
# 获取当前训练的进度
for epoch in range(hfai.get_whole_life_state()//len(train_loader)%self.args.train_epochs, self.args.train_epochs):
    ...
    
    # 对于之前已经训练过的轮次,直接跳过
    if epoch*len(train_loader)+i <hfai.get_whole_life_state():
        continue
    ...
    
    # 状态更新
    steps += 1
    
    # 收到调度信号,保存模型,设置当前执行的状态信息
    if hfai.receive_suspend_command():
        ...
        hfai.set_whole_life_state(steps)
        hfai.go_suspend()

image.png

  1. 提交任务,等待执行结果

image.png

  1. 完成

image.png ​ Informer开源的测试数据不大(10MB以内),集群测试中单个Epoch执行花费6-7s左右。对每组实验配一块A100卡,我们很快就可以完成60组实验。在长时序单变量预测和长时序多变量预测这两个任务上,测试结果基本与论文所公布的MSE和MAE结果吻合,在某些参数上甚至跑出了更好的成绩。 ​

体验总结

​ Informer作为今年AAAI的Best Paper之一,对Transformer模型进行了很多切实有效的改进,使其计算、内存和体系结构更加高效。同时,作者也做了完整的开源,代码结构清晰,笔者能够很流畅的对其进行复现。借助萤火平台,我们很快就拿到了全部的实验数据。 ​ 综合体验打分如下:

  1. 研究指数:★★

    该模型着眼于长时间序列的预测问题,是一个广泛研究的基础课题。

  2. 开源指数:★★★★★

    数据和代码都已开源,代码逻辑清晰可读性高。

  3. 门槛指数:★★★★

    数据量小,模型计算、内存和体系结构均已优化,普通高性能PC即可运行。

  4. 通用指数:★★★

    模型是对Transformer的优化,深入改进了自注意力机制,能适用于长序列预测场景,其他场景待验证。

  5. 适配指数:★★★★★

    依赖简单,PyTorch框架构建的模型,只需要修改几行代码就能在萤火集群上执行。

幻方 AI 紧跟 AI 研究的前沿浪潮,致力于用领先算力助力AI落地与价值创造,欢迎各方数据研究者与开发者们一同共建。


本文作者: Vachel


You are free to reprint the content in this Blog or excerpt or quote it without contravening the authors' intentions under the following terms: Attribution — You shall give credit to the author(s), but not in any way that suggests High-Flyer endorses you or imposes any negative influence on High-Flyer's rights. Non Commercial — You may not use the content in this Blog for commercial purposes. No Derivatives — If you remix, transform or create upon the content, you may not publish or distribute the modified content but for personal use only.