diff --git a/pypika/queries.py b/pypika/queries.py index 2adc1b15..1ec7baa6 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -994,21 +994,24 @@ def orderby(self, *fields: Any, **kwargs: Any) -> "QueryBuilder": @builder def join( - self, item: Union[Table, "QueryBuilder", AliasedQuery, Selectable], how: JoinType = JoinType.inner + self, + item: Union[Table, "QueryBuilder", AliasedQuery, Selectable], + how: JoinType = JoinType.inner, + force_index: Optional[str] = None, ) -> "Joiner": if isinstance(item, Table): - return Joiner(self, item, how, type_label="table") + return Joiner(self, item, how, type_label="table", force_index=force_index) elif isinstance(item, QueryBuilder): if item.alias is None: self._tag_subquery(item) - return Joiner(self, item, how, type_label="subquery") + return Joiner(self, item, how, type_label="subquery", force_index=force_index) elif isinstance(item, AliasedQuery): - return Joiner(self, item, how, type_label="table") + return Joiner(self, item, how, type_label="table", force_index=force_index) elif isinstance(item, Selectable): - return Joiner(self, item, how, type_label="subquery") + return Joiner(self, item, how, type_label="subquery", force_index=force_index) raise ValueError("Cannot join on type '%s'" % type(item)) @@ -1539,12 +1542,17 @@ def _set_sql(self, **kwargs: Any) -> str: class Joiner: def __init__( - self, query: QueryBuilder, item: Union[Table, "QueryBuilder", AliasedQuery], how: JoinType, type_label: str + self, + query: QueryBuilder, + item: Union[Table, "QueryBuilder", AliasedQuery], + how: JoinType, type_label: str, + force_index: Optional[str] = None, ) -> None: self.query = query self.item = item self.how = how self.type_label = type_label + self.force_index = force_index def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> QueryBuilder: if criterion is None: @@ -1553,7 +1561,7 @@ def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> Q "{type} JOIN but was not supplied.".format(type=self.type_label) ) - self.query.do_join(JoinOn(self.item, self.how, criterion, collate)) + self.query.do_join(JoinOn(self.item, self.how, criterion, collate, self.force_index)) return self.query def on_field(self, *fields: Any) -> QueryBuilder: @@ -1567,7 +1575,7 @@ def on_field(self, *fields: Any) -> QueryBuilder: consituent = Field(field, table=self.query._from[0]) == Field(field, table=self.item) criterion = consituent if criterion is None else criterion & consituent - self.query.do_join(JoinOn(self.item, self.how, criterion)) + self.query.do_join(JoinOn(self.item, self.how, criterion, self.force_index)) return self.query def using(self, *fields: Any) -> QueryBuilder: @@ -1579,21 +1587,25 @@ def using(self, *fields: Any) -> QueryBuilder: def cross(self) -> QueryBuilder: """Return cross join""" - self.query.do_join(Join(self.item, JoinType.cross)) + self.query.do_join(Join(self.item, JoinType.cross, self.force_index)) return self.query class Join: - def __init__(self, item: Term, how: JoinType) -> None: + def __init__(self, item: Term, how: JoinType, force_index: Optional[str] = None) -> None: self.item = item self.how = how + self.force_index = force_index def get_sql(self, **kwargs: Any) -> str: sql = "JOIN {table}".format( table=self.item.get_sql(subquery=True, with_alias=True, **kwargs), ) + if self.force_index: + sql = sql + " FORCE INDEX ({index})".format(index=self.force_index) + if self.how.value: return "{type} {join}".format(join=sql, type=self.how.value) return sql @@ -1618,8 +1630,15 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl class JoinOn(Join): - def __init__(self, item: Term, how: JoinType, criteria: QueryBuilder, collate: Optional[str] = None) -> None: - super().__init__(item, how) + def __init__( + self, + item: Term, + how: JoinType, + criteria: QueryBuilder, + collate: Optional[str] = None, + force_index: Optional[str] = None, + ) -> None: + super().__init__(item, how, force_index) self.criterion = criteria self.collate = collate @@ -1661,8 +1680,8 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl class JoinUsing(Join): - def __init__(self, item: Term, how: JoinType, fields: Sequence[Field]) -> None: - super().__init__(item, how) + def __init__(self, item: Term, how: JoinType, fields: Sequence[Field], force_index: Optional[str] = None) -> None: + super().__init__(item, how, force_index) self.fields = fields def get_sql(self, **kwargs: Any) -> str: diff --git a/pypika/tests/test_joins.py b/pypika/tests/test_joins.py index 6e54f883..482bfe19 100644 --- a/pypika/tests/test_joins.py +++ b/pypika/tests/test_joins.py @@ -346,6 +346,20 @@ def test_temporal_join(self): str(query), ) + def test_join_with_force_index(self): + table_a, table_b = Tables("a", "b") + + q1 = ( + Query.from_(table_a) + .select(table_b.ouch) + .join(table_b, force_index='PRIMARY') + .on(table_a.foo == table_b.boo) + ) + + self.assertEqual( + 'SELECT "b"."ouch" FROM "a" JOIN "b" FORCE INDEX (PRIMARY) ON "a"."foo"="b"."boo"', + str(q1), + ) class JoinBehaviorTests(unittest.TestCase): table_abc, table_efg, table_hij, table_klm = Tables("abc", "efg", "hij", "klm")