MIT提出:引入贝叶斯深度学习在医疗监控上的应用

人工智能 深度学习
来给大家讲一下我们发表在Nature Medicine上的一个工作,这算是我在MIT期间做的最有意思的工作之一了。

 

Paper:

http://wanghao.in/paper/NatureMedicine21_MSA.pdf

Bayesian Formulation及算法细节:

http://wanghao.in/BayesDL4MSA.html

Bayesian Deep Learning Survey:

http://wanghao.in/paper/CSUR20_BDL.pdf

Bayesian Deep Learning Github Repo:

https://github.com/js05212/BayesianDeepLearning-Survey/blob/master/README.md

来给大家讲一下我们发表在Nature Medicine上的一个工作,这算是我在MIT期间做的最有意思的工作之一了。希望这个帖子能够贡献一个数据点,让大家看看机器学习(特别是贝叶斯深度学习,or Bayesian Deep Learning)在医疗监控(Health Monitoring)上的应用。

一、应用场景

简单(科幻)地说,我们做的这个系统能够通过感知房子里面的wifi信号,来监测病人是否遵医嘱,按时使用胰岛素笔(Insulin Pen)或者定量吸入器(Inhaler)之类的医疗工具来治疗自己。因为这类医疗工具的使用有点复杂(比如胰岛素笔有8个步骤,而定量吸入器有6个步骤),病人经常会出现使用失误,我们这个系统还能自动检测出病人有没有漏掉哪个步骤,或者有没有哪个步骤做得不到位。我们把这个应用叫做‘ 自我给药 ’(Medication Self-Administration,or MSA)。具体使用场景如下图。

关于胰岛素笔和定量吸入器的使用步骤可以看下图。

二、 连续时间域的概率推理

熟悉机器学习的同学可能已经发现了,这个问题其实是个比较复杂的概率推理问题:

1. 不同的步骤持续的时间长度不同,比如上图Fig. 4a的第1步(Step 1)的‘拿起工具’一般只有4秒左右,而第6步(Step 6)‘用药并按住’一般会持续12秒左右。因此可以认为,不同步骤的时长都遵循不同的概率分布。如下图。而我们的模型需要把这些先验知识整合进去。

2. 不同步骤之间的空白时间也有长有短(比如上面绿色中间的白色区域)。

3. 病人经常会忘记里面的一些关键步骤。比如,对于胰岛素笔(见上图)来说,病人经常会忘记的步骤是第2步‘放入药芯’(Load Cartridge)和第4步‘预备胰岛素笔’(Prime Insulin Pen)。那么此时,我们可以把整个胰岛素笔的流程画成如下图的有限状态机。图中的从Step 1出发的2个’50%’的路径表示,一个病人有一半的(先验)概率会忘记Step 2而直接进行Step 3。而这个也是我们的模型需要整合的先验知识。

三、 贝叶斯深度学习

(深度学习和概率推理的结合)

从技术的角度上说,在这个work里面,我们结合底层的FMCW雷达Perception和顶层的Continuous-time BayesNet Reasoning做了一个Bayesian Deep Learning的model,用于全天候、无接触地推断慢性病人是否按时使用Insulin Pen、Inhaler之类的医疗工具治疗自己,同时检测自动具体使用步骤的异常。整个系统的流程图如下。

这里面包含两个联动的模型。

第一个模型是用来处理类Wifi信号(底层的FMCW雷达信号)的深度神经网络。对应上图的阶段1(Stage 1)和阶段2(Stage 2)的合并。这个深度模型的 输入 是:长达几分钟的很多帧(frame)的雷达信号(见上图的第一行);而他的 输出 是:每一帧属于不同步骤的概率(见上图的Stage 2的输出),也就是说,如果这个用药过程包含6个步骤,那么每一帧的输出会是一个6维的向量,这6维的数字加起来会恒等于1。

第二个模型对应上图的阶段3(Stage 3),它是基于Stage 2的原始概率分数(Raw Probability Score),然后结合我们前面讲到的‘ 连续时间域的概率推理 ’来进行进一步的概率推断,从而输出最终预测(见上图最后一行)。

这两个模型一个作为 深度模块(也叫做感知模块) ,负责对高维信号进行处理,另一个作为 概率推断模块(也叫做任务模块) ,负责进行任务相关的概率推断。两个模块可以联动地、端到端(end-to-end)地起作用,我们把这个称为 贝叶斯深度学习

有兴趣的同学请看  A Survey on Bayesian Deep Learning :

http://wanghao.in/paper/CSUR20_BDL.pdf

四、 深度学习vs贝叶斯深度学习

那么问题来了,此处为何需要第二个模型的联动呢?为何直接使用第一个模型没法解决问题?

这是因为, 第一个模型作为一个深度神经网络,只负责给出每帧独立的概率预测,无法结合本帧的前后部分来进行推理。 这样的后果是,它直接输出的逐帧预测经常是不符合常理的。比如它可能输出这样的预测:前0.1秒这个病人还在使用设备的第3步,后0.1秒就直接跳回到第1步,而后的0.1秒又是在第4步。一个正常人显然无法做出这样的动作序列。因此,第二个模型的作用就是,把我们在连续时间域上的先验知识(前面‘ 连续时间域的概率推理 ’章节讲到的那三个方面)整合入模型里面,进行端到端(end-to-end)的推断,从而拿到最终的预测。这整个结合概率推理和深度学习的框架,我们把它叫做 贝叶斯深度学习 

这样做的好处有两方面。 第一方面 是大大提高模型的 准确度和鲁棒性 。由于概率推理的存在,模型会根据整个几分钟的动作序列来判断病人是否在使用医疗用药工具,这样既自动纠正了第一个模型的一些错误预测,同时也使得整个系统受无关噪音的影响大大减小。 第二方面 是给模型提供了 可解释性 。正如  @kkhenry  所说,在医疗相关的应用中,可解释性异常重要,因为这个关系到,AI系统的使用者(医生等医疗从业人员)能否相信你模型的预测。有了概率推断的部分,我们就可以给出对每个步骤预测的概率,以及模型预测于先验知识的偏差程度,从而提供解释。比如,模型可以给出‘此病人在早上9点使用了医疗工具,但是使用错误’的结论,同时解释‘这是因为模型有95%的把握他/她跳过了Step 2。而医生可以根据模型提供的解释,来决定要不要进一步检查此病人的具体数据,并提醒病人。如下图。

五、技术细节之如何结合

两个模型的预测

关于Output,每个帧会有2个预测(Prediction)。第一个预测是来自于第一个模型(深度学习模型)给出的逐帧(Frame-level)预测,这个很简单,可以理解为神经网络对输出进行Softmax操作后,得到的各类的概率。第二个预测来自于第二个模型(概率推断模型)。它是来自于一个作为先验(Prior)的随机过程,具体地讲,这是一个连续时间域的 点过程 (Point Process)和 马尔科夫链 (Markov chain)的结合。点过程负责对每个步骤(比如Step 2)的长度进行建模,马尔科夫链负责对各个步骤之间的转换进行建模(比如进行Step 1后,有一半的概率会进行Step 2,有另一半的概率会进行Step 3)。

这里有一个有趣的点就是,如果我们只用一般的点过程,比如 泊松过程 (Poisson Process),是没有办法很好的对每个步骤的长度进行建模的。这是因为泊松过程假设每个步骤的长度是一个指数分布(Exponential Distribution),而指数分布一旦他的期望值(均值)确定了(比如是a),他的方差也就确定了(等于a^2)。因此它没法像高斯分布(Gaussian Distribution)一样那么灵活,可以自由地描述一个分布的期望值和方差。所以这个地方我们灵机一动,把泊松过程的指数分布替换成 高斯分布 ,用来作为我们模型的先验之一。而每个步骤持续时间的高斯分布的期望值和方差都不一样,这些都可以从训练数据里面直接估计出来。

所以,我们直接结合了第一个模型(深度学习模型)的 预测分数 ,以及第二个模型提供的 先验分数 ,在加上一个 近似的动态规划算法 ,就可以进行联动的(jointly or end-to-end)概率推断,得到最终的预测。下图展示了一些我们模型预测(AI Prediction)和人工标注(Human Annotation)的对比,前3个例子(a、b、c)是3个不同的病人在使用胰岛素笔,总共有8个步骤。后面3个例子是3个不同的病人在使用定量吸入器,总共有6个步骤。可以看到, 我们的模型最终预测是非常准确的,而且不会出现 physically impossible 的预测。

六、写在最后

整个帖子算是抛砖引玉,讲了下机器学习(更具体的是贝叶斯深度学习)及其在医疗上的应用。

最后要感谢一下赵老板拉我入伙一起做这个work。遥想当年刚进去MIT的时候就想着把贝叶斯深度学习用到医疗上,说要用深度模块(即感知模块)来对无线信号建模,用概率模块(即任务模块)来做医疗相关的概率推断。没想到最后真的实现了。可谓念念不忘必有回响:)

Illustration by  Maria Shukshina   from Icons8

 

责任编辑:张燕妮 来源: 将门创投
相关推荐

2012-09-24 10:13:35

贝叶斯

2017-06-12 06:31:55

深度学习贝叶斯算法

2022-05-06 12:13:55

模型AI

2023-01-31 15:49:51

机器学习函数评分函数

2016-08-30 00:14:09

大数据贝叶斯

2020-05-21 14:50:37

算法深度学习人工智能

2021-11-05 15:22:46

神经网络AI算法

2023-09-12 11:36:15

携程模型

2017-07-24 10:36:37

Python机器学习朴素贝叶斯

2021-08-30 11:53:36

机器学习人工智能计算机

2017-08-07 13:02:32

全栈必备贝叶斯

2013-05-08 09:05:48

狐狸贝叶斯大数据

2017-03-29 14:50:18

2021-04-18 09:57:45

Java朴素贝叶斯贝叶斯定理

2024-10-11 16:53:16

贝叶斯人工智能网络

2016-08-30 00:19:30

2023-10-18 08:00:00

贝叶斯网络Python医疗保健

2017-07-12 11:27:05

朴素贝叶斯情感分析Python

2012-02-14 10:55:24

2022-09-28 08:00:00

Python机器学习算法
点赞
收藏

51CTO技术栈公众号