Skip to content

Commit c7bd48c

Browse files
committed
[soc2010/query-refactor] Improved the ListField implementation, and added an EmbeddedModelField.
git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/query-refactor@13564 bcc190cf-cafb-0310-a4f2-bffc1f526a37
1 parent 9b263c6 commit c7bd48c

File tree

6 files changed

+137
-4
lines changed

6 files changed

+137
-4
lines changed

django/contrib/mongodb/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def update(self, result_type):
172172

173173
vals = {}
174174
for field, o, value in self.query.values:
175+
if hasattr(value, 'prepare_database_save'):
176+
value = value.prepare_database_save(field)
177+
else:
178+
value = field.get_db_prep_save(value, connection=self.connection)
175179
if hasattr(value, "evaluate"):
176180
assert value.connector in (value.ADD, value.SUB)
177181
assert not value.negated

django/db/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from django.db.models.fields.files import FileField, ImageField
1414
from django.db.models.fields.related import (ForeignKey, OneToOneField,
1515
ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel)
16-
from django.db.models.fields.structures import ListField
16+
from django.db.models.fields.structures import ListField, EmbeddedModel
1717
from django.db.models import signals
1818

1919
# Admin stages.

django/db/models/fields/structures.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
from django.core.exceptions import ValidationError
2+
from django.db.models.loading import cache
13
from django.db.models.fields import Field
4+
from django.db.models.fields.subclassing import SubfieldBase
25

36

47
class ListField(Field):
8+
__metaclass__ = SubfieldBase
9+
510
def __init__(self, field_type):
611
self.field_type = field_type
7-
super(ListField, self).__init__()
12+
super(ListField, self).__init__(default=[])
813

914
def get_prep_lookup(self, lookup_type, value):
1015
return self.field_type.get_prep_lookup(lookup_type, value)
@@ -19,3 +24,53 @@ def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
1924
return self.field_type.get_db_prep_lookup(
2025
lookup_type, value, connection=connection, prepared=prepared
2126
)
27+
28+
def to_python(self, value):
29+
try:
30+
value = iter(value)
31+
except TypeError:
32+
raise ValidationError("Value should be iterable")
33+
return [
34+
self.field_type.to_python(v)
35+
for v in value
36+
]
37+
38+
39+
class EmbeddedModel(Field):
40+
__metaclass__ = SubfieldBase
41+
42+
def __init__(self, to):
43+
self.to = to
44+
super(EmbeddedModel, self).__init__()
45+
46+
def get_db_prep_save(self, value, connection):
47+
data = {}
48+
if not isinstance(value, self.to):
49+
raise ValidationError("Value must be an instance of %s, got %s "
50+
"instead" % (self.to, value))
51+
if type(value) is not self.to:
52+
data["_cls"] = (value._meta.app_label, value._meta.object_name)
53+
for field in value._meta.fields:
54+
# If the field is a OneToOneField that makes the inheritance link,
55+
# ignore it.
56+
if field.rel and field.rel.parent_link:
57+
continue
58+
data[field.column] = field.get_db_prep_save(
59+
getattr(value, field.name), connection=connection
60+
)
61+
return data
62+
63+
def to_python(self, value):
64+
if isinstance(value, self.to):
65+
return value
66+
try:
67+
value = dict(value)
68+
except TypeError:
69+
raise ValidationError("Value should be a dict")
70+
71+
if "_cls" in value:
72+
cls = cache.get_model(*value.pop("_cls"))
73+
else:
74+
cls = self.to
75+
76+
return cls(**value)

django/utils/encoding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def is_protected_type(obj):
4747
return isinstance(obj, (
4848
types.NoneType,
4949
int, long,
50+
list,
5051
datetime.datetime, datetime.date, datetime.time,
5152
float, Decimal)
5253
)

tests/regressiontests/mongodb/models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,23 @@ class Post(models.Model):
3131
magic_numbers = models.ListField(
3232
models.IntegerField()
3333
)
34+
35+
36+
class Revision(models.Model):
37+
number = models.IntegerField()
38+
content = models.TextField()
39+
40+
41+
class AuthenticatedRevision(Revision):
42+
# This is a really stupid way to add optional authentication, but it serves
43+
# its purpose.
44+
author = models.CharField(max_length=100)
45+
46+
47+
class WikiPage(models.Model):
48+
id = models.NativeAutoField(primary_key=True)
49+
title = models.CharField(max_length=255)
50+
51+
revisions = models.ListField(
52+
models.EmbeddedModel(Revision)
53+
)

tests/regressiontests/mongodb/tests.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from django.core.exceptions import ValidationError
12
from django.db import connection, UnsupportedDatabaseOperation
23
from django.db.models import Count, Sum, F, Q
34
from django.test import TestCase
45

5-
from models import Artist, Group, Post
6+
from models import Artist, Group, Post, WikiPage, Revision, AuthenticatedRevision
67

78

89
class MongoTestCase(TestCase):
@@ -398,6 +399,9 @@ def test_unsupported_ops(self):
398399
)
399400

400401
def test_list_field(self):
402+
p = Post()
403+
self.assertEqual(p.tags, [])
404+
401405
p = Post.objects.create(
402406
title="Django ORM grows MongoDB support",
403407
tags=["python", "django", "mongodb", "web"]
@@ -428,7 +432,7 @@ def test_list_field(self):
428432
lambda p: p.title
429433
)
430434

431-
self.assertRaises(ValueError,
435+
self.assertRaises(ValidationError,
432436
lambda: Post.objects.create(magic_numbers=["a"])
433437
)
434438

@@ -448,3 +452,52 @@ def test_list_field(self):
448452
],
449453
lambda p: p.title,
450454
)
455+
456+
def test_embedded_model(self):
457+
page = WikiPage(title="Django")
458+
page.revisions.append(
459+
Revision(number=1, content="Django is a Python")
460+
)
461+
page.revisions.append(
462+
Revision(number=2, content="Django is a Python web framework.")
463+
)
464+
465+
page.save()
466+
467+
page = WikiPage.objects.get(pk=page.pk)
468+
self.assertEqual(len(page.revisions), 2)
469+
self.assertEqual(
470+
[(r.number, r.content) for r in page.revisions],
471+
[(1, "Django is a Python"), (2, "Django is a Python web framework.")]
472+
)
473+
474+
self.assertEqual(Revision.objects.count(), 0)
475+
476+
self.assertRaises(ValidationError,
477+
lambda: WikiPage.objects.create(title="Python", revisions=14)
478+
)
479+
self.assertRaises(ValidationError,
480+
lambda: WikiPage.objects.create(title="Python", revisions=[14])
481+
)
482+
483+
page = WikiPage.objects.create(title="Python", revisions=[
484+
Revision(number=1, content="Python was created by Guido van Rossum.")
485+
])
486+
page = WikiPage.objects.get(pk=page.pk)
487+
self.assertEqual(len(page.revisions), 1)
488+
489+
page.revisions.append(
490+
AuthenticatedRevision(number=2, content="Python is a trap.", author="Rasmus Lerdorf"),
491+
)
492+
493+
page.save()
494+
self.assertEqual(len(page.revisions), 2)
495+
self.assertEqual(page.revisions[-1].author, "Rasmus Lerdorf")
496+
497+
page = WikiPage.objects.get(pk=page.pk)
498+
self.assertEqual(len(page.revisions), 2)
499+
self.assertTrue(isinstance(page.revisions[-1], AuthenticatedRevision))
500+
self.assertEqual(page.revisions[-1].author, "Rasmus Lerdorf")
501+
502+
page.revisions.append(14)
503+
self.assertRaises(ValidationError, page.save)

0 commit comments

Comments
 (0)