干货 | 手把手教你用115行代码做个数独解析器!

人工智能 机器学习
这里有一份数独解析教程,等待你查收~ 喜欢收藏硬核干货的小伙伴看过来。

 

你也是数独爱好者吗?

Aakash Jhawar和许多人一样,乐于挑战新的难题。上学的时候,他每天早上都要玩数独。长大后,随着科技的进步,我们可以让计算机来帮我们解数独了!只需要点击数独的图片,它就会为你填满全部九宫格。

叮~ 这里有一份数独解析教程,等待你查收~ 喜欢收藏硬核干货的小伙伴看过来~

我们都知道,数独由9×9的格子组成,每行、列、宫各自都要填上1-9的数字,要做到每行、列、宫里的数字都不重复。

可以将解析数独的整个过程分成3步:

第一步:从图像中提取数独

第二步:提取图像中出现的每个数字

第三步:用算法计算数独的解

第一步:从图像中提取数独

首先需要进行图像处理。

1、对图像进行预处理

首先,我们应用高斯模糊的内核大小(高度,宽度)为9的图像。注意,内核大小必须是正的和奇数的,并且内核必须是平方的。然后使用11个最近邻像素自适应阈值。 

  1. proc = cv2.GaussianBlur(img.copy(),(9,9),0)  
  2. proc = cv2.adaptiveThreshold(proc,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY,11,2) 

为了使网格线具有非零像素值,我们颠倒颜色。此外,把图像放大,以增加网格线的大小。 

  1. proc = cv2.bitwise_not(proc,proc)     
  2. kernel = np.array([[0。,1.,0.],[1.,1.,1.],[0.,1.,0.]] ,np.uint8)  
  3. proc = cv2.dilate(proc,kernel) 

       

阈值化后的数独图像

2、找出最大多边形的角

下一步是寻找图像中最大轮廓的4个角。所以需要找到所有的轮廓线,按面积降序排序,然后选择面积最大的那个。 

  1. _, contours, h = cv2.findContours(img.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)  
  2. contours = sorted(contours, key=cv2.contourArea, reverse=True 
  3. polygon = contours[0] 

使用的操作符。带有max和min的itemgetter允许我们获得该点的索引。每个点都是有1个坐标的数组,然后[0]和[1]分别用于获取x和y。

右下角点具有最大的(x + y)值;左上角有点最小(x + y)值;左下角则具有最小的(x - y)值;右上角则具有最大的(x - y)值。 

  1. bottom_right, _ = max(enumerate([pt[0][0] + pt[0][1] for pt in  
  2.                       polygon]), key=operator.itemgetter(1))  
  3. top_left, _ = min(enumerate([pt[0][0] + pt[0][1] for pt in  
  4.                   polygon]), key=operator.itemgetter(1))  
  5. bottom_left, _ = min(enumerate([pt[0][0] - pt[0][1] for pt in  
  6.                      polygon]), key=operator.itemgetter(1))  
  7. top_right, _ = max(enumerate([pt[0][0] - pt[0][1] for pt in  
  8.                    polygon]), key=operator.itemgetter(1)) 

现在我们有了4个点的坐标,然后需要使用索引返回4个点的数组。每个点都在自己的一个坐标数组中。 

  1. [polygon[top_left][0], polygon[top_right][0], polygon[bottom_right][0], polygon[bottom_left][0]] 

最大多边形的四个角

3、裁剪和变形图像

有了数独的4个坐标后,我们需要剪裁和弯曲一个矩形部分,从一个图像变成一个类似大小的正方形。由左上、右上、右下和左下点描述的矩形。

注意:将数据类型显式设置为float32或‘getPerspectiveTransform’会引发错误。 

  1. top_left, top_right, bottom_right, bottom_left = crop_rect[0], crop_rect[1], crop_rect[2], crop_rect[3]  
  2. src = np.array([top_left, top_right, bottom_right, bottom_left], dtypefloat32 )   
  3. side = max([  distance_between(bottom_right, top_right),   
  4.             distance_between(top_left, bottom_left),  
  5.             distance_between(bottom_right, bottom_left),     
  6.             distance_between(top_left, top_right) ]) 

用计算长度的边来描述一个正方形,这是要转向的新视角。然后要做的是通过比较之前和之后的4个点来得到用于倾斜图像的变换矩阵。最后,再对原始图像进行变换。 

  1. dst = np.array([[0, 0], [side - 1, 0], [side - 1, side - 1], [0, side - 1]], dtypefloat32 )  
  2. m = cv2.getPerspectiveTransform(src, dst)  
  3. cv2.warpPerspective(img, m, (int(side), int(side))) 

裁剪和变形后的数独图像

4、从正方形图像中推断网格

从正方形图像推断出81个单元格。我们在这里交换 j 和 i ,这样矩形就被存储在从左到右读取的列表中,而不是自上而下。 

  1. squares = []   
  2. side = img.shape[:1]   
  3. sideside = side[0] / 9 
  4. for j in range(9):  
  5.     for i in range(9):  
  6.         p1 = (i * side, j * side)  #Top left corner of a box    
  7.         p2 = ((i + 1) * side, (j + 1) * side)  #Bottom right corner       
  8.          squares.append((p1, p2)) return squares 

5、得到每一位数字

下一步是从其单元格中提取数字并构建一个数组。 

  1. digits = []  
  2. img = pre_process_image(img.copy(), skip_dilate=True 
  3. for square in squares:  
  4.     digits.append(extract_digit(img, square, size)) 

extract_digit 是从一个数独方块中提取一个数字(如果有的话)的函数。它从整个方框中得到数字框,使用填充特征查找来获得框中间的最大特征,以期在边缘找到一个属于该数字的像素,用于定义中间的区域。接下来,需要缩放并填充数字,让适合用于机器学习的数字大小的平方。同时,我们必须忽略任何小的边框。 

  1. def extract_digit(img, rect, size):  
  2.     digit = cut_from_rect(img, rect)  
  3.     h, w = digit.shape[:2]  
  4.     margin = int(np.mean([h, w]) / 2.5)  
  5.     _, bbox, seed = find_largest_feature(digit, [margin, margin], [w  
  6.     - margin, h - margin])  
  7.     digit = cut_from_rect(digit, bbox)   
  8.     w = bbox[1][0] - bbox[0][0]  
  9.     h = bbox[1][1] - bbox[0][1]  
  10.     if w > 0 and h > 0 and (w * h) > 100 and len(digit) > 0:  
  11.         return scale_and_centre(digit, size, 4)  
  12.     else:  
  13.         return np.zeros((size, size), np.uint8) 

      

最后的数独的形象

现在,我们有了最终的数独预处理图像,下一个任务是提取图像中的每一位数字,并将其存储在一个矩阵中,然后通过某种算法计算出数独的解。

第二步:提取图像中出现的每个数字

对于数字识别,我们将在MNIST数据集上训练神经网络,该数据集包含60000张0到9的数字图像。从导入所有库开始。 

  1. import numpy  
  2. import cv2from keras.datasets   
  3. import mnistfrom keras.models   
  4. import Sequentialfrom keras.layers   
  5. import Densefrom keras.layers   
  6. import Dropoutfrom keras.layers   
  7. import Flattenfrom keras.layers.convolutional   
  8. import Conv2Dfrom keras.layers.convolutional   
  9. import MaxPooling2Dfrom keras.utils   
  10. import np_utilsfrom keras   
  11. import backend as K  
  12. import matplotlib.pyplot as plt 

需要修复随机种子以确保可重复性。 

  1. K.set_image_dim_ordering( th )  
  2. seed = 7numpy.random.seed(seed)  
  3. (X_train, y_train), (X_test, y_test) = mnist.load_data() 

然后将图像重塑为样本*像素*宽度*高度,并输入从0-255规范化为0-1。在此之后,对输出进行热编码。 

  1. X_trainX_train = X_train.reshape(X_train.shape[0], 1, 28,  
  2.                            28).astype( float32 )  
  3. X_testX_test = X_test.reshape(X_test.shape[0], 1, 28,  
  4.                            28).astype( float32 )   
  5. X_trainX_train = X_train / 255  
  6. X_testX_test = X_test / 255 
  7. y_train = np_utils.to_categorical(y_train)  
  8. y_test = np_utils.to_categorical(y_test)  
  9. num_classes = y_test.shape[1] 

接下来,我们将创建一个模型来预测手写数字。 

  1. model = Sequential()  
  2. model.add(Conv2D(32, (5, 5), input_shape=(1, 28, 28),  
  3.           activationrelu ))  
  4. model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(16, (3,  
  5.           3), activationrelu ))  
  6. model.add(MaxPooling2D(pool_size=(2, 2))) 
  7. model.add(Dropout(0.2)) 
  8. model.add(Flatten())  
  9. model.add(Dense(128, activationrelu ))  
  10. model.add(Dense(64, activationrelu ))  
  11. model.add(Dense(num_classes, activationsoftmax )) 

模型总结

在创建模型之后,需要进行编译,使其适合数据集并对其进行评估。 

  1. model.compile(losscategorical_crossentropy , optimizeradam ,  
  2.                metrics=[ accuracy ])  
  3. model.fit(X_train, y_train, validation_data=(X_test, y_test),  
  4.                epochs=10batch_size=200 
  5. scores = model.evaluate(X_test, y_test, verbose=0 
  6. print("Large CNN Error: %.2f%%" % (100-scores[1]*100)) 

现在,可以测试上面创建的模型了。 

  1. test_images = X_test[1:5]  
  2. test_imagestest_images = test_images.reshape(test_images.shape[0], 28, 28)  
  3. print ("Test images shape: {}".format(test_images.shape))  
  4. for i, test_image in enumerate(test_images, start=1):  
  5.     org_image = test_image  
  6.     test_imagetest_image = test_image.reshape(1,1,28,28)  
  7.     prediction = model.predict_classes(test_image, verbose=0 
  8.     print ("Predicted digit: {}".format(prediction[0]))  
  9.     plt.subplot(220+i)  
  10.     plt.axis( off )  
  11.     plt.title("Predicted digit: {}".format(prediction[0]))  
  12.     plt.imshow(org_image, cmap=plt.get_cmap( gray ))  
  13. plt.show() 

手写体数字分类模型的预测数字

神经网络的精度为98.314%!最后,保存序列模型,这样就不必在需要使用它的时候反复训练了。 

  1. # serialize model to JSON  
  2. modelmodel_json = model.to_json()  
  3. with open("model.json", "w") as json_file:  
  4.     json_file.write(model_json) 
  5. # serialize weights to HDF5  
  6. model.save_weights("model.h5")  
  7. print("Saved model to disk") 

更多关于手写数字识别的信息:

https://github.com/aakashjhawar/Handwritten-Digit-Recognition

下一步是加载预先训练好的模型。 

  1. json_file = open( model.json ,  r )  
  2. loaded_model_json = json_file.read()  
  3. json_file.close()  
  4. loaded_model = model_from_json(loaded_model_json)  
  5. loaded_model.load_weights("model.h5") 

调整图像大小,并将图像分割成9x9的小图像。每个小图像的数字都是从1-9。 

  1. sudoku = cv2.resize(sudoku, (450,450))  
  2. grid = np.zeros([9,9])  
  3. for i in range(9):  
  4.     for j in range(9):  
  5.         image = sudoku[i*50:(i+1)*50,j*50:(j+1)*50]  
  6.         if image.sum() > 25000:      
  7.             grid[i][j] = identify_number(image)  
  8.         else:  
  9.             grid[i][j] = 0      
  10. gridgrid =  grid.astype(int) 

identify_number 函数拍摄数字图像并预测图像中的数字。 

  1. def identify_number(image):  
  2.     image_resize = cv2.resize(image, (28,28))    # For plt.imshow  
  3.     image_resizeimage_resize_2 = image_resize.reshape(1,1,28,28)    # For input to model.predict_classes  
  4. #    cv2.imshow( number , image_test_1)  
  5.     loaded_modelloaded_model_pred = loaded_model.predict_classes(image_resize_2   
  6.                                                       , verbose = 0 
  7.     return loaded_model_pred[0]  

完成以上步骤后,数独网格看起来是这样的:

提取的数独

第三步:用回溯算法计算数独的解

我们将使用回溯算法来计算数独的解。

在网格中搜索仍未分配的条目。如果找到引用参数行,col 将被设置为未分配的位置,而 true 将被返回。如果没有未分配的条目保留,则返回false。“l” 是 solve_sudoku 函数传递的列表变量,用于跟踪行和列的递增。 

  1. def find_empty_location(arr,l):  
  2.     for row in range(9):  
  3.         for col in range(9):  
  4.             if(arr[row][col]==0):  
  5.                 l[0]=row  
  6.                 l[1]=col  
  7.                 return True  
  8.     return False 

返回一个boolean,指示指定行的任何赋值项是否与给定数字匹配。 

  1. def used_in_row(arr,row,num):  
  2.     for i in range(9):     
  3.         if(arr[row][i] == num):    
  4.             return True  
  5.     return False 

返回一个boolean,指示指定列中的任何赋值项是否与给定数字匹配。 

  1. def used_in_col(arr,col,num):  
  2.     for i in range(9):    
  3.         if(arr[i][col] == num):   
  4.             return True  
  5.     return False 

返回一个boolean,指示指定的3x3框内的任何赋值项是否与给定的数字匹配。 

  1. def used_in_box(arr,row,col,num):  
  2.     for i in range(3):  
  3.         for j in range(3):  
  4.             if(arr[i+row][j+col] == num):      
  5.             return True  
  6.      return False 

检查将num分配给给定的(row, col)是否合法。检查“ num”是否尚未放置在当前行,当前列和当前3x3框中。 

  1. def check_location_is_safe(arr,row,col,num):  
  2.     return not used_in_row(arr,row,num) and   
  3.            not used_in_col(arr,col,num) and   
  4.            not used_in_box(arr,row - row%3,col - col%3,num) 

采用部分填入的网格,并尝试为所有未分配的位置分配值,以满足数独解决方案的要求(跨行、列和框的非重复)。“l” 是一个列表变量,在 find_empty_location 函数中保存行和列的记录。将我们从上面的函数中得到的行和列赋值给列表值。 

  1. def solve_sudoku(arr):  
  2.     l=[0,0]   
  3.     if(not find_empty_location(arr,l)):  
  4.         return True   
  5.     row=l[0]  
  6.     col=l[1]   
  7.     for num in range(1,10):   
  8.         if(check_location_is_safe(arr,row,col,num)):   
  9.             arr[row][col]=num   
  10.             if(solve_sudoku(arr)):   
  11.                 return True   
  12.             # failure, unmake & try again   
  13.             arr[row][col] = 0   
  14.     return False 

最后一件事是print the grid。 

  1. def print_grid(arr):  
  2.     for i in range(9):  
  3.         for j in range(9):    
  4.             print (arr[i][j])   
  5.          print (  ) 

最后,把所有的函数整合在主函数中。 

  1. def sudoku_solver(grid):  
  2.     if(solve_sudoku(grid)):  
  3.         print( --- )  
  4.     else:  
  5.         print ("No solution exists")  
  6.     gridgrid = grid.astype(int) 
  7.      return grid 

这个函数的输出将是最终解出的数独。

最终的解决方案

当然,这个解决方案绝不是万无一失的,处理图像时仍然会出现一些问题,要么无法解析,要么解析错误导致无法处理。不过,我们的目标是探索新技术,从这个角度来看,这个项目还是有价值的。 

 

责任编辑:庞桂玉 来源: 机器学习算法与Python学习
相关推荐

2017-10-29 21:43:25

人脸识别

2017-10-27 10:29:35

人脸识别UbuntuPython

2021-08-09 13:31:25

PythonExcel代码

2022-10-19 14:30:59

2021-02-06 14:55:05

大数据pandas数据分析

2021-02-04 09:00:57

SQLDjango原生

2011-03-28 16:14:38

jQuery

2022-08-04 10:39:23

Jenkins集成CD

2022-06-30 16:10:26

Python计时器装饰器

2020-08-12 07:41:39

SQL 优化语句

2009-04-22 09:17:19

LINQSQL基础

2021-01-21 09:10:29

ECharts柱状图大数据

2021-01-08 10:32:24

Charts折线图数据可视化

2021-05-10 06:48:11

Python腾讯招聘

2021-08-02 23:15:20

Pandas数据采集

2012-01-11 13:40:35

移动应用云服务

2021-12-11 20:20:19

Python算法线性

2020-03-08 22:06:16

Python数据IP

2021-02-02 13:31:35

Pycharm系统技巧Python

2022-05-11 10:45:21

SpringMVC框架Map
点赞
收藏

51CTO技术栈公众号