-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgetting_a_position2.py
131 lines (112 loc) · 3.8 KB
/
getting_a_position2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import chess.pgn
import io
import all_possible_moves
all_moves = all_possible_moves.get_all_possible_moves_in_chess()
def hash_position(data):
data = data.encode("ascii")
hash_ = 0xcbf29ce484222325
for b in data:
hash_ *= 0x100000001b3
hash_ &= 0xffffffffffffffff
hash_ ^= b
return hash_
piece_dict = {
"p": 11,
"k": 14,
"q": 19,
"r": 15,
"n": 13,
"b": 13.5,
"P": 1,
"K": 4,
"Q": 9,
"R": 5,
"N": 3,
"B": 3.5,
".": 0
}
def get_positions_and_moves_from_pgn(a):
game_file = io.StringIO(a)
first_game = chess.pgn.read_game(game_file)
first_game.headers["Event"]
board = first_game.board()
position_count_dict = {}
for i, move in enumerate(first_game.mainline_moves()):
position_string = str(board)
position_array_str = ([(x).split()
for x in position_string.splitlines()])
position = [[15, 13, 13.5, 19, 14, 13.5, 13, 15],
[11, 11, 11, 11, 11, 11, 11, 11], [0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1],
[5, 3, 3.5, 9, 4, 3.5, 3, 5]]
for r in range(8):
for c in range(8):
position[r][c] = piece_dict[position_array_str[r][c]]
player_turn = [(i + 1) % 2 for x in range(8)]
white_castling_rights = board.has_castling_rights(chess.WHITE)
black_castling_rights = board.has_castling_rights(chess.BLACK)
white_castling_rights = [int(white_castling_rights) for x in range(8)]
black_castling_rights = [int(black_castling_rights) for x in range(8)]
repetition = 0
hashed_string = hash_position(position_string)
if (hashed_string not in position_count_dict):
position_count_dict[hashed_string] = 0
repetition = [0 for x in range(8)]
else:
position_count_dict[hashed_string] += 1
repetition = [position_count_dict[hashed_string] for x in range(8)]
position.append(player_turn)
position.append(white_castling_rights)
position.append(black_castling_rights)
position.append(repetition)
yield_move = ""
if (board.is_kingside_castling(move)):
yield_move = "0-0"
elif (board.is_queenside_castling(move)):
yield_move = "0-0-0"
else:
yield_move = (str(move)[0:4])
board.push(move)
position.extend(
get_possible_move_list(board, board.is_kingside_castling(move),
board.is_queenside_castling(move)))
yield(position,yield_move)
def get_possible_move_list(board, is_kingside_castling, is_queenside_castling):
global all_moves
legal_moves = []
legal_moves_no_duplicates = []
for item in board.legal_moves:
legal_moves.append(str(item)[0:4])
if (board.is_kingside_castling): legal_moves.append('0-0')
if (board.is_queenside_castling): legal_moves.append('0-0-0')
one_hot_encoded_list = [0 for x in range(len(all_moves))]
count = 0
for i, item in enumerate(all_moves):
if (item in legal_moves):
one_hot_encoded_list[i] = 1
legal_moves_no_duplicates.append(item)
count += 1
legal_moves = legal_moves_no_duplicates
if (count != len(legal_moves)):
print("Error. Something is wrong.")
print(legal_moves)
print(count, "count")
print(len(legal_moves), "legal_moves")
for item in legal_moves:
if (item not in all_moves):
print(item)
print("---------------------------------------------------------------")
# length of last all moves is 4034 with is not divisible by 8. Take last two elements out ot make it divisible.
var1 = one_hot_encoded_list[-2]
var2 = one_hot_encoded_list[-1]
return_list = []
temp_list = []
for i in range(len(one_hot_encoded_list) - 2):
if (i != 0 and i % 8 == 0):
return_list.append(temp_list.copy())
temp_list = []
temp_list.append(one_hot_encoded_list[i])
return_list.append([var1 for x in range(8)])
return_list.append([var2 for x in range(8)])
return return_list