-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrie.py
161 lines (109 loc) · 3.81 KB
/
trie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from pykit import dictutil
class TrieNode(dict):
def __init__(self, *args, **kwargs):
super(TrieNode, self).__init__(*args, **kwargs)
# Total number of items with the prefix represented by this node.
# `None` means there might be more items with this prefix thus its
# number can not be decided yet.
self.n = None
# Trie branch key, a single char
self.char = ''
# An outstanding node is a node that there might be more following input
# string which has its corresponding prefix.
self.outstanding = None
# If this node is an outstanding node to its parent node.
# When created, a node must be an outstanding node.
self.is_outstanding = True
# If this node is an end of a line of string.
# A leaf node must be an eol node.
# But not vice verse
self.is_eol = False
def __str__(self):
children_keys = sorted(self.keys())
if self.is_outstanding:
mark = '*'
else:
mark = ' '
if self.char != '':
c = str(self.char) + ','
else:
c = ''
if len(children_keys) > 0:
colon = ': '
else:
colon = ''
line = '{mark}{c}{n}{colon}'.format(c=c, n=(self.n or '?'), mark=mark, colon=colon)
fields = []
for c in children_keys:
substr = str(self[c])
indent = ' ' * len(line)
substr = indent + substr.replace('\n', '\n' + indent)
fields.append(substr)
rst = '\n'.join(fields)
return line + rst[len(line):]
def _trie_node(parent, char):
n = TrieNode()
n.char = char
parent.outstanding = n
return n
def make_trie(sorted_iterable, node_max_num=1):
t = TrieNode()
for _s in sorted_iterable:
# find longest common prefix of _s and any seen string
node = t
for i, c in enumerate(_s):
if c in node:
node = node[c]
else:
break
# `node` now is at where the longest common prefix.
# `i` points to next char not in common prefix in `_s`.
# Since `node` is the longest prefix and all input strings are sorted,
# the prefix represented by the child `node.outstanding` can never be a prefix of
# any following strings.
_squash(node.outstanding, node_max_num)
for c in _s[i:]:
node[c] = _trie_node(node, c)
node = node[c]
node.is_eol = True
# Only leaf node is count by 1
node.n = 1
_squash(t, node_max_num)
return t
def sharding(sorted_iterable, size, accuracy=None, joiner=''.join):
if accuracy is None:
accuracy = size / 10
t = make_trie(sorted_iterable, node_max_num=accuracy)
n = 0
prev_key = None
rst = []
# iterate root node.
t = {'': t}
for ks, node in dictutil.depth_iter(t, is_allowed=lambda ks, v: v.is_eol or len(v) == 0):
if n >= size:
rst.append((prev_key, n))
# strip empty root path
prev_key = ks[1:]
prev_key = joiner(prev_key)
n = 0
if len(node) == 0:
n += node.n
else:
# node.is_eol == True
n += 1
rst.append((prev_key, n))
return rst
def _squash(node, node_max_num):
# If the number of strings with prefix at `node`(children of `node`) is smaller than `node_max_num`,
# squash all children to reduce memory cost.
if node is None:
return
_squash(node.outstanding, node_max_num)
total = node.n or 0
for subnode in node.values():
total += subnode.n
if total <= node_max_num:
for c in node.keys():
del node[c]
node.n = total
node.is_outstanding = False