Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 63 additions & 58 deletions django/db/models/fields/related.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.db import connection, connections, router, transaction
from django.db.backends import utils
from django.db.models import signals
from django.db.models import signals, Q
from django.db.models.fields import (AutoField, Field, IntegerField,
PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
from django.db.models.related import RelatedObject, PathInfo
Expand Down Expand Up @@ -444,14 +444,25 @@ def get_or_create(self, **kwargs):
# remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel_field.null:
def remove(self, *objs):
# If there aren't any objects, there is nothing to do.
if not objs:
return

val = rel_field.get_foreign_related_value(self.instance)

old_ids = set()
for obj in objs:
# Is obj actually part of this descriptor set?
if rel_field.get_local_related_value(obj) == val:
setattr(obj, rel_field.name, None)
obj.save()
old_ids.add(obj.pk)
else:
raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))

db = router.db_for_write(self.model, instance=self.instance)
with transaction.commit_on_success_unless_managed(using=db):
for obj in self.using(db).filter(pk__in=old_ids):
setattr(obj, rel_field.name, None)
obj.save(using=db)
remove.alters_data = True

def clear(self):
Expand Down Expand Up @@ -516,6 +527,7 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical
self.instance = instance
self.symmetrical = symmetrical
self.source_field = source_field
self.target_field = through._meta.get_field(target_field_name)
self.source_field_name = source_field_name
self.target_field_name = target_field_name
self.reverse = reverse
Expand Down Expand Up @@ -552,6 +564,19 @@ def __call__(self, **kwargs):
)
do_not_call_in_templates = True

def _build_clear_filters(self, qs):
filters = Q(**{
self.source_field_name: self.related_val,
'%s__in' % self.target_field_name: qs
})

if self.symmetrical:
filters |= Q(**{
self.target_field_name: self.related_val,
'%s__in' % self.source_field_name: qs
})
return filters

def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
Expand Down Expand Up @@ -605,18 +630,20 @@ def add(self, *objs):

def remove(self, *objs):
self._remove_items(self.source_field_name, self.target_field_name, *objs)

# If this is a symmetrical m2m relation to self, remove the mirror entry in the m2m table
if self.symmetrical:
self._remove_items(self.target_field_name, self.source_field_name, *objs)
remove.alters_data = True

def clear(self):
self._clear_items(self.source_field_name)
db = router.db_for_write(self.through, instance=self.instance)

# If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table
if self.symmetrical:
self._clear_items(self.target_field_name)
signals.m2m_changed.send(sender=self.through, action="pre_clear",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db)
filters = self._build_clear_filters(self.using(db))
self.through._default_manager.using(db).filter(filters).delete()

signals.m2m_changed.send(sender=self.through, action="post_clear",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db)
clear.alters_data = True

def create(self, **kwargs):
Expand Down Expand Up @@ -702,55 +729,33 @@ def _remove_items(self, source_field_name, target_field_name, *objs):
# *objs - objects to remove

# If there aren't any objects, there is nothing to do.
if objs:
# Check that all the objects are of the right type
old_ids = set()
for obj in objs:
if isinstance(obj, self.model):
fk_val = self.through._meta.get_field(
target_field_name).get_foreign_related_value(obj)[0]
old_ids.add(fk_val)
else:
old_ids.add(obj)
# Work out what DB we're operating on
db = router.db_for_write(self.through, instance=self.instance)
# Send a signal to the other end if need be.
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are deleting the
# duplicate data row for symmetrical reverse entries.
signals.m2m_changed.send(sender=self.through, action="pre_remove",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=old_ids, using=db)
# Remove the specified objects from the join table
self.through._default_manager.using(db).filter(**{
source_field_name: self.related_val[0],
'%s__in' % target_field_name: old_ids
}).delete()
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are deleting the
# duplicate data row for symmetrical reverse entries.
signals.m2m_changed.send(sender=self.through, action="post_remove",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=old_ids, using=db)
if not objs:
return

# Check that all the objects are of the right type
old_ids = set()
for obj in objs:
if isinstance(obj, self.model):
fk_val = self.target_field.get_foreign_related_value(obj)[0]
old_ids.add(fk_val)
else:
old_ids.add(obj)

def _clear_items(self, source_field_name):
db = router.db_for_write(self.through, instance=self.instance)
# source_field_name: the PK colname in join table for the source object
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are clearing the
# duplicate data rows for symmetrical reverse entries.
signals.m2m_changed.send(sender=self.through, action="pre_clear",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db)
self.through._default_manager.using(db).filter(**{
source_field_name: self.related_val
}).delete()
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are clearing the
# duplicate data rows for symmetrical reverse entries.
signals.m2m_changed.send(sender=self.through, action="post_clear",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db)

# Send a signal to the other end if need be.
signals.m2m_changed.send(sender=self.through, action="pre_remove",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=old_ids, using=db)

old_vals_qs = self.using(db).filter(**{
'%s__in' % self.target_field.related_field.attname: old_ids})
filters = self._build_clear_filters(old_vals_qs)
self.through._default_manager.using(db).filter(filters).delete()

signals.m2m_changed.send(sender=self.through, action="post_remove",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=old_ids, using=db)

return ManyRelatedManager

Expand Down
117 changes: 117 additions & 0 deletions tests/custom_managers/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def setUp(self):
self.b2 = Book.published_objects.create(
title="How to be smart", author="Albert Einstein", is_published=False)
self.p1 = Person.objects.create(first_name="Bugs", last_name="Bunny", fun=True)
self.fun = self.p1
self.p2 = Person.objects.create(first_name="Droopy", last_name="Dog", fun=False)
self.boring = self.p2

def test_manager(self):
# Test a custom `Manager` method.
Expand Down Expand Up @@ -189,3 +191,118 @@ def test_related_manager_m2m(self):
],
lambda c: c.first_name
)

def test_related_manager_fk_removal(self):
self.fun.favorite_book = self.b1
self.fun.save()
self.boring.favorite_book = self.b1
self.boring.save()

# TODO: Uncomment once FK RelatedManager.remove() is fixed.
## Check that the fun manager DOESN'T remove boring people.
self.b1.favorite_books(manager='fun_people').remove(self.boring)
self.assertQuerysetEqual(
self.b1.favorite_books(manager='boring_people').all(), [
self.boring.first_name,
],
lambda c: c.first_name
)
# Check that the boring manager DOES remove boring people.
self.b1.favorite_books(manager='boring_people').remove(self.boring)
self.assertQuerysetEqual(
self.b1.favorite_books(manager='boring_people').all(), [
],
lambda c: c.first_name
)
self.boring.favorite_book = self.b1
self.boring.save()

# Check that the fun manager ONLY clears fun people.
self.b1.favorite_books(manager='fun_people').clear()
self.assertQuerysetEqual(
self.b1.favorite_books(manager='boring_people').all(), [
self.boring.first_name,
],
lambda c: c.first_name
)
self.assertQuerysetEqual(
self.b1.favorite_books(manager='fun_people').all(), [
],
lambda c: c.first_name
)

def test_related_manager_gfk_removal(self):
self.fun.favorite_thing = self.b1
self.fun.save()
self.boring.favorite_thing = self.b1
self.boring.save()

# Check that the fun manager DOESN'T remove boring people.
self.b1.favorite_things(manager='fun_people').remove(self.boring)
self.assertQuerysetEqual(
self.b1.favorite_things(manager='boring_people').all(), [
self.boring.first_name,
],
lambda c: c.first_name
)

# Check that the boring manager DOES remove boring people.
self.b1.favorite_things(manager='boring_people').remove(self.boring)
self.assertQuerysetEqual(
self.b1.favorite_things(manager='boring_people').all(), [
],
lambda c: c.first_name
)
self.boring.favorite_thing = self.b1
self.boring.save()

# Check that the fun manager ONLY clears fun people.
self.b1.favorite_things(manager='fun_people').clear()
self.assertQuerysetEqual(
self.b1.favorite_things(manager='boring_people').all(), [
self.boring.first_name,
],
lambda c: c.first_name
)
self.assertQuerysetEqual(
self.b1.favorite_things(manager='fun_people').all(), [
],
lambda c: c.first_name
)

def test_related_manager_m2m_removal(self):
self.b1.authors.add(self.fun)
self.b1.authors.add(self.boring)

# Check that the fun manager DOESN'T remove boring people.
self.b1.authors(manager='fun_people').remove(self.boring)
self.assertQuerysetEqual(
self.b1.authors(manager='boring_people').all(), [
self.boring.first_name,
],
lambda c: c.first_name
)

# Check that the boring manager DOES remove boring people.
self.b1.authors(manager='boring_people').remove(self.boring)
self.assertQuerysetEqual(
self.b1.authors(manager='boring_people').all(), [
],
lambda c: c.first_name
)
self.b1.authors.add(self.boring)


# Check that the fun manager ONLY clears fun people.
self.b1.authors(manager='fun_people').clear()
self.assertQuerysetEqual(
self.b1.authors(manager='boring_people').all(), [
self.boring.first_name,
],
lambda c: c.first_name
)
self.assertQuerysetEqual(
self.b1.authors(manager='fun_people').all(), [
],
lambda c: c.first_name
)