diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 8252cce6726e..4f62cfee7caf 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -2,7 +2,7 @@ from django.db import connection, connections, router, transaction from django.db.backends import utils -from django.db.models import signals +from django.db.models import signals, Q from django.db.models.fields import (AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) from django.db.models.related import RelatedObject, PathInfo @@ -444,14 +444,25 @@ def get_or_create(self, **kwargs): # remove() and clear() are only provided if the ForeignKey can have a value of null. if rel_field.null: def remove(self, *objs): + # If there aren't any objects, there is nothing to do. + if not objs: + return + val = rel_field.get_foreign_related_value(self.instance) + + old_ids = set() for obj in objs: # Is obj actually part of this descriptor set? if rel_field.get_local_related_value(obj) == val: - setattr(obj, rel_field.name, None) - obj.save() + old_ids.add(obj.pk) else: raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance)) + + db = router.db_for_write(self.model, instance=self.instance) + with transaction.commit_on_success_unless_managed(using=db): + for obj in self.using(db).filter(pk__in=old_ids): + setattr(obj, rel_field.name, None) + obj.save(using=db) remove.alters_data = True def clear(self): @@ -516,6 +527,7 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical self.instance = instance self.symmetrical = symmetrical self.source_field = source_field + self.target_field = through._meta.get_field(target_field_name) self.source_field_name = source_field_name self.target_field_name = target_field_name self.reverse = reverse @@ -552,6 +564,19 @@ def __call__(self, **kwargs): ) do_not_call_in_templates = True + def _build_clear_filters(self, qs): + filters = Q(**{ + self.source_field_name: self.related_val, + '%s__in' % self.target_field_name: qs + }) + + if self.symmetrical: + filters |= Q(**{ + self.target_field_name: self.related_val, + '%s__in' % self.source_field_name: qs + }) + return filters + def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] @@ -605,18 +630,20 @@ def add(self, *objs): def remove(self, *objs): self._remove_items(self.source_field_name, self.target_field_name, *objs) - - # If this is a symmetrical m2m relation to self, remove the mirror entry in the m2m table - if self.symmetrical: - self._remove_items(self.target_field_name, self.source_field_name, *objs) remove.alters_data = True def clear(self): - self._clear_items(self.source_field_name) + db = router.db_for_write(self.through, instance=self.instance) - # If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table - if self.symmetrical: - self._clear_items(self.target_field_name) + signals.m2m_changed.send(sender=self.through, action="pre_clear", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=None, using=db) + filters = self._build_clear_filters(self.using(db)) + self.through._default_manager.using(db).filter(filters).delete() + + signals.m2m_changed.send(sender=self.through, action="post_clear", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=None, using=db) clear.alters_data = True def create(self, **kwargs): @@ -702,55 +729,33 @@ def _remove_items(self, source_field_name, target_field_name, *objs): # *objs - objects to remove # If there aren't any objects, there is nothing to do. - if objs: - # Check that all the objects are of the right type - old_ids = set() - for obj in objs: - if isinstance(obj, self.model): - fk_val = self.through._meta.get_field( - target_field_name).get_foreign_related_value(obj)[0] - old_ids.add(fk_val) - else: - old_ids.add(obj) - # Work out what DB we're operating on - db = router.db_for_write(self.through, instance=self.instance) - # Send a signal to the other end if need be. - if self.reverse or source_field_name == self.source_field_name: - # Don't send the signal when we are deleting the - # duplicate data row for symmetrical reverse entries. - signals.m2m_changed.send(sender=self.through, action="pre_remove", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=old_ids, using=db) - # Remove the specified objects from the join table - self.through._default_manager.using(db).filter(**{ - source_field_name: self.related_val[0], - '%s__in' % target_field_name: old_ids - }).delete() - if self.reverse or source_field_name == self.source_field_name: - # Don't send the signal when we are deleting the - # duplicate data row for symmetrical reverse entries. - signals.m2m_changed.send(sender=self.through, action="post_remove", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=old_ids, using=db) + if not objs: + return + + # Check that all the objects are of the right type + old_ids = set() + for obj in objs: + if isinstance(obj, self.model): + fk_val = self.target_field.get_foreign_related_value(obj)[0] + old_ids.add(fk_val) + else: + old_ids.add(obj) - def _clear_items(self, source_field_name): db = router.db_for_write(self.through, instance=self.instance) - # source_field_name: the PK colname in join table for the source object - if self.reverse or source_field_name == self.source_field_name: - # Don't send the signal when we are clearing the - # duplicate data rows for symmetrical reverse entries. - signals.m2m_changed.send(sender=self.through, action="pre_clear", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=None, using=db) - self.through._default_manager.using(db).filter(**{ - source_field_name: self.related_val - }).delete() - if self.reverse or source_field_name == self.source_field_name: - # Don't send the signal when we are clearing the - # duplicate data rows for symmetrical reverse entries. - signals.m2m_changed.send(sender=self.through, action="post_clear", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=None, using=db) + + # Send a signal to the other end if need be. + signals.m2m_changed.send(sender=self.through, action="pre_remove", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=old_ids, using=db) + + old_vals_qs = self.using(db).filter(**{ + '%s__in' % self.target_field.related_field.attname: old_ids}) + filters = self._build_clear_filters(old_vals_qs) + self.through._default_manager.using(db).filter(filters).delete() + + signals.m2m_changed.send(sender=self.through, action="post_remove", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=old_ids, using=db) return ManyRelatedManager diff --git a/tests/custom_managers/tests.py b/tests/custom_managers/tests.py index f9a9f33d8785..3321ec724546 100644 --- a/tests/custom_managers/tests.py +++ b/tests/custom_managers/tests.py @@ -13,7 +13,9 @@ def setUp(self): self.b2 = Book.published_objects.create( title="How to be smart", author="Albert Einstein", is_published=False) self.p1 = Person.objects.create(first_name="Bugs", last_name="Bunny", fun=True) + self.fun = self.p1 self.p2 = Person.objects.create(first_name="Droopy", last_name="Dog", fun=False) + self.boring = self.p2 def test_manager(self): # Test a custom `Manager` method. @@ -189,3 +191,118 @@ def test_related_manager_m2m(self): ], lambda c: c.first_name ) + + def test_related_manager_fk_removal(self): + self.fun.favorite_book = self.b1 + self.fun.save() + self.boring.favorite_book = self.b1 + self.boring.save() + + # TODO: Uncomment once FK RelatedManager.remove() is fixed. + ## Check that the fun manager DOESN'T remove boring people. + self.b1.favorite_books(manager='fun_people').remove(self.boring) + self.assertQuerysetEqual( + self.b1.favorite_books(manager='boring_people').all(), [ + self.boring.first_name, + ], + lambda c: c.first_name + ) + # Check that the boring manager DOES remove boring people. + self.b1.favorite_books(manager='boring_people').remove(self.boring) + self.assertQuerysetEqual( + self.b1.favorite_books(manager='boring_people').all(), [ + ], + lambda c: c.first_name + ) + self.boring.favorite_book = self.b1 + self.boring.save() + + # Check that the fun manager ONLY clears fun people. + self.b1.favorite_books(manager='fun_people').clear() + self.assertQuerysetEqual( + self.b1.favorite_books(manager='boring_people').all(), [ + self.boring.first_name, + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.favorite_books(manager='fun_people').all(), [ + ], + lambda c: c.first_name + ) + + def test_related_manager_gfk_removal(self): + self.fun.favorite_thing = self.b1 + self.fun.save() + self.boring.favorite_thing = self.b1 + self.boring.save() + + # Check that the fun manager DOESN'T remove boring people. + self.b1.favorite_things(manager='fun_people').remove(self.boring) + self.assertQuerysetEqual( + self.b1.favorite_things(manager='boring_people').all(), [ + self.boring.first_name, + ], + lambda c: c.first_name + ) + + # Check that the boring manager DOES remove boring people. + self.b1.favorite_things(manager='boring_people').remove(self.boring) + self.assertQuerysetEqual( + self.b1.favorite_things(manager='boring_people').all(), [ + ], + lambda c: c.first_name + ) + self.boring.favorite_thing = self.b1 + self.boring.save() + + # Check that the fun manager ONLY clears fun people. + self.b1.favorite_things(manager='fun_people').clear() + self.assertQuerysetEqual( + self.b1.favorite_things(manager='boring_people').all(), [ + self.boring.first_name, + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.favorite_things(manager='fun_people').all(), [ + ], + lambda c: c.first_name + ) + + def test_related_manager_m2m_removal(self): + self.b1.authors.add(self.fun) + self.b1.authors.add(self.boring) + + # Check that the fun manager DOESN'T remove boring people. + self.b1.authors(manager='fun_people').remove(self.boring) + self.assertQuerysetEqual( + self.b1.authors(manager='boring_people').all(), [ + self.boring.first_name, + ], + lambda c: c.first_name + ) + + # Check that the boring manager DOES remove boring people. + self.b1.authors(manager='boring_people').remove(self.boring) + self.assertQuerysetEqual( + self.b1.authors(manager='boring_people').all(), [ + ], + lambda c: c.first_name + ) + self.b1.authors.add(self.boring) + + + # Check that the fun manager ONLY clears fun people. + self.b1.authors(manager='fun_people').clear() + self.assertQuerysetEqual( + self.b1.authors(manager='boring_people').all(), [ + self.boring.first_name, + ], + lambda c: c.first_name + ) + self.assertQuerysetEqual( + self.b1.authors(manager='fun_people').all(), [ + ], + lambda c: c.first_name + )