LLM实践系列-细聊LLM的拒绝采样
最近学强化的过程中,总是遇到“拒绝采样”这个概念,我尝试科普一下,争取用最大白话的方式让每个感兴趣的同学都理解其中思想。
拒绝采样是 LLM 从统计学借鉴过来的一个概念。其实大家很早就接触过这个概念,每个刷过 leetcode 的同学大概率都遇到过这样一个问题:“如何用一枚骰子获得 1/7 的概率?”
答案很简单:把骰子扔两次,获得 6 * 6 = 36 种可能的结果,丢弃最后一个结果,剩下的 35 个结果平分成 7 份,对应的概率值便为 1/7 。使用这种思想,我们可以利用一枚骰子获得任意 1/N 的概率。
在这个问题中,我们可以看到拒绝采样的一些关键要素:
- 采样:从易于采样的分布(两个骰子的所有可能结果)中生成样本;
- 缩放:(扔两次骰子)获得更大的样本分布;
- 拒绝:丢弃(拒绝)不符合条件的样本(第36种情况);
- 接受:对于剩下的样本,重新调整概率(通过分组),获得目标概率分布。
用大白话来总结就是:我们想获得某个分布(1/7)的样本,但却没有办法。于是我们对另外一个分布(1/6)进行采样,但这个分布不能涵盖原始分布,需要我们缩放这个分布(扔两次)来包裹起来目标分布。然后,我们以某种规则拒绝明显不是目标分布的采样点,剩下的采样点就可以看作是从目标分布采样出来的了。
统计学的拒绝采样
LLM 的拒绝采样
LLM 的拒绝采样操作起来非常简单:让自己的模型针对 prompt 生成多个候选 response,然后用 reward_model 筛选出来高质量的 response (也可以是 pair 对),拿来再次进行训练。
解剖这个过程:
- 提议分布是我们自己的模型,目标分布是最好的语言模型;
- prompt + response = 一个采样结果;
- do_sample 多次 = 缩放提议分布(也可以理解为扔多次骰子);
- 采样结果得到 reward_model 的认可 = 符合目标分布。
经过这一番操作,我们能获得很多的训练样本,“这些样本既符合最好的语言模型的说话习惯,又不偏离原始语言模型的表达习惯”,学习它们就能让我们的模型更接近最好的语言模型。
统计学与 LLM 的映射关系
统计学的拒绝采样有几个关键要素:
- 原始分布采样困难,提议分布采样简单;
- 提议分布缩放后能涵盖原始分布;
- 有办法判断从提议分布获取的样本是否属于原始分布,这需要我们知道原始分布的密度函数。
LLM 的拒绝采样也有几个对应的关键要素:
- 我们不知道最好的语言模型怎么说话,但我们知道自己的语言模型如何说话;
- 让自己的语言模型反复说话,得到的语料大概率会包括最好的语言模型的说话方式;
- reward_model 可以判断某句话是否属于最好的语言模型的说话方式。
目前为止,是不是看上去很有道理,很好理解。但其实这里有一个致命的逻辑漏洞:为什么我们的模型反复 do_sample,就一定能覆盖最好的语言模型呢?这不合逻辑啊,狗嘴里采样多少次也吐不出象牙啊。
紧接着,就需要我们引出另一个概念了:RLHF 的优化目标是什么?
RLHF 与拒绝采样