pythonで数独の問題を解く方法のコード

このページでは「pythonで数独の問題を解く方法」のコードを載せています。

コードの説明やアルゴリズムはこちら

 

pythonのコード

import numpy as np
import random
import copy


# 9*9の表を表示する,9*9の配列を引数に入れてください
def table_print(field):
    print(" +----------+----------+----------+")
    for i in range(0, 3):
        for j in range(0, 3):
            print(" | ", end="")
            for k in range(0, 3):
                print(int(field[(j+3*i, k)]), " ",  end="")
            print("| ", end="")
            for k in range(3, 6):
                print(int(field[(j+3*i, k)]), " ",  end="")
            print("| ", end="")
            for k in range(6, 9):
                print(int(field[(j+3*i, k)]), " ", end="")
            print("|")
        print(" +----------+----------+----------+")


def row_num_list(x, y, field):
    num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    for i in range(9):
        n = field[(x, i)]
        if( (n != 0) and (n in num_list) ):
            num_list.remove(n)
    return num_list


def col_num_list(x, y, field):
    num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    for i in range(9):
        n = field[(i, y)]
        if( (n != 0) and (n in num_list) ):
            num_list.remove(n)
    return num_list


def box_num_list(x, y, field):
    num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    offset1 = [0, 1, 2]
    offset2 = [-1, 0, 1]
    offset3 = [-2, -1, 0]
    offset_list = [offset1, offset2, offset3]
    x_offset = offset_list[x % 3]
    y_offset = offset_list[y % 3]
    for i in range(3):
        for j in range(3):
            tmpx = x + x_offset[i]
            tmpy = y + y_offset[j]
            n = field[(tmpx, tmpy)]
            if( (n != 0) and (n in num_list) ):
                num_list.remove(n)
    return num_list


def ok_row_num(x, y, field):
    coordinate_list = [] # 座標がリストされている
    collect_list = []    # and_listをリストしている
    all_list = []        # and_listを全て足し合わせたもの
    for i in range(0, 9):
        # x, y = x, i
        # 空欄のマスに格納可能な数字のリストを得る
        if( field[(x, i)] == 0 ):
            row_list = row_num_list(x, i, field)
            col_list = col_num_list(x, i, field)
            box_list = box_num_list(x, i, field)
            l = set(row_list) & set(col_list) & set(box_list)
            and_list = list(l)
            coordinate_list.append((x, i))
            collect_list.append(and_list)
            all_list = all_list + and_list
    for i in range(0, len(coordinate_list)):
        for j in collect_list[i]:
            coorx, coory = coordinate_list[i]
            c = all_list.count(j)
            if( (x == coorx) and (y == coory) and (c == 1) ):
                return j
    return 0


def ok_col_num(x, y, field):
    coordinate_list = [] # 座標がリストされている
    collect_list = []    # and_listをリストしている
    all_list = []        # and_listを全て足し合わせたもの
    for i in range(0, 9):
        # x, y = i, y
        # 空欄のマスに格納可能な数字のリストを得る
        if( field[(i, y)] == 0 ):
            row_list = row_num_list(i, y, field)
            col_list = col_num_list(i, y, field)
            box_list = box_num_list(i, y, field)
            l = set(row_list) & set(col_list) & set(box_list)
            and_list = list(l)
            coordinate_list.append((i, y))
            collect_list.append(and_list)
            all_list = all_list + and_list
    for i in range(0, len(coordinate_list)):
        for j in collect_list[i]:
            coorx, coory = coordinate_list[i]
            c = all_list.count(j)
            if( (x == coorx) and (y == coory) and (c == 1) ):
                return j
    return 0


def ok_box_num(x, y, field):
    offset1 = [0, 1, 2]
    offset2 = [-1, 0, 1]
    offset3 = [-2, -1, 0]
    offset_list = [offset1, offset2, offset3]
    x_offset = offset_list[x % 3]
    y_offset = offset_list[y % 3]
    coordinate_list = []
    collect_list = []
    all_list = []
    for i in range(0, 3):
        for j in range(0, 3):
            # (x, y) = (setx, sety)
            setx = x + x_offset[i]
            sety = y + y_offset[j]
            if( field[(setx, sety)] == 0 ):
                row_list = row_num_list(setx, sety, field)
                col_list = col_num_list(setx, sety, field)
                box_list = box_num_list(setx, sety, field)
                l = set(row_list) & set(col_list) & set(box_list)
                and_list = list(l)
                coordinate_list.append((setx, sety))
                collect_list.append(and_list)
                all_list = all_list + and_list
    # リストから数字が1回しか出現しないリストを特定する
    # そして格納する
    for i in range(0, len(coordinate_list)):
        for j in collect_list[i]:
            coorx, coory = coordinate_list[i]
            c = all_list.count(j)
            if( (x == coorx) and (y == coory) and (c == 1) ):
                return j
    return 0


def ok_num_list(x, y, field):
    row_list = row_num_list(x, y, field)
    col_list = col_num_list(x, y, field)
    box_list = box_num_list(x, y, field)
    l = set(row_list) & set(col_list) & set(box_list)
    and_list = list(l)
    #
    row_num = ok_row_num(x, y, field)
    col_num = ok_col_num(x, y, field)
    box_num = ok_box_num(x, y, field)
    row_num_TF = row_num in and_list
    col_num_TF = col_num in and_list
    box_num_TF = box_num in and_list
    if( row_num in and_list ):
        and_list.clear()
        and_list.append(row_num)
        return and_list
    else:
        if( col_num in and_list ):
            and_list.clear()
            and_list.append(col_num)
            return and_list
        else:
            if( box_num in and_list ):
                and_list.clear()
                and_list.append(box_num)
                return and_list
            else:
                return and_list


# 再帰を使ったsolver
def solve_func(emp_dic):
    dic_length = np.array([])
    for v in emp_dic.values():
        dic_length = np.append(dic_length, len(v))
    return dic_length
def solver(emp, emp_dic, field, dic_length):
    dic_length = solve_func(emp_dic)
    cp_emp = copy.copy(emp)
    cp_emp_dic = copy.copy(emp_dic)
    cp_field = copy.copy(field)
    cp_dic_length = copy.copy(dic_length)
    before_emp = len(cp_emp)
    after_emp = len(cp_emp)
    for vx, vy in cp_emp:
        l = ok_num_list(vx, vy, cp_field)
        cp_emp_dic[(vx, vy)] = l
        if(len(l) == 0):
            return emp, emp_dic, field, dic_length
    while(len(cp_emp) > 0):
        for x, y in cp_emp:
            length = len(cp_emp_dic[(x, y)])
            if(cp_emp_dic[(x, y)][0] == 0):
                return emp, emp_dic, field, dic_length
            if( (length == 1) and ((x, y) in cp_emp) ):
                n = cp_emp_dic[(x, y)][0]
                cp_field[(x, y)] = n
                cp_emp.remove((x, y))
                cp_emp_dic[(x, y)] = [0,0,0,0,0,0,0,0,0,0,0]
                z = (9*x) + y
                cp_dic_length[z] = 11
        for vx, vy in cp_emp:
            l = ok_num_list(vx, vy, cp_field)
            cp_emp_dic[(vx, vy)] = l
            if(len(l) == 0):
                return emp, emp_dic, field, dic_length
        cp_dic_length = solve_func(cp_emp_dic)
        tmp = after_emp
        before_emp = tmp
        after_emp = len(cp_emp)
        if(before_emp == after_emp):
            cp_dic_length = solve_func(cp_emp_dic)
            copy_emp = copy.copy(cp_emp)
            copy_emp_dic = copy.copy(cp_emp_dic)
            copy_field = copy.copy(cp_field)
            copy_dic_length = copy.copy(cp_dic_length)
            z = np.argmin(copy_dic_length)
            x = int(z/9)
            y = z % 9
            l = copy_emp_dic[(x, y)]
            num = len(l)
            count = 0
            while(count < num):
                n = l[count]
                copy_field[(x, y)] = n
                copy_emp.remove((x, y))
                copy_emp_dic[(x, y)] = [0,0,0,0,0,0,0,0,0,0,0]
                copy_dic_length[z] = 11
                ce = len(copy_emp)
                copy_emp, copy_emp_dic, copy_field, copy_dic_length = solver(copy_emp, copy_emp_dic, copy_field, copy_dic_length)
                if(ce == len(copy_emp)):
                    copy_emp = copy.copy(cp_emp)
                    copy_emp_dic = copy.copy(cp_emp_dic)
                    copy_field = copy.copy(cp_field)
                    copy_dic_length = copy.copy(cp_dic_length)
                    count += 1
                    continue
                break
            if(count >= num):
                return emp, emp_dic, field, dic_length
            cp_emp = copy_emp
            cp_emp_dic = copy_emp_dic
            cp_field = copy_field
            cp_dic_length = copy_dic_length
    return cp_emp, cp_emp_dic, cp_field, cp_dic_length


# 問題を入力する
def set_field(field):
    # 空欄は0を入れてください
    field = np.array([[0, 0, 0,   0, 0, 0,   0, 0, 0],
                      [0, 0, 0,   0, 0, 0,   0, 0, 0],
                      [0, 0, 0,   0, 0, 0,   0, 0, 0],

                      [0, 0, 0,   0, 0, 0,   0, 0, 0],
                      [0, 0, 0,   0, 0, 0,   0, 0, 0],
                      [0, 0, 0,   0, 0, 0,   0, 0, 0],

                      [0, 0, 0,   0, 0, 0,   0, 0, 0],
                      [0, 0, 0,   0, 0, 0,   0, 0, 0],
                      [0, 0, 0,   0, 0, 0,   0, 0, 0]])
    return field

# 実行部
def main():

    """ variable setting """
    field = np.zeros((9, 9), dtype=int)
    field = set_field(field)
    emp_dic = {}
    emp = []
    for i in range(9):
        for j in range(9):
            if(field[(i, j)] == 0):
                emp.append((i, j))
                l = ok_num_list(i, j, field)
                emp_dic[(i, j)] = l
            else:
                emp_dic[(i, j)] = [0,0,0,0,0,0,0,0,0,0,0]
    """ End of variable setting """


    """ Show problem (Before)"""
    print("ans-sudoku.py Before")
    table_print(field)
    #print(emp_dic)
    """ End of show """


    """ Solver """
    dic_length = solve_func(emp_dic)
    emp, emp_dic, field, dic_length = solver(emp, emp_dic, field, dic_length)



    """ Show of answer (After) """

    print("")
    print("Answer")
    table_print(field)
    if(len(emp) >= 1):
        print(emp_dic)

if __name__ == '__main__':
    main()

スポンサードサーチ