TensorFlow和Keras解决大数据量内存溢出问题

存储 存储软件
解决思路其实说来也简单,打破思维定式就好了,不是把所有图片读到内存中,而是只把所有图片的路径一次性读到内存中。

以前做的练手小项目导致新手产生一个惯性思维——读取训练集图片的时候把所有图读到内存中,然后分批训练。

其实这是有问题的,很容易导致OOM。现在内存一般16G,而训练集图片通常是上万张,而且RGB图,还很大,VGG16的图片一般是224x224x3,上万张图片,16G内存根本不够用。这时候又会想起——设置batch,但是那个batch的输入参数却又是图片,它只是把传进去的图片分批送到显卡,而我OOM的地方恰是那个“传进去”的图片,怎么办?

解决思路其实说来也简单,打破思维定式就好了,不是把所有图片读到内存中,而是只把所有图片的路径一次性读到内存中。

[[229304]]

大致的解决思路为:

将上万张图片的路径一次性读到内存中,自己实现一个分批读取函数,在该函数中根据自己的内存情况设置读取图片,只把这一批图片读入内存中,然后交给模型,模型再对这一批图片进行分批训练,因为内存一般大于等于显存,所以内存的批次大小和显存的批次大小通常不相同。

下面代码分别介绍Tensorflow和Keras分批将数据读到内存中的关键函数。Tensorflow对初学者不太友好,所以我个人现阶段更习惯用它的高层API Keras来做相关项目,下面的TF实现是之前不会用Keras分批读时候参考的一些列资料,在模型训练上仍使用Keras,只有分批读取用了TF的API。

TensorFlow

在input.py里写get_batch函数。

  1. def get_batch(X_train, y_train, img_w, img_h, color_type, batch_size, capacity): 
  2.    ''
  3.    Args: 
  4.        X_train: train img path list 
  5.        y_train: train labels list 
  6.        img_w: image width 
  7.        img_h: image height 
  8.        batch_size: batch size 
  9.        capacity: the maximum elements in queue 
  10.    Returns
  11.        X_train_batch: 4D tensor [batch_size, width, height, chanel],\ 
  12.                        dtype=tf.float32 
  13.        y_train_batch: 1D tensor [batch_size], dtype=int32 
  14.    ''
  15.    X_train = tf.cast(X_train, tf.string) 
  16.  
  17.    y_train = tf.cast(y_train, tf.int32)     
  18.    # make an input queue 
  19.    input_queue = tf.train.slice_input_producer([X_train, y_train]) 
  20.  
  21.    y_train = input_queue[1] 
  22.    X_train_contents = tf.read_file(input_queue[0]) 
  23.    X_train = tf.image.decode_jpeg(X_train_contents, channels=color_type) 
  24.  
  25.    X_train = tf.image.resize_images(X_train, [img_h, img_w],  
  26.                                     tf.image.ResizeMethod.NEAREST_NEIGHBOR) 
  27.  
  28.    X_train_batch, y_train_batch = tf.train.batch([X_train, y_train], 
  29.                                                  batch_size=batch_size, 
  30.                                                  num_threads=64, 
  31.                                                  capacity=capacity) 
  32.    y_train_batch = tf.one_hot(y_train_batch, 10)    return X_train_batch, y_train_batch 

在train.py文件中训练(下面不是纯TF代码,model.fit是Keras的拟合,用纯TF的替换就好了)。

  1. X_train_batch, y_train_batch = inp.get_batch(X_train, y_train,  
  2.                                             img_w, img_h, color_type,  
  3.                                             train_batch_size, capacity) 
  4. X_valid_batch, y_valid_batch = inp.get_batch(X_valid, y_valid,  
  5.                                             img_w, img_h, color_type,  
  6.                                             valid_batch_size, capacity)with tf.Session() as sess: 
  7.  
  8.    coord = tf.train.Coordinator() 
  9.    threads = tf.train.start_queue_runners(coord=coord)    
  10.  try:        
  11.  for step in np.arange(max_step):             
  12. if coord.should_stop() :                 
  13. break 
  14.            X_train, y_train = sess.run([X_train_batch,  
  15.                                             y_train_batch]) 
  16.            X_valid, y_valid = sess.run([X_valid_batch, 
  17.                                             y_valid_batch]) 
  18.               
  19.            ckpt_path = 'log/weights-{val_loss:.4f}.hdf5' 
  20.            ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_path,  
  21.                                                      monitor='val_loss',  
  22.                                                      verbose=1,  
  23.                                                      save_best_only=True,  
  24.                                                      mode='min'
  25.            model.fit(X_train, y_train, batch_size=64,  
  26.                          epochs=50, verbose=1, 
  27.                          validation_data=(X_valid, y_valid), 
  28.                          callbacks=[ckpt])             
  29.            del X_train, y_train, X_valid, y_valid     
  30. except tf.errors.OutOfRangeError: 
  31.        print('done!')    finally: 
  32.        coord.request_stop() 
  33.    coord.join(threads) 
  34.    sess.close() 

Keras

keras文档中对fit、predict、evaluate这些函数都有一个generator,这个generator就是解决分批问题的。

关键函数:fit_generator

  1. # 读取图片函数 
  2. def get_im_cv2(paths, img_rows, img_cols, color_type=1, normalize=True): 
  3.    ''
  4.    参数: 
  5.        paths:要读取的图片路径列表 
  6.        img_rows:图片行 
  7.        img_cols:图片列 
  8.        color_type:图片颜色通道 
  9.    返回:  
  10.        imgs: 图片数组 
  11.    ''
  12.    # Load as grayscale 
  13.    imgs = []    for path in paths:         
  14. if color_type == 1: 
  15.            img = cv2.imread(path, 0)         
  16. elif color_type == 3: 
  17.            img = cv2.imread(path)         
  18. # Reduce size 
  19.        resized = cv2.resize(img, (img_cols, img_rows))        
  20.  if normalize: 
  21.            resized = resized.astype('float32'
  22.            resized /= 127.5 
  23.            resized -= 1.  
  24.         
  25.        imgs.append(resized)         
  26.    return np.array(imgs).reshape(len(paths), img_rows, img_cols, color_type) 

获取批次函数,其实就是一个generator

  1. def get_train_batch(X_train, y_train, batch_size, img_w, img_h, color_type, is_argumentation): 
  2.    ''
  3.    参数: 
  4.        X_train:所有图片路径列表 
  5.        y_train: 所有图片对应的标签列表 
  6.        batch_size:批次 
  7.        img_w:图片宽 
  8.        img_h:图片高 
  9.        color_type:图片类型 
  10.        is_argumentation:是否需要数据增强 
  11.    返回:  
  12.        一个generator, 
  13. x: 获取的批次图片  
  14. y: 获取的图片对应的标签 
  15.    ''
  16.    while 1:         
  17. for i in range(0, len(X_train), batch_size): 
  18.            x = get_im_cv2(X_train[i:i+batch_size], img_w, img_h, color_type) 
  19.            y = y_train[i:i+batch_size]             
  20. if is_argumentation:                 
  21. # 数据增强 
  22.                x, y = img_augmentation(x, y)             
  23. # 最重要的就是这个yield,它代表返回,返回以后循环还是会继续,然后再返回。就比如有一个机器一直在作累加运算,但是会把每次累加中间结果告诉你一样,直到把所有数加完 
  24.            yield({'input': x}, {'output': y}) 

训练函数

  1. result = model.fit_generator(generator=get_train_batch(X_train, y_train, train_batch_size, img_w, img_h, color_type, True),  
  2.          steps_per_epoch=1351,  
  3.          epochs=50, verbose=1, 
  4.          validation_data=get_train_batch(X_valid, y_valid, valid_batch_size,img_w, img_h, color_type, False), 
  5.          validation_steps=52, 
  6.          callbacks=[ckpt, early_stop], 
  7.          max_queue_size=capacity, 
  8.          workers=1) 

就是这么简单。但是当初从0到1的过程很难熬,每天都没有进展,没有头绪,急躁占据了思维的大部,熬过了这个阶段,就会一切顺利,不是运气,而是踩过的从0到1的每个脚印累积的灵感的爆发,从0到1的脚印越多,后面的路越顺利。

责任编辑:武晓燕 来源: 人工智能LeadAI
相关推荐

2021-03-06 10:25:19

内存Java代码

2021-02-03 15:12:08

java内存溢出

2010-09-26 15:53:25

JVM内存溢出

2011-08-25 10:50:32

SQL Server数Performance

2009-12-08 15:19:58

WCF大数据量

2024-04-25 10:06:03

内存泄漏

2011-08-16 09:21:30

MySQL大数据量快速语句优化

2023-08-29 11:38:27

Java内存

2011-04-18 11:13:41

bcp数据导入导出

2024-01-31 10:11:41

Redis内存

2022-03-25 09:01:16

CSS溢出属性

2024-01-29 08:45:38

MySQL大数据分页

2010-07-29 13:30:54

Hibari

2018-09-06 16:46:33

数据库MySQL分页查询

2018-04-02 15:37:33

数据库MySQL翻页

2010-12-01 09:18:19

数据库优化

2010-05-05 10:30:46

MongoDBNoSQL

2024-07-30 15:56:42

2012-12-26 09:23:56

数据库优化

2024-09-09 09:41:03

内存溢出golang开发者
点赞
收藏

51CTO技术栈公众号