基于纯SQL训练机器学习模型

译文 精选
数据库 SQL Server
在本文中,我将分享如何在开源分布式SQL数据库TiDB上用纯SQL训练机器学习模型。

译者 | 朱先忠

审校 | 梁策 孙淑娟

在​​《用纯SQL在BigQuery上实现深层神经网络》​​一文中,作者声称使用纯SQL方式实现了一个深层神经网络模型。但在我打开他的​​GitHub代码仓库​​分析后发现,他是使用Python来实现迭代训练的,而这并不是真正的纯SQL方式。

在本文中,我将分享我是如何在​​开源分布式SQL数据库TiDB​​上用纯SQL方式训练机器学习模型的。主要步骤包括:

  1. 选择Iris数据集
  2. (https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html)。
  3. 选择softmax逻辑回归模型用于训练。
  4. 编写SQL语句来实现模型推理。
  5. 开始模型训练。

在测试中,我训练了一个softmax逻辑回归模型。在测试期间,我发现TiDB不允许在递归公共表表达式(CTE)中使用子查询和聚合函数。通过修改TiDB的代码,我绕过了这些限制,成功地训练了一个模型,并在Iris数据集上获得了98%的准确率。

为什么我选择TiDB来实现机器学习模型?

TiDB 5.1引入了许多新功能,包括ANSI SQL 99标准的通用表表达式(CTE)。我们可以使用CTE作为临时视图的语句来解耦复杂的SQL语句,并更高效地开发代码。此外,递归CTE可以引用自身,这对于改进SQL功能非常重要。CTE和窗口函数使SQL成为一种图灵完备的语言。

【说明】因为递归CTE可以“迭代”,所以我想尝试一下,看看是否可以使用纯SQL在TiDB上实现机器学习模型训练和推理。

鸢尾花(Iris)数据集

我选择使用scikit-learn的Iris数据集。该数据集包含3种类型,每种类型有50条记录,一共150条。每个记录有4个特征:萼片长度(SL)、萼片宽度(SW)、花瓣长度(PL)和花瓣宽度(PW)。我们可以利用这些特征来预测鸢尾花是否属于山鸢尾(Iris-setosa) 、 变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。

以CSV格式下载数据后,我将其导入了TiDB数据库。使用的SQL脚本如下:

create table iris(sl float, sw float, pl float, pw float, type varchar(16));
LOAD DATA LOCAL INFILE 'iris.csv' INTO TABLE iris FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' ;
select * from iris limit 10;
+------+------+------+------+------------------+
| sl | sw | pl | pw | type |
+-------+--------+-------+--------+------------+
| 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
| 4.9 | 3 | 1.4 | 0.2 | Iris-setosa |
| 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
| 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
| 5 | 3.6 | 1.4 | 0.2 | Iris-setosa |
| 5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa |
| 4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa |
| 5 | 3.4 | 1.5 | 0.2 | Iris-setosa |
| 4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa |
| 4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa |
+---- --+-------+-------+-------+--------------+
10 rows in set (0.00 sec)
select type, count(*) from iris group by type;
+-------------------+------------------+
| type | count(*) |
+-------------------+-----------------+
| Iris-versicolor | 50 |
| Iris-setosa | 50 |
| Iris-virginica | 50 |
+-------------------+----------------+
3 rows in set (0.00 sec)

Softmax逻辑回归

我选择了一个简单的机器学习模型:用于多类分类的Softmax逻辑回归。在Softmax回归中:

成本函数是:

梯度是:

因此,我们可以使用梯度下降来升级梯度:

模型推理

我编写了一条SQL语句来实现推理。基于上面定义的模型和数据,输入数据x有五个维度(SL、SW、PL、PW和一个常数1.0),输出使用了一种热编码。SQL脚本如下:

create table data(
x0 decimal(35, 30), x1 decimal(35, 30), x2 decimal(35, 30), x3 decimal(35, 30), x4 decimal(35, 30),
y0 decimal(35, 30), y1 decimal(35, 30), y2 decimal(35, 30)
);
insert into data
select
sl, sw, pl, pw, 1.0,
case when type='Iris-setosa'then 1 else 0 end,
case when type='Iris-versicolor'then 1 else 0 end,
case when type='Iris-virginica'then 1 else 0 end
from iris;

共有15个参数(3种类型*5个维度)。SQL脚本如下:

create table weight(
w00 decimal(35, 30), w01 decimal(35, 30), w02 decimal(35, 30), w03 decimal(35, 30), w04 decimal(35, 30),
w10 decimal(35, 30), w11 decimal(35, 30), w12 decimal(35, 30), w13 decimal(35, 30), w14 decimal(35, 30),
w20 decimal(35, 30), w21 decimal(35, 30), w22 decimal(35, 30), w23 decimal(35, 30), w24 decimal(35, 30));

我将输入数据初始化为0.1、0.2、0.3。为了便于演示,我使用了不同的数字。将它们全部初始化为0.1是可以的。SQL脚本如下:

insert into weight values (
0.1, 0.1, 0.1, 0.1, 0.1,
0.2, 0.2, 0.2, 0.2, 0.2,
0.3, 0.3, 0.3, 0.3, 0.3);

接下来,我编写了一条SQL语句来计算数据推断结果的准确性。为了更好地理解,我使用伪代码来描述这个过程:

weight = (   
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24
)
for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:
exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)
exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)
exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)
sum_exp = exp0 + exp1 + exp2

// softmax
p0 = exp0 / sum_exp
p1 = exp1 / sum_exp
p2 = exp2 / sum_exp

//推理结果
r0 = p0 > p1 and p0 > p2
r1 = p1 > p0 and p1 > p2
r2 = p2 > p0 and p2 > p1

data.correct = (y0 == r0 and y1 == r1 and y2 == r2)
return sum(Data.correct) / count(Data)

在上面的代码中,我计算了每行数据中的元素。为了对样本进行推断:

  1. 我计算出加权向量的EXP。
  2. 并且计算出softmax值。
  3. 然后,选择p0、p1和p2中最大的一个作为1,并将其余的设置为0。

如果样本的推断结果与其原始分类一致,则预测正确。然后,我将所有样本的正确数量相加,得到最终的准确率。

下面的代码显示了SQL语句的实现。我将每一行数据加上一个权重(只有一行数据),计算每一行的推断结果,并将正确的样本数相加:

select sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*)
from
(select
y0, y1, y2,
p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2
from
(select
y0, y1, y2,
e0/(e0+e1+e2) as p0, e1/(e0+e1+e2) as p1, e2/(e0+e1+e2) as p2
from
(select
y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight) t1
)t2
    )t3;

上面的SQL语句几乎一步一步地实现了伪代码的计算过程。我得到了如下结果:

+-------------------------------------------------------------+  
| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*) |
+-------------------------------------------------------------+
| 0.3333 |
+-------------------------------------------------------------+
1 row in set (0.01 sec)

接下来,我开始学习模型参数。

模型训练

注意:为了简化问题,我没有考虑“训练集”和“验证集”问题,而是把所有的数据仅用于进行训练。

我编写了伪代码,然后在此基础上编写了一条SQL语句:

weight = (     
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24
)
for iter in iterations:
sum00 = 0
sum01 = 0
...
sum23 = 0
sum24 = 0
for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:
exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)
exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)
exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)
sum_exp = exp0 + exp1 + exp2
// softmax
p0 = y0 - exp0 / sum_exp
p1 = y1 - exp1 / sum_exp
p2 = y2 - exp2 / sum_exp
sum00 += p0 * x0
sum01 += p0 * x1
sum02 += p0 * x2
...
sum23 += p2 * x3
sum24 += p2 * x4
w00 = w00 + learning_rate * sum00 / Data.size
w01 = w01 + learning_rate * sum01 / Data.size
...
w23 = w23 + learning_rate * sum23 / Data.size
    w24 = w24 + learning_rate * sum24 / Data.size

因为我手动扩展了sum和w向量,所以这段代码看起来有点麻烦。然后,我开始编写SQL训练代码。首先,我编写了一条只用一次迭代的SQL语句。

我设置了如下所示的学习速率和样本数:

set @lr = 0.1;  
Query OK, 0 rows affected (0.00 sec)
set @dsize = 150;
Query OK, 0 rows affected (0.00 sec)

代码迭代了一次:

select 
w00 + @lr * sum(d00) / @dsize as w00, w01 + @lr * sum(d01) / @dsize as w01, w02 + @lr * sum(d02) / @dsize as w02, w03 + @lr * sum(d03) / @dsize as w03, w04 + @lr * sum(d04) / @dsize as w04 ,
w10 + @lr * sum(d10) / @dsize as w10, w11 + @lr * sum(d11) / @dsize as w11, w12 + @lr * sum(d12) / @dsize as w12, w13 + @lr * sum(d13) / @dsize as w13, w14 + @lr * sum(d14) / @dsize as w14,
w20 + @lr * sum(d20) / @dsize as w20, w21 + @lr * sum(d21) / @dsize as w21, w22 + @lr * sum(d22) / @dsize as w22, w23 + @lr * sum(d23) / @dsize as w23, w24 + @lr * sum(d24) / @dsize as w24
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,
p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,
p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4,
y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4, y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight) t1
)t2
    )t3;

一次迭代后,输出结果是模型参数,如下所示:

以下是核心代码部分,我使用递归CTE进行迭代训练:

set @num_iterations = 1000;
Query OK, 0 rows affected (0.00 sec)

其核心思想是,每次迭代的输入都是前一次迭代的结果,此外我添加了一个增量迭代变量来控制迭代次数。总体框架代码是:

with recursive cte(iter, weight) as
(
select 1, init_weight
union all
select iter+1, new_weight
from cte
where ites < @num_iterations

接下来,我将迭代的SQL语句与这个迭代框架结合在一起。为了提高计算精度,我在中间结果中添加了类型转换:

with recursive weight( iter,
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24) as
(
select 1,
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast (0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30))
union all
select
iter + 1,
w00 + @lr * cast(sum(d00) as DECIMAL(35, 30)) / @dsize as w00, w01 + @lr * cast(sum(d01) as DECIMAL(35, 30)) / @dsize as w01, w02 + @lr * cast(sum(d02) as DECIMAL(35, 30)) / @dsize as w02, w03 + @lr * cast(sum(d03) as DECIMAL(35, 30)) / @dsize as w03, w04 + @lr * cast(sum(d04) as DECIMAL(35, 30)) / @dsize as w04 ,
w10 + @lr * cast(sum(d10) as DECIMAL(35, 30)) / @dsize as w10, w11 + @lr * cast(sum(d11) as DECIMAL(35, 30)) / @dsize as w11, w12 + @lr * cast(sum(d12) as DECIMAL(35, 30)) / @dsize as w12, w13 + @lr * cast(sum(d13) as DECIMAL(35, 30)) / @dsize as w13, w14 + @lr * cast(sum(d14) as DECIMAL(35, 30)) / @dsize as w14,
w20 + @lr * cast(sum(d20) as DECIMAL(35, 30)) / @dsize as w20, w21 + @lr * cast(sum(d21) as DECIMAL(35, 30)) / @dsize as w21, w22 + @lr * cast(sum(d22) as DECIMAL(35, 30)) / @dsize as w22, w23 + @lr * cast(sum(d23) as DECIMAL(35, 30)) / @dsize as w23, w24 + @lr * cast(sum(d24) as DECIMAL(35, 30)) / @dsize as w24
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,
p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,
p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4,
y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4, y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight where iter < @num_iterations) t1
)t2
)t3
having count(*) > 0
)
select * from weight where iter = @num_iterations;

这个代码块和上面一次迭代的代码块之间有两个区别。在此代码块中:

  • 在 data join weight后面,我添加了where iter <@num_iterations以便控制迭代次数和要输出的iter + 1 as iter列。
  • 添加了count(*)>0,以防止聚合在最后没有输入数据时输出数据。此错误可能会导致迭代失败。

上述代码运行结果是:

ERROR 3577 (HY000): In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once, and not in any subquery

这表明递归CTE不允许在递归部分使用子查询。不过,我可以合并上面所有的子查询。但是,即使在我手动合并它们之后还是得到了以下错误提示:

ERROR 3575 (HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block

这表明不允许使用聚合函数。然后,我决定改变TiDB的实现代码。

根据​​提案​​​中的介绍,递归CTE的实现遵循了TiDB的基本执行框架。在咨询​​PingCAP​​的研发人员黄文军(Wenjun Huang)之后,我了解到子查询和聚合函数不被允许的原因有两个:

  • MySQL不允许这样做。
  • 如果允许,会有很多复杂的特殊情形需要克服。

但我只是想测试一下这些功能。为此,我暂时删除了​​diff​​中对子查询和聚合函数的检查。

最后,我再次执行修改后的代码,输出结果如下:

成功了!经过1000次迭代,我得到了参数。

接下来,我使用新参数重新计算正确的速率:

+--------------------------------------------------------------+
| sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*) |
+--------------------------------------------------------------+
| 0.9867 |
+--------------------------------------------------------------+
1 row in set (0.02 sec)

这一次,准确率达到了98%。

结论

通过使用TiDB 5.1中的递归CTE,我成功地使用纯SQL在TiDB上训练了softmax逻辑回归模型。

在测试期间,我发现TiDB的递归CTE不允许子查询和聚合函数,所以我修改了TiDB的代码以绕过这些限制。最后,我成功地训练了一个模型,并在Iris数据集上获得了98%的准确率。

最后,作为补充,在我的上述工作中还总结了下面几个想法:

  • 在做了一些测试之后,我发现PostgreSQL和MySQL都不支持递归CTE中的聚合函数,可能是因为有一些棘手的情形难以处理吧。
  • 在这次测试中,我手动扩展了向量的所有维度。事实上,我还编写了一个不需要扩展所有维度的实现。例如,数据表的模式是(idx,dim,value),但在这个实现中,权重表需要连接两次。这意味着需要在CTE中访问两次,为此还需要修改TiDB执行器的实现代码。由于这一原因,我没有在本文中讨论这一实现。但事实上,这种实现更通用,可以用它来处理MNIST数据集等更多维度的模型。

原文标题:I Trained a Machine Learning Model in Pure SQL,作者:Mingcong Han


责任编辑:华轩 来源: 51CTO
相关推荐

2020-08-10 15:05:02

机器学习人工智能计算机

2017-03-24 15:58:46

互联网

2024-11-04 00:24:56

2024-11-26 09:33:44

2024-12-26 00:46:25

机器学习LoRA训练

2022-03-25 11:11:50

机器学习TiDBSQL

2018-11-07 09:00:00

机器学习模型Amazon Sage

2024-06-06 08:00:00

2023-11-17 08:46:26

2022-09-19 15:37:51

人工智能机器学习大数据

2018-03-09 09:00:00

前端JavaScript机器学习

2021-04-09 14:49:02

人工智能机器学习

2020-01-02 14:13:01

机器学习模型部署预测

2021-04-22 08:00:00

人工智能机器学习数据

2021-11-02 23:00:35

机器学习JAVA编程

2022-09-20 23:42:15

机器学习Python数据集

2019-05-07 11:18:51

机器学习人工智能计算机

2021-12-06 20:23:40

机器学习数学

2023-09-05 10:41:28

人工智能机器学习

2023-01-09 08:00:00

迁移学习机器学习数据集
点赞
收藏

51CTO技术栈公众号