详解MMoE 模型:多任务学习中的专家混合建模与实践【附代码】

发布于 2025-2-27 12:18
浏览
0收藏

MMOE模型由谷歌研究团队于2018年在论文《Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts》中提出,是一种新颖的多任务学习框架,广泛应用于推荐系统中。

本文从技术背景、演化过程、计算原理、关键问题解析以及基于PyTorch 的代码实现方面对MMoE架构进行深入探究。

1.技术背景

(1)多任务学习的本质是共享表示以及相关任务的相互影响,多任务学习模型并不总是在所有任务上都优于相应的单任务模型。

(2)通常,相似的子任务拥有比较接近的底层特征,那么在多任务学习中就可以很好地进行底层特征共享。但对于不相似的任务,它们的底层特征表示差异很大,在进行参数共享时很可能互相冲突或噪声太多,从而导致多任务学习模型效果不佳。

(3)因此,多任务学习的难点是如何在相似性不高的任务上获得好的效果。

2.MMoE演化过程

(1)MMoE的主干建立在最常用的Shared-Bottom模型(图(a))上,其中所有任务共享一个底层网络,每个任务在底层网络的顶部有一个单独的网络塔。

(2)MMoE模型(图(c))不是共享一个底层网络,而是拥有一组底层网络,其中每个底层网络被称之为专家(Expert)。

(3)与所有任务共享一个门控网络(Gate)的MoE模型(图(b))相比,MMoE为每个任务单独引入一个门控网络。在针对不同任务时可以得到不同的专家权重,从而实现对专家的选择性利用,即不同任务对应的门控网络可以学习到不同的专家组合模式,因此更容易捕捉到任务间的相关性和差异性。

详解MMoE 模型:多任务学习中的专家混合建模与实践【附代码】-AI.x社区

3.MMoE计算原理

详解MMoE 模型:多任务学习中的专家混合建模与实践【附代码】-AI.x社区

(1)当输入为时(是特征维度),任务的输出可以表示为:

其中,代表任务的底层网络,即被多个任务共享的专家网络组。代表专属于任务的顶层网络塔。

(2)底层共享专家组网络的输出可以表示为:

其中,代表第个专家网络(是专家数量),代表任务的门控网络产生的第个专家网络的权重值。最终输出结果为所有专家网络输出的加权和。

(3)对于任务专属的门控网络而言,输入为特征,输出为所有专家的权重值(类似于注意力机制),表示如下:

其中,是一个可训练参数矩阵。

(4)其中,每个专家网络的输入特征和结构都是一样的,每个门控网络的输入和结构也是一样的。专家和门控网络的结构都为前馈神经网络,层数可自定义。

(5)特别注意,由计算原理可知,门控网络的数量取决于任务数量,门控网络的输出大小必须等于专家数量。

4.关键问题解析

(1)专家网络的结构一样,输入特征也一样,是否会导致每个专家学习得到的参数趋向于一致,从而失去最终集成的意义?

(a)在网络参数随机初始化的情况下,不会出现上述问题。因为数据存在多个视角,只要每一个专家网络的参数初始化是不一样的,就会导致每一个专家学习到数据中不同的特征表达。

(2)门控网络的权重极化问题,即某些专家网络的权重过大或者国小,导致模型对这些专家网络的依赖性过高或过低,从而影响模型性能。

(a)使用Dropout:随机地将一部分神经元地输出设置为0,以防止模型过度依赖某些神经元,增强模型地泛化能力和鲁棒性。

(b)正则化:对门控网络的参数进行正则化,如添加 L1 或 L2 正则项,限制参数的大小和复杂度,防止模型过拟合,也有助于缓解权重分布极化现象。

5.基于Pytorch的代码实现

以下代码中的专家和门控神经网络皆以一个简单的线性层代替。

(1)专家神经网络的实现

详解MMoE 模型:多任务学习中的专家混合建模与实践【附代码】-AI.x社区

(2)门控神经网络的实现

详解MMoE 模型:多任务学习中的专家混合建模与实践【附代码】-AI.x社区

(3)任务塔网络的实现

详解MMoE 模型:多任务学习中的专家混合建模与实践【附代码】-AI.x社区

(4)MMoE整体实现

详解MMoE 模型:多任务学习中的专家混合建模与实践【附代码】-AI.x社区

(5)测试运行

详解MMoE 模型:多任务学习中的专家混合建模与实践【附代码】-AI.x社区

参考资料

1、 Ma J , Zhao Z , Yi X , et al. Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts. ACM, 2018.

2、https://mp.weixin.qq.com/s/38tQmvVxngDT0c-QI5B8rg

3、https://blog.csdn.net/u012328159/article/details/123309660

本文转载自 南夏的算法驿站​,作者: 赵南夏


收藏
回复
举报
回复
相关推荐