使用 TiDE 进行时间序列预测

人工智能
TiDE(Time-series Dense Encoder)是一种用于时间序列预测的机器学习模型。它的全称是时间序列密集编码器,是一种基于多层感知机(MLP)结构的模型,专门设计用于处理多变量、长期的时间序列预测问题。

今天云朵君和大家一起学习一种新颖的时间序列预测模型 - TiDE(Time-series Dense Encoder)。

时间序列预测一直是数据科学领域的一个热门研究课题,广泛应用于能源、金融、交通等诸多行业。传统的统计模型如ARIMA、GARCH等因其简单高效而被广泛使用。而近年来,随着深度学习的兴起,基于神经网络的预测模型也备受关注,表现出了强大的预测能力。

其中,Transformer模型因其出色的捕捉长期依赖关系的能力,一度被认为是解决时间序列预测问题的利器。但最新研究发现,这些基于Transformer的模型在长期预测任务中,性能并不如人意,反而被一些简单的线性模型超越。

有鉴于此,谷歌的研究团队在2023年提出了TiDE模型。该模型摒弃了Transformer的复杂结构,转而采用了多层感知器(MLP)的编码器-解码器架构。虽然设计简洁,但TiDE能有效捕捉时间序列的非线性依赖关系,并能很好地处理动态协变量和静态属性数据,展现出令人惊艳的预测性能。

在多个公开基准数据集的实验中,TiDE不仅精度超越了当前最优模型,而且在推理速度和训练效率上也领先于Transformer模型5-10倍以上。这种简单高效的特点使TiDE非常适合应用于工业级的大规模部署场景。

如果您对TiDE模型的原理和细节有进一步的了解兴趣,我们强烈推荐大家阅读原论文(https://arxiv.org/pdf/2304.08424.pdf)。希望TiDE这一创新预测模型能为时间序列分析领域注入新的活力,为解决实际问题提供更多的可能性。

探索 TiDE

TiDE 这个名字看似生涩,其实就是" Time-series Dense Encoder "的英文缩写。它的设计思路非常巧妙,摒弃了目前流行的转换器(Transformer)结构,而是采用了编码器-解码器的框架,使用简单的多层感知器(MLP)网络来完成编码和解码的工作。

那它是如何工作的呢?首先,编码器会将历史的时间序列数据和相关的协变量(如节假日、促销活动等)输入进去,学习一个紧凑的表示向量,捕捉数据的内在模式。接下来,解码器会根据这个向量,结合已知的未来时间步的协变量,生成相应的预测值。

TiDE的巧妙之处在于,它利用MLP的非线性映射能力来提取复杂特征,同时避免了转换器的注意力计算,大幅提高了模型的训练和预测速度。事实上,在多个公开数据集的测试中,TiDE不仅精度超过了现有最佳模型,其运算效率甚至比基于转换器的模型快了5-10倍之多!

这种高效而精准的特性,使得TiDE十分适合工业级的大规模部署场景。如果你对时间序列预测有研究兴趣,不防一探 TiDE 模型的奥秘。

TiDE 的结构

TiDE 的架构如下图所示。

TiDE 的结构TiDE 的结构

从上图我们可以看出,该模型将每个序列视为一个独立通道,即每次只传递一个序列及其协变量。

我们还可以看到,该模型有三个主要组成部分:编码器、解码器和时序解码器,它们都依赖于残差块结构。

这张图包含了很多信息,让我们来更详细地探讨每个组件。

探索残差块

如前所述,残差块是 TiDE 架构的基础层。

残差块的组成残差块的组成

从上图中,我们可以看到这是一个具有一个隐藏层和 ReLU 激活的 MLP。然后是一个剔除层、一个跳转连接和最后的层归一化步骤。

然后,这个组件会在整个网络中重复使用,以进行编码、解码和预测。

了解编码器

在这一步中,模型会将时间序列的过去和协变因素映射到一个密集的表示中。

第一步是进行特征投影。这就是利用残差块将动态协变量(随时间变化的外生变量)映射到低维投影中。

请记住,在进行多元预测时,我们需要特征的未来值。因此,模型必须处理回望窗口和水平序列。

这些序列可能会很长,因此,通过向低维空间投影,我们可以保持长度可控,并允许模型处理更长的序列,包括历史窗口和预测范围。

第二步是将序列的过去与其属性以及过去和未来协变量的投影连接起来。然后将其发送给编码器,编码器就是一叠残差块。

因此,编码器负责学习输入的表示。这可以看作是一种学习嵌入。

完成后,嵌入将被发送到解码器。

了解解码器

在这里,解码器负责接收编码器的学习表示并生成预测。

第一步是密集解码器,它也是由一叠残差块组成。它获取编码信息并输出一个矩阵,然后输入时序解码器。

解码输出与预测特征堆叠,以捕捉未来协变量的直接影响。例如,节假日是准时事件,会对某些时间序列产生重要影响。有了这种残差联系,模型就能捕捉并利用这些信息。

第二步是时空解码器,在此生成预测结果。在这里,它只是一个输出大小为 1 的残差块,这样我们就能得到给定时间序列的预测结果。

现在,我们已经了解了 TiDE 的每个关键组成部分,让我们用 Python 将其应用到一个小型预测项目中。

使用 TiDE 进行预测

现在,让我们在一个小型预测项目中应用 TiDE,并将其性能与 TSMixer 进行比较。

有趣的是,TSMixer 也是谷歌研究人员开发的基于 MLP 的多元预测架构,但它比 TiDE 早一个月发布。因此,我认为在一个小实验中比较这两种模型是很有趣的。

Etth1 数据集: https://github.com/zhouhaoyi/ETDataset。

这是文献中广泛使用的时间序列预测基准。它与其他协变量一起跟踪电力变压器的每小时油温,是进行多元预测的绝佳场景。

导入库并读取数据

第一步自然是导入项目所需的库并读取数据。

虽然 TiDE 原始论文的源代码在  上公开,但我还是选择使用 Darts 中的实现。

GitHub: https://github.com/google-research/google-research/tree/master/tide)

它将为我们提供更大的灵活性,而且它还带有超参数优化功能,而这些功能在原始代码库中是没有的。

导入 darts 以及其他标准软件包。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.datasets import ETTh1Dataset

然后,我们就可以读取数据了。Darts 提供了学术界常用的标准数据集,比如 Etth1 数据集。

series = ETTh1Dataset().load()

拆分数据,将最后 96 个时间步骤保留给测试集。

train, test = series[:-96], series[-96:]

训练 TiDE

要访问 TiDE,只需从 darts 库中导入它。在训练之前,还需要手动缩放数据。这样可以确保训练过程更快、更稳定。

from darts.models.forecasting.tide_model import TiDEModel
from darts.dataprocessing.transformers import Scaler

train_scaler = Scaler()
scaled_train = train_scaler.fit_transform(train)

然后,初始化模型并指定其参数。在这里,我使用的优化参数与论文中针对该特定数据集介绍的参数相同。

tide = TiDEModel(
    input_chunk_length=720, 
    output_chunk_length=96,
    num_encoder_layers=2,
    num_decoder_layers=2,
    decoder_output_dim=32,
    hidden_size=512,
    temporal_decoder_hidden=16,
    use_layer_norm=True,
    dropout=0.5,
    random_state=42)

然后,就可以简单地训练30 个epochs拟合模型了。

tide.fit(
    scaled_train,
    epochs=30
)

一旦模型完成训练,我们就可以访问其预测结果。请注意,由于我们对训练数据进行了缩放,因此模型也会输出缩放的预测结果。因此,我们必须反向转换。

scaled_pred_tide = tide.predict(n=96)

pred_tide = train_scaler.inverse_transform(scaled_pred_tide)

完美!然后,我们就可以评估 TiDE 的性能了。

评估性能

为了评估模型的性能,我们将预测值和实际值存储在一个 DataFrame 中。

preds_df = pred_tide.pd_dataframe()
test_df = test.pd_dataframe()

我们还可以选择将预测结果可视化。为了简单起见,我只绘制了四列。

cols_to_plot = ['OT', 'HULL', 'MUFL', 'MULL']

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12,8))

for i, ax in enumerate(axes.flatten()):
    col = cols_to_plot[i]
        
    ax.plot(test_df[col], label='Actual', ls='-', color='blue')
    ax.plot(preds_df[col], label='TiDE', ls='--', color='green')
    
    ax.legend(loc='best')
    ax.set_xlabel('Date')
    ax.set_title(col)
    
plt.tight_layout()
fig.autofmt_xdate()

可视化 TiDE 预测可视化 TiDE 预测

从上图中我们可以看出,TiDE 对每个序列的预测都相当出色。

当然,评估性能的最佳方法是计算误差指标,因此我们来计算一下平均绝对误差(MAE)和平均平方误差(MSE)。

from darts.metrics import mae, mse

tide_mae = mae(test, pred_tide)
tide_mse = mse(test, pred_tide)

print(tide_mae, tide_mse)

由此得出 MAE 为 1.19,MSE 为 3.58。

目前,还没有现成的实现方法,因此我们必须手动完成许多步骤。

现在,我们只报告 TSMixer 在 Etth1 数据集上对 96 个时间步长进行多元预测的性能。

图片图片

TiDE 和 TSMixer 对 Etth1 数据集在 96 个时间步长范围内进行多元预测的性能指标。我们可以看到,TiDE 的性能最好。

我们使用了一个名为Etth1的标准数据集,在96个时间步长的范围内进行评估。结果显示,在这个数据集上,TiDE模型的平均绝对误差(MAE)和均方误差(MSE)都比TSMixer更低,这意味着TiDE在预测精度上表现更优秀。

当然,这只是一个有限的实验案例,并不能完全说明TiDE在任何情况下都会胜过TSMixer。事实上,TiDE可能是对TSMixer的一种渐进式改进。因此,对于每个具体的应用场景,我们都应当分别评估并选择最适合的模型。

总的来说,时间序列预测是一个错综复杂的领域,没有放之四海而皆准的万能模型。选择合适的模型需要结合具体数据和应用场景,并进行反复试验和调优。我们应该保持开放和客观的态度,虚心学习不同模型的优缺点,努力寻找最佳实践。

写在最后

TiDE(Time-series Dense Encoder)是一种用于时间序列预测的机器学习模型。它的全称是时间序列密集编码器,是一种基于多层感知机(MLP)结构的模型,专门设计用于处理多变量、长期的时间序列预测问题。

TiDE模型的工作原理是,首先利用残差模块对协变量(影响预测目标的其他相关变量)和历史数据进行编码,将它们映射到一个内部表示空间中。然后,模型会对这个学习到的内部表示进行解码,从而生成对未来时间步的预测值。

由于TiDE模型结构仅包含全连接层,因此相比循环神经网络等复杂模型,它的训练时间更短。但即便如此,在长期多步预测任务中,TiDE仍能取得很高的预测性能。

不过,针对不同的预测问题,模型的表现也会有所差异。因此,在实际应用中,建议对TiDE以及其他潜在的模型方案进行评估和测试,选择最佳的方案。

责任编辑:武晓燕 来源: 数据STUDIO
相关推荐

2023-03-16 07:27:30

CnosDB数据库

2021-04-07 10:02:00

XGBoostPython代码

2023-03-27 07:34:28

XGBoostInluxDB时间序列

2024-01-30 01:12:37

自然语言时间序列预测Pytorch

2024-10-23 17:10:49

2024-11-04 15:34:01

2022-11-24 17:00:01

模型ARDL开发

2024-05-09 16:23:14

2024-12-16 13:15:15

机器学习时间序列分析数据缺失

2017-01-09 09:20:07

Ubuntu NTP同步

2024-06-27 16:38:57

2023-10-16 18:02:29

2024-06-17 16:02:58

2021-07-02 10:05:45

PythonHot-winters指数平滑

2021-07-01 21:46:30

PythonHot-Winters数据

2023-10-13 15:34:55

时间序列TimesNet

2022-08-16 09:00:00

机器学习人工智能数据库

2023-03-16 18:09:00

机器学习数据集

2022-06-09 09:14:31

机器学习PythonJava

2024-09-04 16:36:48

点赞
收藏

51CTO技术栈公众号