机器学习中常用的损失函数你知多少?

开发 开发工具 机器学习
本文作者将常用的损失函数分为了两大类:分类和回归。然后又分别对这两类进行了细分和讲解,其中回归中包含了一种不太常见的损失函数:平均偏差误差,可以用来确定模型中存在正偏差还是负偏差。

机器通过损失函数进行学习。这是一种评估特定算法对给定数据建模程度的方法。如果预测值与实际结果偏离较远,损失函数会得到一个非常大的值。在一些优化函数的辅助下,损失函数逐渐学会减少预测值的误差。本文将介绍几种损失函数及其在机器学习和深度学习领域的应用。

[[243642]]

损失函数和优化

没有一个适合所有机器学习算法的损失函数。针对特定问题选择损失函数涉及到许多因素,比如所选机器学习算法的类型、是否易于计算导数以及数据集中异常值所占比例。

从学习任务的类型出发,可以从广义上将损失函数分为两大类——回归损失和分类损失。在分类任务中,我们要从类别值有限的数据集中预测输出,比如给定一个手写数字图像的大数据集,将其分为 0~9 中的一个。而回归问题处理的则是连续值的预测问题,例如给定房屋面积、房间数量以及房间大小,预测房屋价格。

  1. NOTE  
  2.         n        - Number of training examples. 
  3.         i        - ith training example in a data set. 
  4.         y(i)     - Ground truth label for ith training example. 
  5.         y_hat(i) - Prediction for ith training example. 

回归损失

1. 均方误差/平方损失/L2 损失

数学公式:

均方误差

顾名思义,均方误差(MSE)度量的是预测值和实际观测值间差的平方的均值。它只考虑误差的平均大小,不考虑其方向。但由于经过平方,与真实值偏离较多的预测值会比偏离较少的预测值受到更为严重的惩罚。再加上 MSE 的数学特性很好,这使得计算梯度变得更容易。

  1. import numpy as np 
  2. y_hat = np.array([0.000, 0.166, 0.333]) 
  3. y_true = np.array([0.000, 0.254, 0.998]) 
  4. def rmse(predictions, targets): 
  5.     differences = predictions - targets 
  6.     differencesdifferences_squared = differences ** 2 
  7.     mean_of_differences_squared = differences_squared.mean() 
  8.     rmse_val = np.sqrt(mean_of_differences_squared) 
  9.     return rmse_val 
  10. print("d is: " + str(["%.8f" % elem for elem in y_hat])) 
  11. print("p is: " + str(["%.8f" % elem for elem in y_true])) 
  12. rmsermse_val = rmse(y_hat, y_true) 
  13. print("rms error is: " + str(rmse_val)) 

2. 平均绝对误差/L1 损失

数学公式:

平均绝对误差

平均绝对误差(MAE)度量的是预测值和实际观测值之间绝对差之和的平均值。和 MSE 一样,这种度量方法也是在不考虑方向的情况下衡量误差大小。但和 MSE 的不同之处在于,MAE 需要像线性规划这样更复杂的工具来计算梯度。此外,MAE 对异常值更加稳健,因为它不使用平方。

  1. import numpy as np 
  2. y_hat = np.array([0.000, 0.166, 0.333]) 
  3. y_true = np.array([0.000, 0.254, 0.998]) 
  4.  
  5. print("d is: " + str(["%.8f" % elem for elem in y_hat])) 
  6. print("p is: " + str(["%.8f" % elem for elem in y_true])) 
  7.  
  8. def mae(predictions, targets): 
  9.     differences = predictions - targets 
  10.     absolute_differences = np.absolute(differences) 
  11.     mean_absolute_differences = absolute_differences.mean() 
  12.     return mean_absolute_differences 
  13. maemae_val = mae(y_hat, y_true) 
  14. print ("mae error is: " + str(mae_val)) 

3. 平均偏差误差(mean bias error)

与其它损失函数相比,这个函数在机器学习领域没有那么常见。它与 MAE 相似,唯一的区别是这个函数没有用绝对值。用这个函数需要注意的一点是,正负误差可以互相抵消。尽管在实际应用中没那么准确,但它可以确定模型存在正偏差还是负偏差。

数学公式:

平均偏差误差

二、分类损失

1. Hinge Loss/多分类 SVM 损失

简言之,在一定的安全间隔内(通常是 1),正确类别的分数应高于所有错误类别的分数之和。因此 hinge loss 常用于***间隔分类(maximum-margin classification),最常用的是支持向量机。尽管不可微,但它是一个凸函数,因此可以轻而易举地使用机器学习领域中常用的凸优化器。

数学公式:

SVM 损失(Hinge Loss)

思考下例,我们有三个训练样本,要预测三个类别(狗、猫和马)。以下是我们通过算法预测出来的每一类的值:

Hinge loss/多分类 SVM 损失

计算这 3 个训练样本的 hinge loss:

  1. ## 1st training example 
  2. max(0, (1.49) - (-0.39) + 1) + max(0, (4.21) - (-0.39) + 1) 
  3. max(0, 2.88) + max(0, 5.6) 
  4. 2.88 + 5.6 
  5. 8.48 (High loss as very wrong prediction) 
  6. ## 2nd training example 
  7. max(0, (-4.61) - (3.28)+ 1) + max(0, (1.46) - (3.28)+ 1) 
  8. max(0, -6.89) + max(0, -0.82) 
  9. 0 + 0 
  10. 0 (Zero loss as correct prediction) 
  11. ## 3rd training example 
  12. max(0, (1.03) - (-2.27)+ 1) + max(0, (-2.37) - (-2.27)+ 1) 
  13. max(0, 4.3) + max(0, 0.9) 
  14. 4.3 + 0.9 
  15. 5.2 (High loss as very wrong prediction) 

交叉熵损失/负对数似然:

这是分类问题中最常见的设置。随着预测概率偏离实际标签,交叉熵损失会逐渐增加。

数学公式:

交叉熵损失

注意,当实际标签为 1(y(i)=1) 时,函数的后半部分消失,而当实际标签是为 0(y(i=0)) 时,函数的前半部分消失。简言之,我们只是把对真实值类别的实际预测概率的对数相乘。还有重要的一点是,交叉熵损失会重重惩罚那些置信度高但是错误的预测值。

  1. import numpy as np 
  2. predictions = np.array([[0.25,0.25,0.25,0.25], 
  3.                         [0.01,0.01,0.01,0.96]]) 
  4. targets = np.array([[0,0,0,1], 
  5.                    [0,0,0,1]]) 
  6. def cross_entropy(predictions, targets, epsilon=1e-10): 
  7.     predictions = np.clip(predictions, epsilon, 1. - epsilon) 
  8.     N = predictions.shape[0] 
  9.     ce_loss = -np.sum(np.sum(targets * np.log(predictions + 1e-5)))/N 
  10.     return ce_loss 
  11. cross_entropycross_entropy_loss = cross_entropy(predictions, targets) 
  12. print ("Cross entropy loss is: " + str(cross_entropy_loss)) 

【本文是51CTO专栏机构“机器之心”的原创文章,微信公众号“机器之心( id: almosthuman2014)”】

戳这里,看该作者更多好文

责任编辑:赵宁宁 来源: 51CTO专栏
相关推荐

2018-06-21 15:17:15

机器学习

2017-04-18 15:49:24

人工智能机器学习数据

2021-05-22 23:08:08

深度学习函数算法

2023-11-29 14:34:15

机器学习统计学

2024-06-27 00:46:10

机器学习向量相似度

2020-06-08 07:00:00

数据安全加密机密计算

2016-01-28 19:58:43

创业IT建设

2024-11-05 12:56:06

机器学习函数MSE

2016-08-30 13:23:26

DevOpsOpenStackIaaS

2022-10-28 15:19:28

机器学习距离度量数据集

2023-03-30 08:00:56

MySQL日期函数

2018-06-26 09:24:02

流量陷阱费用

2010-04-01 09:46:04

Oracle日期函数

2020-05-08 07:00:00

Linux色码文件类型

2021-05-08 05:40:32

Excel数据技巧

2023-04-11 08:49:42

排序函数SQL

2022-06-30 08:31:54

排序函数SQL

2023-11-28 12:08:56

机器学习算法人工智能

2012-02-13 22:50:59

集群高可用

2024-08-06 10:07:15

点赞
收藏

51CTO技术栈公众号