Skip to content

Commit

Permalink
Code cleanup and simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
dnicolodi committed Mar 31, 2022
1 parent af8ee98 commit 9f38a83
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions beanquery/query_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,19 +368,19 @@ def get_function(self, name, operands):
Args:
name: A string, the name of the function to access.
"""
try:
key = tuple([name] + [operand.dtype for operand in operands])
return self.functions[key](operands)
except KeyError:
# If not found with the operands, try just looking it up by name.
try:
return self.functions[name](operands)
except KeyError as exc:
signature = '{}({})'.format(name,
', '.join(operand.dtype.__name__
for operand in operands))
raise CompilationError('Unknown function "{}" in {}'.format(
signature, self.context_name)) from exc

key = tuple([name] + [operand.dtype for operand in operands])
func = self.functions.get(key)
if func is not None:
return func(operands)

# If not found with the operands, try just looking it up by name.
func = self.functions.get(name)
if func is not None:
return func(operands)

sig = '{}({})'.format(name, ', '.join(operand.dtype.__name__ for operand in operands))
raise CompilationError('Unknown function "{sig}" in {self.context_name}')


class AttributeColumn(EvalColumn):
Expand Down Expand Up @@ -567,23 +567,19 @@ def compile_targets(targets, environ):
target_names.add(name)
c_targets.append(EvalTarget(c_expr, name, is_aggregate(c_expr)))

# Figure out if this query is an aggregate query and check validity of each
# target's aggregation type.
for index, c_target in enumerate(c_targets):
columns, aggregates = get_columns_and_aggregates(c_target.c_expr)
columns, aggregates = get_columns_and_aggregates(c_expr)

# Check for mixed aggregates and non-aggregates.
if columns and aggregates:
raise CompilationError(
"Mixed aggregates and non-aggregates are not allowed")

if aggregates:
# Check for aggregates of aggregates.
for aggregate in aggregates:
for child in aggregate.childnodes():
if is_aggregate(child):
raise CompilationError(
"Aggregates of aggregates are not allowed")
# Check for aggregates of aggregates.
for aggregate in aggregates:
for child in aggregate.childnodes():
if is_aggregate(child):
raise CompilationError(
"Aggregates of aggregates are not allowed")

return c_targets, names

Expand Down

0 comments on commit 9f38a83

Please sign in to comment.