Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dry: only apply dry to sequences that request it #860

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions aphrodite/modeling/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,18 +631,26 @@ def _apply_dry(

Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
"""
if torch.all(multipliers == 0):
dry_indices = torch.nonzero(multipliers).squeeze(-1)
if len(dry_indices) == 0:
return logits

# DRY needs to be applied to both input AND output tokens
input_ids = torch.cat((input_token_ids, output_token_ids), dim=1)
vocab_size = logits.size(-1)

def compute_z_array(s: List[int], end: int, search_start: int) -> List[int]:
def compute_z_array_gpu(
s: torch.Tensor, end: int,
search_start: int) -> torch.Tensor:
"""
Compute Z array using two-pointer technique for linear time complexity
GPU-optimized version of Z array computation using two-pointer technique
"""
z = [0] * len(s)
n = len(s)
z = torch.zeros(n, dtype=torch.int64, device=s.device)

if end < search_start:
return z

right = end - 1
left = end - 1

Expand All @@ -662,51 +670,51 @@ def compute_z_array(s: List[int], end: int, search_start: int) -> List[int]:
right -= 1
if left == right:
break
z[right] = min(z[end - (helper - right)], right - left)
z[right] = min(z[end - (helper - right)].item(), right - left)
if left >= search_start and right - z[right] <= left:
break

return z

# Process each sequence in the batch
for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)):
multiplier = multipliers[i].item()
if multiplier == 0:
continue

seq_breakers = set(sequence_breakers_ids[i].tolist())
input_ids_list = input_ids_row.tolist()
last_token = input_ids_list[-1]
# Process only sequences that have DRY enabled
for idx in dry_indices:
input_ids_row = input_ids[idx]
logits_row = logits[idx]
seq_breakers = set(sequence_breakers_ids[idx].tolist())
last_token = input_ids_row[-1].item()

if last_token in seq_breakers:
continue

range_limit = ranges[i].item()
range_limit = ranges[idx].item()
if range_limit == 0:
search_start = 0
else:
search_start = max(0, len(input_ids_list) - range_limit)
search_start = max(0, len(input_ids_row) - range_limit)

# Find max match length based on sequence breakers
max_match_length = 0
MAX_LENGTH = min(len(input_ids_list), 1000) # Prevent overflow
MAX_LENGTH = min(len(input_ids_row), 1000) # Prevent overflow
while (max_match_length < MAX_LENGTH and
input_ids_list[len(input_ids_list) - max_match_length - 1]
input_ids_row[len(input_ids_row) - max_match_length - 1].item()
not in seq_breakers):
max_match_length += 1

z_array = compute_z_array(
input_ids_list, len(input_ids_list) - 1, search_start)

z_array = [min(length, max_match_length) for length in z_array]
z_array = compute_z_array_gpu(
input_ids_row, len(input_ids_row) - 1, search_start)

z_array = torch.minimum(
z_array, torch.tensor(max_match_length, device=z_array.device))

penalties = {}
allowed_length = allowed_lengths[i]
base = bases[i]
allowed_length = allowed_lengths[idx]
base = bases[idx]
multiplier = multipliers[idx].item()

for idx, match_length in enumerate(z_array[:-1]):
for idx2, match_length in enumerate(z_array[:-1]):
match_length = match_length.item() # Convert tensor to int
if match_length >= allowed_length:
next_token = input_ids_list[idx + 1]
next_token = input_ids_row[idx2 + 1].item()
if (next_token >= vocab_size or next_token in
seq_breakers):
continue
Expand Down
Loading