diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 56ead9f9a1b..c5e119f5afc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,7 +29,7 @@ jobs: services: postgres: - image: postgres:17 + image: postgres:latest env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres diff --git a/cms/models/placeholdermodel.py b/cms/models/placeholdermodel.py index 2db2faa59e1..f8dcceab7ef 100644 --- a/cms/models/placeholdermodel.py +++ b/cms/models/placeholdermodel.py @@ -3,7 +3,7 @@ from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType -from django.db import connection, models +from django.db import connection, models, transaction from django.template.defaultfilters import title from django.utils.encoding import force_str from django.utils.translation import gettext_lazy as _ @@ -684,17 +684,22 @@ def delete_plugin(self, instance): :param instance: Plugin to add. It's position parameter needs to be set. :type instance: :class:`cms.models.pluginmodel.CMSPlugin` instance """ - instance.get_descendants().delete() - instance.delete() - last_plugin = self.get_last_plugin(instance.language) - - if last_plugin: - self._shift_plugin_positions( - instance.language, - start=instance.position, - offset=last_plugin.position, - ) - self._recalculate_plugin_positions(instance.language) + with transaction.atomic(): + # We're using raw sql - make the whole operation atomic + plugins = self.get_plugins(language=instance.language).count() # 1st hit: Count plugins + descendants = instance._get_descendants_ids() # 2nd hit: Get descendant ids + to_delete = [instance.pk] + descendants # Instance plus descendants pk + self.cmsplugin_set.filter(pk__in=to_delete).delete() # 3rd hit: Delete all plugins in one query + + last_position = instance.position + len(descendants) # Last position of deleted plugins + if last_position < plugins: + # Close the gap in the plugin tree (2 hits) + self._shift_plugin_positions( + instance.language, + start=instance.position, + offset=plugins, + ) + self._recalculate_plugin_positions(instance.language) def get_last_plugin(self, language): return self.get_plugins(language).last() @@ -756,39 +761,15 @@ def _recalculate_plugin_positions(self, language): cursor = _get_database_cursor('write') db_vendor = _get_database_vendor('write') - if db_vendor == 'sqlite': - sql = ( - 'CREATE TEMPORARY TABLE temp AS ' - 'SELECT ID, (' - 'SELECT COUNT(*)+1 FROM {0} t WHERE ' - 'placeholder_id={0}.placeholder_id AND language={0}.language ' - 'AND {0}.position > t.position' - ') AS new_position ' - 'FROM {0} WHERE placeholder_id=%s AND language=%s' - ) - sql = sql.format(connection.ops.quote_name(CMSPlugin._meta.db_table)) - cursor.execute(sql, [self.pk, language]) - - sql = ( - 'UPDATE {0} ' - 'SET position = (SELECT new_position FROM temp WHERE id={0}.id) ' - 'WHERE placeholder_id=%s AND language=%s' - ) - sql = sql.format(connection.ops.quote_name(CMSPlugin._meta.db_table)) - cursor.execute(sql, [self.pk, language]) - - sql = 'DROP TABLE temp' - sql = sql.format(connection.ops.quote_name(CMSPlugin._meta.db_table)) - cursor.execute(sql) - elif db_vendor == 'postgresql': + if db_vendor in ('sqlite', 'postgresql'): sql = ( 'UPDATE {0} ' - 'SET position = RowNbrs.RowNbr ' + 'SET position = subquery.new_pos ' 'FROM (' - 'SELECT ID, ROW_NUMBER() OVER (ORDER BY position) AS RowNbr ' - 'FROM {0} WHERE placeholder_id=%s AND language=%s ' - ') RowNbrs ' - 'WHERE {0}.id=RowNbrs.id' + ' SELECT ID, ROW_NUMBER() OVER (ORDER BY position, id) AS new_pos ' + ' FROM {0} WHERE placeholder_id=%s AND language=%s ' + ') subquery ' + 'WHERE {0}.id=subquery.id' ) sql = sql.format(connection.ops.quote_name(CMSPlugin._meta.db_table)) cursor.execute(sql, [self.pk, language]) diff --git a/cms/models/pluginmodel.py b/cms/models/pluginmodel.py index f18f6427a78..2fb765cb74e 100644 --- a/cms/models/pluginmodel.py +++ b/cms/models/pluginmodel.py @@ -2,7 +2,7 @@ import os import warnings from datetime import date -from functools import lru_cache +from functools import cache from django.core.exceptions import ObjectDoesNotExist from django.db import connection, connections, models, router @@ -19,7 +19,7 @@ from cms.utils.urlutils import admin_reverse -@lru_cache(maxsize=None) +@cache def _get_descendants_cte(): db_vendor = _get_database_vendor('read') if db_vendor == 'oracle': @@ -60,7 +60,7 @@ def _get_database_cursor(action): return _get_database_connection(action).cursor() -@lru_cache(maxsize=None) +@cache def plugin_supports_cte(): # This has to be as function because when it's a var it evaluates before # db is connected and we get OperationalError. MySQL version is retrieved diff --git a/cms/tests/test_placeholder.py b/cms/tests/test_placeholder.py index 773c2d1b229..6005c3d75a2 100644 --- a/cms/tests/test_placeholder.py +++ b/cms/tests/test_placeholder.py @@ -1378,7 +1378,10 @@ def test_delete(self): for plugin in self.get_plugins().filter(parent__isnull=True): for plugin_id in [plugin.pk] + tree[plugin.pk]: plugin_tree_all.remove(plugin_id) + + plugin.refresh_from_db() self.placeholder.delete_plugin(plugin) + new_tree = self.get_plugins().values_list('pk', 'position') expected = [(pk, pos) for pos, pk in enumerate(plugin_tree_all, 1)] self.assertSequenceEqual(new_tree, expected) @@ -1610,6 +1613,24 @@ def test_move_to_placeholder_bottom(self): class PlaceholderNestedPluginTests(PlaceholderFlatPluginTests): + """ + Same tests as for PlaceholderFlatPluginTests but now with a different plugin tree: + + :: + + Parent 1 + Parent 2 + Child + Parent 1 + Parent 2 + Child + Parent 1 + Parent 2 + Child + Parent 1 + Parent 2 + Child + """ def create_plugins(self, placeholder): for i in range(1, 12, 3): @@ -1654,7 +1675,10 @@ def test_delete_single(self): for plugin in self.get_plugins().filter(parent__isnull=True): for plugin_id in [plugin.pk] + tree[plugin.pk]: plugin_tree_all.remove(plugin_id) + + plugin.refresh_from_db() self.placeholder.delete_plugin(plugin) + new_tree = self.get_plugins().values_list('pk', 'position') expected = [(pk, pos) for pos, pk in enumerate(plugin_tree_all, 1)] self.assertSequenceEqual(new_tree, expected)