Skip to content

Commit

Permalink
Fix index class argument name not work (#1856)
Browse files Browse the repository at this point in the history
* Fix index class argument `name` not work

* refactor: add `field_names` property to Index class

* Check custom index name in generated schema

* Add `index_name` and `get_sql` back to Index class for aerich
  • Loading branch information
waketzheng authored Jan 22, 2025
1 parent de48e77 commit 948ccdb
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 25 deletions.
5 changes: 4 additions & 1 deletion tests/fields/test_db_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ class TestIndexAliasChar(TestIndexAlias):

class TestModelWithIndexes(test.TestCase):
def test_meta(self):
self.assertEqual(ModelWithIndexes._meta.indexes, [Index(fields=("f1", "f2"))])
self.assertEqual(
ModelWithIndexes._meta.indexes,
[Index(fields=("f1", "f2")), Index(fields=("f3",), name="model_with_indexes__f3")],
)
self.assertTrue(ModelWithIndexes._meta.fields_map["id"].index)
self.assertTrue(ModelWithIndexes._meta.fields_map["indexed"].index)
self.assertTrue(ModelWithIndexes._meta.fields_map["unique_indexed"].unique)
5 changes: 5 additions & 0 deletions tests/schema/test_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ async def test_create_index(self):
sql = self.get_sql("CREATE INDEX")
self.assertIsNotNone(re.search(r"idx_tournament_created_\w+", sql))

async def test_create_index_with_custom_name(self):
await self.init_for("tests.testmodels")
sql = self.get_sql("f3")
self.assertIn("model_with_indexes__f3", sql)

async def test_fk_bad_model_name(self):
with self.assertRaisesRegex(
ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"'
Expand Down
2 changes: 2 additions & 0 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,11 +1059,13 @@ class ModelWithIndexes(Model):
unique_indexed = fields.CharField(max_length=16, unique=True)
f1 = fields.CharField(max_length=16)
f2 = fields.CharField(max_length=16)
f3 = fields.CharField(max_length=16)
u1 = fields.IntField()
u2 = fields.IntField()

class Meta:
indexes = [
Index(fields=["f1", "f2"]),
Index(fields=["f3"], name="model_with_indexes__f3"),
]
unique_together = [("u1", "u2")]
11 changes: 10 additions & 1 deletion tests/utils/test_describe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,16 @@ def test_describe_indexes_serializable(self):

self.assertEqual(
val["indexes"],
[{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}],
[
{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""},
{
"fields": ["f3"],
"expressions": [],
"name": "model_with_indexes__f3",
"type": "",
"extra": "",
},
],
)

def test_describe_indexes_not_serializable(self):
Expand Down
28 changes: 6 additions & 22 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Type, cast

from pypika_tortoise.context import DEFAULT_SQL_CONTEXT

from tortoise.exceptions import ConfigurationError
from tortoise.fields import JSONField, TextField, UUIDField
from tortoise.fields.relational import OneToOneFieldInstance
Expand Down Expand Up @@ -348,31 +346,17 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:

if model._meta.indexes:
for index in model._meta.indexes:
if not isinstance(index, Index):
if isinstance(index, Index):
idx_sql = index.get_sql(self, model, safe)
else:
fields = []
for field in index:
field_object = model._meta.fields_map[field]
fields.append(field_object.source_field or field)
idx_sql = self._get_index_sql(model, fields, safe=safe)

_indexes.append(self._get_index_sql(model, fields, safe=safe))
else:
if index.fields:
fields = [f for f in index.fields]
elif index.expressions:
fields = [
f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})"
for expression in index.expressions
]
else:
raise ConfigurationError(
"At least one field or expression is required to define an index."
)

_indexes.append(
self._get_index_sql(
model, fields, safe=safe, index_type=index.INDEX_TYPE, extra=index.extra
)
)
if idx_sql:
_indexes.append(idx_sql)

field_indexes_sqls = [val for val in list(dict.fromkeys(_indexes)) if val]

Expand Down
37 changes: 36 additions & 1 deletion tortoise/indexes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any, Type

from pypika_tortoise.context import DEFAULT_SQL_CONTEXT
from pypika_tortoise.terms import Term, ValueWrapper

from tortoise.exceptions import ConfigurationError

if TYPE_CHECKING:
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.models import Model


class Index:
INDEX_TYPE = ""
Expand Down Expand Up @@ -46,6 +51,36 @@ def describe(self) -> dict:
"extra": self.extra,
}

def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str:
# This function is required by aerich
return self.name or schema_generator._generate_index_name("idx", model, self.field_names)

def get_sql(
self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool
) -> str:
# This function is required by aerich
return schema_generator._get_index_sql(
model,
self.field_names,
safe,
index_name=self.name,
index_type=self.INDEX_TYPE,
extra=self.extra,
)

@property
def field_names(self) -> list[str]:
if self.fields:
return list(self.fields)
elif self.expressions:
return [
f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" for expression in self.expressions
]
else:
raise ConfigurationError(
"At least one field or expression is required to define an index."
)

def __repr__(self) -> str:
argument = ""
if self.expressions:
Expand Down

0 comments on commit 948ccdb

Please sign in to comment.