-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathlexical_constraints.py
112 lines (87 loc) · 3.66 KB
/
lexical_constraints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import pickle
from typing import List, Optional, Set
class Literal:
def __init__(self, tokens: List[int]):
self.tokens = tokens
self.pointer = -1
self.satisfy = False
def advance(self, word_id: int):
# token matches the next token in constraint
if word_id == self.tokens[self.pointer + 1]:
self.pointer += 1
else:
self.pointer = -1
if self.pointer == len(self.tokens) - 1:
self.satisfy = True
class Clause:
def __init__(self, idx: int, phrases: List[List[int]]):
self.idx = idx
self.literals = [Literal(p) for p in phrases]
self.satisfy = False
def advance(self, word_id: int):
for literal in self.literals:
literal.advance(word_id)
if literal.satisfy:
self.satisfy = True
def __str__(self):
return f'clause(id={self.idx}, phrases={[l.tokens for l in self.literals]}, satisfy={self.satisfy})'
class ConstrainedHypothesis:
def __init__(self,
constraint_list: List[List[List[int]]],
eos_tokens: List[int]) -> None:
self.clauses = []
for idx, clause in enumerate(constraint_list):
self.clauses.append(Clause(idx=idx, phrases=clause))
self.eos_tokens = eos_tokens
def __len__(self) -> int:
"""
:return: The number of constraints.
"""
return len(self.clauses)
def __str__(self) -> str:
return '\n'.join([str(c) for c in self.clauses])
def num_met(self) -> int:
"""
:return: the number of constraints that have been met.
"""
return sum([int(c.satisfy) for c in self.clauses])
def advance(self, word_id: int) -> 'ConstrainedHypothesis':
obj = pickle.loads(pickle.dumps(self))
for clause in obj.clauses:
if clause.satisfy:
continue
clause.advance(word_id)
return obj
def avoid(self) -> Set[int]:
"""
:return: the tokens to avoid for next generation
"""
allowed_token, avoid_token = set(), set()
unsatisfied_clauses = [c for c in self.clauses if not c.satisfy]
sorted_clauses = sorted(unsatisfied_clauses, key=lambda x: x.idx)
for j, clause in enumerate(sorted_clauses):
assert not clause.satisfy
for literal in clause.literals:
assert literal.pointer < len(literal.tokens) - 1 and not literal.satisfy
tokens = {literal.tokens[literal.pointer + 1], literal.tokens[0]}
if j == 0:
allowed_token.update(tokens)
else:
avoid_token.update(tokens)
negative_token = {t for t in avoid_token if t not in allowed_token}
if self.eos_tokens is not None and not all(c.satisfy for c in self.clauses):
negative_token.update(self.eos_tokens)
return negative_token
def init_batch(raw_constraints: List[List[List[List[int]]]],
eos_tokens: List[int]) -> List[Optional[ConstrainedHypothesis]]:
"""
:param raw_constraints: The list of clause constraints.
:param beam_size: The beam size.
:param eos_id: The target-language vocabulary ID of the EOS symbol.
:param ordered: Whether enforce constraints to be satisfied in given order
:return: A list of ConstrainedHypothesis objects (shape: (batch_size * beam_size,)).
"""
constraints_list = [None] * len(raw_constraints) # type: List[Optional[ConstrainedHypothesis]]
for i, raw_list in enumerate(raw_constraints):
constraints_list[i] = ConstrainedHypothesis(raw_list, eos_tokens)
return constraints_list