PathAI利用机器学习推动药物开发

译文 精选
人工智能 机器学习
传统的手工病理学易于出现主观性和观察者的变异性,从而对诊断和药物开发试验产生负面影响。在深入研究如何使用Pytorch改进诊断工作流程之前,让我们先介绍一下不依赖机器学习的传统模拟病理工作流程。

​译者 | 朱先忠

审校 | 孙淑娟

位于美国波士顿的病理人工智能技术公司PathAI是病理学(疾病研究)人工智能技术工具和服务的领先供应商。他们开发的平台旨在利用机器学习中的现代方法,如图像分割、图神经网络和多实例学习,大幅提高复杂疾病的诊断准确性和疗效评估。

传统的手工病理学易于出现主观性和观察者的变异性,从而对诊断和药物开发试验产生负面影响。在深入研究如何使用Pytorch改进诊断工作流程之前,让我们先介绍一下不依赖机器学习的传统模拟病理工作流程。

传统生物制药的工作原理

生物制药公司可以通过多种途径发现新的治疗方法或诊断方法。其中一个途径在很大程度上依赖于通过病理切片的分析来回答各种问题:特定的细胞通信途径是如何工作的?特定疾病状态是否与特定蛋白质的存在或缺乏有关?为什么临床试验中的特定药物对某些患者有效,而对其他患者无效?患者治疗效果与新的生物标志物之间是否存在关联?等等。

为了帮助回答这些问题,生物制药公司一般都是依靠专业病理学家来分析幻灯片,并帮助评估他们可能存在的问题。

正如你可能想象的那样,需要一位经过专家委员会认证的病理学家才能做出准确的解释和诊断。在一项研究中,36名不同的病理学家获得了一个单一的活检结果,结果出现18种不同的诊断结论,其严重程度包括从不需治疗到必要的积极治疗等各种情形。病理学家也经常在困难的边缘病例中征求同事的反馈。鉴于问题的复杂性,即使经过专家培训和合作,病理学家仍很难做出正确的诊断。这种潜在差异可能是被批准的药物与未通过临床试验的药物之间的差异所致。

PathAI如何利用机器学习推动药物开发

PathAI公司开发了一系列机器学习模型,旨在为药物研发、临床试验和诊断提供见解。为此,PathAI利用Pytork框架进行幻灯片级推理,使用各种方法,包括图神经网络(GNN)和多实例学习等途径。在这种情况下,“载玻片”是指载玻片的全尺寸扫描图像,载玻片就是玻璃片,中间有一薄层组织,染色后显示各种细胞形成物。PyTorch使我们的团队能够使用这些不同的方法共享一个通用框架,该框架足够强大,可以在我们需要的所有条件下工作。此外,PyTorch的高级、命令式和pythonic语法使我们能够快速原型化模型,然后在得到想要的结果后将这些模型扩展。

千兆图像的多实例学习

将机器学习技术应用于病理学的一个独特挑战是图像的巨大尺寸。这些数字幻灯片的分辨率通常为100000 x 100000像素或更高,大小为GB级。在GPU内存中加载完整图像并在其上应用传统的计算机视觉算法几乎是不可能的任务。注释完整的幻灯片图像(100k x 100k)也需要花费大量的时间和资源,尤其是当注释者需要的是领域专家(委员会认证的病理学家)时。我们经常建立模型来预测图像级别的标签,例如在覆盖整个图像数千像素的患者幻灯片上是否存在癌症。癌变区域有时只是整个幻灯片的一小部分,这使得机器学习问题类似于大海捞针。另一方面,有些问题,如某些组织学生物标志物的预测,需要从整个载玻片中聚集信息,然而由于图像的尺寸问题,使得这一目标同样很难实现。当将机器学习技术应用于病理学问题时,所有这些因素都增加了显著的算法、计算和逻辑复杂性。

将图像分解为较小的切片(patch),学习切片表示,然后将这些表示合并以预测图像级标签是解决此问题的一种方法,如下图所示。一种常用的方法称为多实例学习(Multiple Instance Learning,简称“MIL”)。每个切片被视为一个“实例”,一组切片形成一个“包”。将单个切片表示汇总在一起,以预测最终的包级标签。在算法上,包中的单个切片实例不需要标签;因此,允许我们以弱监督的方式学习包级标签。它们还使用置换不变池函数,使预测独立于切片的顺序,并允许有效地聚合信息。

通常,使用基于注意力的池功能,不仅可以有效聚合,还可以为包中的每个切片提供注意力值。这些值表明了相应切片在预测中的重要性,可以可视化以更好地理解模型预测。可解释性的这一要素对于推动这些模型在现实世界中的采用非常重要,我们使用诸如加性MIL模型之类的变体来实现这种空间解释性。在计算上,MIL模型避免了将神经网络应用于大尺寸图像的问题,因为切片表示是独立于图像大小获得的。

在PathAI中,我们使用基于深度网络的自定义MIL模型来预测图像级标签。该过程概述如下:

1.使用不同的采样方法从幻灯片中选择切片。

2、基于随机抽样或启发式规则构造一包切片。

3、基于预训练模型或大规模表示学习模型为每个实例生成切片表示。

4.应用置换不变池函数来获得最终的幻灯片级别分数。

现在,我们已经了解了Pytork中有关MIL的一些高级细节。接下来,让我们看看一些代码,在Pytork中从构思到生产代码有多么简单。我们首先定义采样器、转换和MIL数据集:

#创建一袋采样器,从幻灯片中随机采样切片
bag_sampler = RandomBagSampler(bag_size=12)

#设置转换
crop_transform = FlipRotateCenterCrop(use_flips=True)

#创建为每个包加载切片的数据集
train_dataset = MILDataset(
bag_sampler=bag_sampler,
samples_loader=sample_loader,
transform=crop_transform,
)

在定义了采样器和数据集之后,我们需要定义使用该数据集实际训练的模型。通过使用大家熟悉的PyTorch模型定义语法很容易做到这一点,同时也允许我们创建定制模型。

classifier = DefaultPooledClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=1)

pooling = DefaultAttentionModule(
input_dims=1024,
hidden_dims=[256, 256],
output_activation=StableSoftmax()
)

# 定义由特征化器、池模块和分类器组成的模型
model = DefaultMILGraph(featurizer=ShuffleNetV2(), classifier=classifier, pooling = pooling)

由于这些模型经过端到端的训练,因此它们提供了一种强大的方法,可以直接从千兆像素的整张幻灯片图像转换为单个标签。由于其广泛适用于不同的生物问题,其实施和部署的两个方面很重要:

  • 对管道每个部分的可配置控制,包括数据加载器、模型的模块化部分以及它们之间的交互。
  • 通过“形成概念-实现-实验-产品化”循环能够快速迭代。

Pytork在MIL建模方面具有各种优势。它提供了一种直观的方法来创建具有灵活控制流的动态计算图,这对于快速研究实验非常有用。映射风格的数据集、可配置的采样器和批量采样器允许我们自定义如何构建切片包,从而实现更快的实验。由于MIL模型是IO密集型的,数据并行性和pythonic数据加载程序使任务非常高效且用户友好。最后,PyTorch的面向对象特性支持构建可重用的模块,这有助于快速实验、可维护的实现和易于构建管道的组合组件。

在PyTorch中用GNN探索空间组织结构

在健康组织和病变组织中,细胞的空间排列和结构往往与细胞本身一样重要。例如,在评估肺癌时,病理学家试图观察肿瘤细胞的整体分组和结构(它们形成固体薄片吗?还是以较小的局部簇出现?)来确定癌症是否属于差异很大的特定亚型。细胞和其他组织结构之间的这种空间关系可以使用图建模,以便同时捕捉组织拓扑和细胞组成。图神经网络(GNN)允许学习这些图中与其他临床变量相关的空间模式,例如某些癌症中的基因过度表达。

2020年末,当PathAI公司开始在组织样本上使用GNN时,PyTorch通过PyG包获得了对GNN功能的最佳和最成熟的支持。这使得PyTork成为我们团队的自然选择,因为我们知道GNN模型是我们想要探索的重要机器学习概念。

在组织样本的背景下,GNN的主要附加值之一是,图本身可以揭示空间关系,否则仅通过视觉检查很难找到这些关系。在我们最近的AACR出版论文中,我们指出,通过使用GNN,我们可以更好地了解肿瘤微环境中免疫细胞聚集体(特别是三级淋巴结构,或TLS)的存在对患者预后的影响。在这种情况下,GNN方法用于预测与TLS存在相关的基因表达,并识别TLS区域以外与TLS相关的组织学特征。如果没有ML模型的帮助,则很难从组织样本图像中识别这种对基因表达的见解。

我们成功使用的最有前途的GNN变体之一是自注意力图池。接下来,让我们看一下我们是如何使用PyTorch和PyG来定义自注意力图池(SAGPool)模型的:

class SAGPool(torch.nn.Module):
def __init__(self, ...):
super().__init__()
self.conv1 = GraphConv(in_features, hidden_features, aggr='mean')
self.convs = torch.nn.ModuleList()
self.pools = torch.nn.ModuleList()
self.convs.extend([GraphConv(hidden_features, hidden_features, aggr='mean') for i in range(num_layers - 1)])
self.pools.extend([SAGPooling(hidden_features, ratio, GNN=GraphConv, min_score=min_score) for i in range((num_layers) // 2)])
self.jump = JumpingKnowledge(mode='cat')
self.lin1 = Linear(num_layers * hidden_features, hidden_features)
self.lin2 = Linear(hidden_features, out_features)
self.out_activation = out_activation
self.dropout = dropout

在上面的代码中,我们首先定义一个卷积图层,然后添加两个模块列表层,允许我们传入可变数量的层。然后,我们获取空模块列表,并附加可变数量的GraphConv层,后跟可变数量的SAGPooling层。然后,我们通过添加JumpingKnowledge层、两个线性层、激活函数和退出值来完成SAGPool定义。PyTorch直观的语法使我们能够抽象出使用最先进方法(如SAG池)的复杂性,同时保持我们熟悉的通用模型开发方法。

像我们上面描述的使用一个SAG池这样的模型只是GNN与PyTorch如何允许我们探索新想法的一个例子。我们最近还探索了多模式CNN-GNN混合模型,其结果比传统病理学家共识分数高20%。这些创新以及传统CNN和GNN之间的相互作用,再次得益于从研究到生产的短期模型开发循环。

改善患者预后

总而言之,为了实现我们使用人工智能驱动的病理学改善患者预后的使命,PathAI需要借助于ML开发框架,该框架:(1)在开发和探索的初始阶段促进快速迭代和轻松扩展(即模型配置为代码)(2)将模型训练和推理扩展到海量图像(3)轻松可靠地为我们产品的生产使用(在临床试验及以后)提供模型。

正如我们在本文中所展示的,PyTorch为我们提供了所有上述功能以及更多功能支持。我们对PyTorch框架的未来感到无比兴奋,甚至迫不及待地想看到我们可以使用该框架解决哪些其他有影响力的挑战。

译者介绍

朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。早期专注各种微软技术(编著成ASP.NET AJX、Cocos 2d-X相关三本技术图书),近十多年投身于开源世界(熟悉流行全栈Web开发技术),了解基于OneNet/AliOS+Arduino/ESP32/树莓派等物联网开发技术与Scala+Hadoop+Spark+Flink等大数据开发技术。

原文标题:Case Study: PathAI Uses PyTorch to Improve Patient Outcomes with AI-powered Pathology​,作者:Logan Kilpatrick, Harshith Padigela, Syed Ashar Javed, Robert Egger​

责任编辑:华轩 来源: 51CTO
相关推荐

2022-05-30 10:53:48

机器学习医疗行业变革

2017-10-26 12:32:23

机器学习大数据药物

2020-07-09 18:35:34

AWS机器学习

2020-01-15 16:58:46

机器学习数据库人工智能

2021-02-07 09:26:55

机器学习建筑能源ML

2020-12-25 15:24:24

人工智能

2024-04-17 08:00:00

2022-06-02 15:42:05

Python机器学习

2021-06-17 10:27:03

人工智能AI机器学习

2018-11-06 09:00:00

2019-04-25 14:00:24

编程语言机器学习Java

2016-04-11 14:35:59

机器学习数据挖掘数据模型

2023-03-28 18:16:38

2020-10-30 08:00:00

PyTorch机器学习人工智能

2024-06-03 10:42:25

2020-08-20 10:49:49

人工智能机器学习技术

2021-01-25 11:04:54

人工智能药物开发神经网络

2022-06-28 10:22:00

机器学习网络攻击黑客

2017-11-28 08:46:29

HPEInfoSight机器学习
点赞
收藏

51CTO技术栈公众号