diff --git a/pythonpattern/.project b/pythonpattern/.project new file mode 100644 index 00000000..7a840801 --- /dev/null +++ b/pythonpattern/.project @@ -0,0 +1,17 @@ + + + pythonpattern + + + + + + org.python.pydev.PyDevBuilder + + + + + + org.python.pydev.pythonNature + + diff --git a/pythonpattern/.pydevproject b/pythonpattern/.pydevproject new file mode 100644 index 00000000..0e9144a8 --- /dev/null +++ b/pythonpattern/.pydevproject @@ -0,0 +1,8 @@ + + + +/pythonpattern/src + +python 2.7 +Default + diff --git a/pythonpattern/src/abstract_factory.py b/pythonpattern/src/abstract_factory.py new file mode 100644 index 00000000..32587b8d --- /dev/null +++ b/pythonpattern/src/abstract_factory.py @@ -0,0 +1,64 @@ +# http://ginstrom.com/scribbles/2007/10/08/design-patterns-python-style/ +#Stest +"""Implementation of the abstract factory pattern""" + +import random + +class PetShop: + """A pet shop""" + + def __init__(self, animal_factory=None): + """pet_factory is our abstract factory. + We can set it at will.""" + + self.pet_factory = animal_factory + + def show_pet(self): + """Creates and shows a pet using the + abstract factory""" + + pet = self.pet_factory.get_pet() + print("This is a lovely", pet) + print("It says", pet.speak()) + print("It eats", self.pet_factory.get_food()) + +# Stuff that our factory makes + +class Dog: + def speak(self): + return "woof" + def __str__(self): + return "Dog" + +class Cat: + def speak(self): + return "meow" + def __str__(self): + return "Cat" + +# Factory classes + +class DogFactory: + def get_pet(self): + return Dog() + def get_food(self): + return "dog food" + +class CatFactory: + def get_pet(self): + return Cat() + def get_food(self): + return "cat food" + +# Create the proper family +def get_factory(): + """Let's be dynamic!""" + return random.choice([DogFactory, CatFactory])() + +# Show pets with various factories +if __name__== "__main__": + shop = PetShop() + for i in range(3): + shop.pet_factory = get_factory() + shop.show_pet() + print("=" * 20) diff --git a/pythonpattern/src/adapter.py b/pythonpattern/src/adapter.py new file mode 100644 index 00000000..cdbb241a --- /dev/null +++ b/pythonpattern/src/adapter.py @@ -0,0 +1,65 @@ +# http://ginstrom.com/scribbles/2008/11/06/generic-adapter-class-in-python/ + +import os + +class Dog(object): + def __init__(self): + self.name = "Dog" + + def bark(self): + return "woof!" + +class Cat(object): + def __init__(self): + self.name = "Cat" + + def meow(self): + return "meow!" + +class Human(object): + def __init__(self): + self.name = "Human" + + def speak(self): + return "'hello'" + +class Car(object): + def __init__(self): + self.name = "Car" + + def make_noise(self, octane_level): + return "vroom%s" % ("!" * octane_level) + +class Adapter(object): + """ + Adapts an object by replacing methods. + Usage: + dog = Dog + dog = Adapter(dog, dict(make_noise=dog.bark)) + """ + def __init__(self, obj, adapted_methods): + """We set the adapted methods in the object's dict""" + self.obj = obj + self.__dict__.update(adapted_methods) + + def __getattr__(self, attr): + """All non-adapted calls are passed to the object""" + return getattr(self.obj, attr) + +def main(): + objects = [] + dog = Dog() + objects.append(Adapter(dog, dict(make_noise=dog.bark))) + cat = Cat() + objects.append(Adapter(cat, dict(make_noise=cat.meow))) + human = Human() + objects.append(Adapter(human, dict(make_noise=human.speak))) + car = Car() + car_noise = lambda : car.make_noise(3) + objects.append(Adapter(car, dict(make_noise=car_noise))) + + for obj in objects: + print("A", obj.name, "goes", obj.make_noise()) + +if __name__ == "__main__": + main() diff --git a/pythonpattern/src/borg.py b/pythonpattern/src/borg.py new file mode 100644 index 00000000..7ba7a2cd --- /dev/null +++ b/pythonpattern/src/borg.py @@ -0,0 +1,36 @@ +class Borg: + __shared_state = {} + + def __init__(self): + self.__dict__ = self.__shared_state + + def __str__(self): + return self.state + +class YourBorg(Borg): + pass + +if __name__ == '__main__': + rm1 = Borg() + rm2 = Borg() + + rm1.state = 'Idle' + rm2.state = 'Running' + + print('rm1:', rm1) + print('rm2:', rm2) + + rm2.state = 'Zombie' + + print('rm1:', rm1) + print('rm2:', rm2) + + print('rm1 id:', id(rm1)) + print('rm2 id:', id(rm2)) + + rm3 = YourBorg() + + print('rm1:', rm1) + print('rm2:', rm2) + print('rm3:', rm3) + diff --git a/pythonpattern/src/bridge.py b/pythonpattern/src/bridge.py new file mode 100644 index 00000000..6c147be2 --- /dev/null +++ b/pythonpattern/src/bridge.py @@ -0,0 +1,40 @@ +# http://en.wikibooks.org/wiki/Computer_Science_Design_Patterns/Bridge_Pattern#Python + +# ConcreteImplementor 1/2 +class DrawingAPI1: + def drawCircle(self, x, y, radius): + print('API1.circle at {}:{} radius {}'.format(x, y, radius)) + +# ConcreteImplementor 2/2 +class DrawingAPI2: + def drawCircle(self, x, y, radius): + print('API2.circle at {}:{} radius {}'.format(x, y, radius)) + +# Refined Abstraction +class CircleShape: + def __init__(self, x, y, radius, drawingAPI): + self.__x = x + self.__y = y + self.__radius = radius + self.__drawingAPI = drawingAPI + + # low-level i.e. Implementation specific + def draw(self): + self.__drawingAPI.drawCircle(self.__x, self.__y, self.__radius) + + # high-level i.e. Abstraction specific + def resizeByPercentage(self, pct): + self.__radius *= pct + +def main(): + shapes = ( + CircleShape(1, 2, 3, DrawingAPI1()), + CircleShape(5, 7, 11, DrawingAPI2()) + ) + + for shape in shapes: + shape.resizeByPercentage(2.5) + shape.draw() + +if __name__ == "__main__": + main() diff --git a/pythonpattern/src/builder.py b/pythonpattern/src/builder.py new file mode 100644 index 00000000..d1436591 --- /dev/null +++ b/pythonpattern/src/builder.py @@ -0,0 +1,64 @@ +#!/usr/bin/python +# -*- coding : utf-8 -*- + +""" + @author: Diogenes Augusto Fernandes Herminio + https://gist.github.com/420905#file_builder_python.py +""" + +# Director +class Director(object): + def __init__(self): + self.builder = None + + def construct_building(self): + self.builder.new_building() + self.builder.build_floor() + self.builder.build_size() + + def get_building(self): + return self.builder.building + +# Abstract Builder +class Builder(object): + def __init__(self): + self.building = None + + def new_building(self): + self.building = Building() + +# Concrete Builder +class BuilderHouse(Builder): + def build_floor(self): + self.building.floor ='One' + + def build_size(self): + self.building.size = 'Big' + +class BuilderFlat(Builder): + def build_floor(self): + self.building.floor ='More than One' + + def build_size(self): + self.building.size = 'Small' + +# Product +class Building(object): + def __init__(self): + self.floor = None + self.size = None + + def __repr__(self): + return 'Floor: %s | Size: %s' % (self.floor, self.size) + +# Client +if __name__== "__main__": + director = Director() + director.builder = BuilderHouse() + director.construct_building() + building = director.get_building() + print(building) + director.builder = BuilderFlat() + director.construct_building() + building = director.get_building() + print(building) diff --git a/pythonpattern/src/chain.py b/pythonpattern/src/chain.py new file mode 100644 index 00000000..fbbda422 --- /dev/null +++ b/pythonpattern/src/chain.py @@ -0,0 +1,42 @@ +# http://www.testingperspective.com/wiki/doku.php/collaboration/chetan/designpatternsinpython/chain-of-responsibilitypattern + +class Handler: + def successor(self, successor): + self.successor = successor + +class ConcreteHandler1(Handler): + def handle(self, request): + if request > 0 and request <= 10: + print("in handler1") + else: + self.successor.handle(request) + +class ConcreteHandler2(Handler): + def handle(self, request): + if request > 10 and request <= 20: + print("in handler2") + else: + self.successor.handle(request) + +class ConcreteHandler3(Handler): + def handle(self, request): + if request > 20 and request <= 30: + print("in handler3") + else: + print('end of chain, no handler for {}'.format(request)) + +class Client: + def __init__(self): + h1 = ConcreteHandler1() + h2 = ConcreteHandler2() + h3 = ConcreteHandler3() + + h1.successor(h2) + h2.successor(h3) + + requests = [2, 5, 14, 22, 18, 3, 35, 27, 20] + for request in requests: + h1.handle(request) + +if __name__== "__main__": + client = Client() diff --git a/pythonpattern/src/command.py b/pythonpattern/src/command.py new file mode 100644 index 00000000..2499ec3c --- /dev/null +++ b/pythonpattern/src/command.py @@ -0,0 +1,36 @@ +import os + +class MoveFileCommand(object): + def __init__(self, src, dest): + self.src = src + self.dest = dest + + def execute(self): + self() + + def __call__(self): + print('renaming {} to {}'.format(self.src, self.dest)) + os.rename(self.src, self.dest) + + def undo(self): + print('renaming {} to {}'.format(self.dest, self.src)) + os.rename(self.dest, self.src) + + +if __name__ == "__main__": + undo_stack = [] + ren1 = MoveFileCommand('foo.txt', 'bar.txt') + ren2 = MoveFileCommand('bar.txt', 'baz.txt') + + # commands are just pushed into the command stack + for cmd in ren1, ren2: + undo_stack.append(cmd) + + # they can be executed later on will + for cmd in undo_stack: + cmd.execute() # foo.txt is now renamed to baz.txt + + # and can also be undone on will + for cmd in undo_stack: + undo_stack.pop().undo() # Now it's bar.txt + undo_stack.pop().undo() # and back to foo.txt diff --git a/pythonpattern/src/composite.py b/pythonpattern/src/composite.py new file mode 100644 index 00000000..f28fcbaf --- /dev/null +++ b/pythonpattern/src/composite.py @@ -0,0 +1,326 @@ +""" +A class which defines a composite object which can store +hieararchical dictionaries with names. + +This class is same as a hiearchical dictionary, but it +provides methods to add/access/modify children by name, +like a Composite. + +Created Anand B Pillai + +""" +__author__ = "Anand B Pillai" +__maintainer__ = "Anand B Pillai" +__version__ = "0.2" + + +def normalize(val): + """ Normalize a string so that it can be used as an attribute + to a Python object """ + + if val.find('-') != -1: + val = val.replace('-','_') + + return val + +def denormalize(val): + """ De-normalize a string """ + + if val.find('_') != -1: + val = val.replace('_','-') + + return val + +class SpecialDict(dict): + """ A dictionary type which allows direct attribute + access to its keys """ + + def __getattr__(self, name): + + if name in self.__dict__: + return self.__dict__[name] + elif name in self: + return self.get(name) + else: + # Check for denormalized name + name = denormalize(name) + if name in self: + return self.get(name) + else: + raise AttributeError('no attribute named %s' % name) + + def __setattr__(self, name, value): + + if name in self.__dict__: + self.__dict__[name] = value + elif name in self: + self[name] = value + else: + # Check for denormalized name + name2 = denormalize(name) + if name2 in self: + self[name2] = value + else: + # New attribute + self[name] = value + +class CompositeDict(SpecialDict): + """ A class which works like a hierarchical dictionary. + This class is based on the Composite design-pattern """ + + ID = 0 + + def __init__(self, name=''): + + if name: + self._name = name + else: + self._name = ''.join(('id#',str(self.__class__.ID))) + self.__class__.ID += 1 + + self._children = [] + # Link back to father + self._father = None + self[self._name] = SpecialDict() + + def __getattr__(self, name): + + if name in self.__dict__: + return self.__dict__[name] + elif name in self: + return self.get(name) + else: + # Check for denormalized name + name = denormalize(name) + if name in self: + return self.get(name) + else: + # Look in children list + child = self.findChild(name) + if child: + return child + else: + attr = getattr(self[self._name], name) + if attr: return attr + + raise AttributeError('no attribute named %s' % name) + + def isRoot(self): + """ Return whether I am a root component or not """ + + # If I don't have a parent, I am root + return not self._father + + def isLeaf(self): + """ Return whether I am a leaf component or not """ + + # I am a leaf if I have no children + return not self._children + + def getName(self): + """ Return the name of this ConfigInfo object """ + + return self._name + + def getIndex(self, child): + """ Return the index of the child ConfigInfo object 'child' """ + + if child in self._children: + return self._children.index(child) + else: + return -1 + + def getDict(self): + """ Return the contained dictionary """ + + return self[self._name] + + def getProperty(self, child, key): + """ Return the value for the property for child + 'child' with key 'key' """ + + # First get the child's dictionary + childDict = self.getInfoDict(child) + if childDict: + return childDict.get(key, None) + + def setProperty(self, child, key, value): + """ Set the value for the property 'key' for + the child 'child' to 'value' """ + + # First get the child's dictionary + childDict = self.getInfoDict(child) + if childDict: + childDict[key] = value + + def getChildren(self): + """ Return the list of immediate children of this object """ + + return self._children + + def getAllChildren(self): + """ Return the list of all children of this object """ + + l = [] + for child in self._children: + l.append(child) + l.extend(child.getAllChildren()) + + return l + + def getChild(self, name): + """ Return the immediate child object with the given name """ + + for child in self._children: + if child.getName() == name: + return child + + def findChild(self, name): + """ Return the child with the given name from the tree """ + + # Note - this returns the first child of the given name + # any other children with similar names down the tree + # is not considered. + + for child in self.getAllChildren(): + if child.getName() == name: + return child + + def findChildren(self, name): + """ Return a list of children with the given name from the tree """ + + # Note: this returns a list of all the children of a given + # name, irrespective of the depth of look-up. + + children = [] + + for child in self.getAllChildren(): + if child.getName() == name: + children.append(child) + + return children + + def getPropertyDict(self): + """ Return the property dictionary """ + + d = self.getChild('__properties') + if d: + return d.getDict() + else: + return {} + + def getParent(self): + """ Return the person who created me """ + + return self._father + + def __setChildDict(self, child): + """ Private method to set the dictionary of the child + object 'child' in the internal dictionary """ + + d = self[self._name] + d[child.getName()] = child.getDict() + + def setParent(self, father): + """ Set the parent object of myself """ + + # This should be ideally called only once + # by the father when creating the child :-) + # though it is possible to change parenthood + # when a new child is adopted in the place + # of an existing one - in that case the existing + # child is orphaned - see addChild and addChild2 + # methods ! + self._father = father + + def setName(self, name): + """ Set the name of this ConfigInfo object to 'name' """ + + self._name = name + + def setDict(self, d): + """ Set the contained dictionary """ + + self[self._name] = d.copy() + + def setAttribute(self, name, value): + """ Set a name value pair in the contained dictionary """ + + self[self._name][name] = value + + def getAttribute(self, name): + """ Return value of an attribute from the contained dictionary """ + + return self[self._name][name] + + def addChild(self, name, force=False): + """ Add a new child 'child' with the name 'name'. + If the optional flag 'force' is set to True, the + child object is overwritten if it is already there. + + This function returns the child object, whether + new or existing """ + + if type(name) != str: + raise ValueError('Argument should be a string!') + + child = self.getChild(name) + if child: + # print 'Child %s present!' % name + # Replace it if force==True + if force: + index = self.getIndex(child) + if index != -1: + child = self.__class__(name) + self._children[index] = child + child.setParent(self) + + self.__setChildDict(child) + return child + else: + child = self.__class__(name) + child.setParent(self) + + self._children.append(child) + self.__setChildDict(child) + + return child + + def addChild2(self, child): + """ Add the child object 'child'. If it is already present, + it is overwritten by default """ + + currChild = self.getChild(child.getName()) + if currChild: + index = self.getIndex(currChild) + if index != -1: + self._children[index] = child + child.setParent(self) + # Unset the existing child's parent + currChild.setParent(None) + del currChild + + self.__setChildDict(child) + else: + child.setParent(self) + self._children.append(child) + self.__setChildDict(child) + +if __name__=="__main__": + window = CompositeDict('Window') + frame = window.addChild('Frame') + tfield = frame.addChild('Text Field') + tfield.setAttribute('size','20') + + btn = frame.addChild('Button1') + btn.setAttribute('label','Submit') + + btn = frame.addChild('Button2') + btn.setAttribute('label','Browse') + + # print(window) + # print(window.Frame) + # print(window.Frame.Button1) + # print(window.Frame.Button2) + print(window.Frame.Button1.label) + print(window.Frame.Button2.label) diff --git a/pythonpattern/src/decorator.py b/pythonpattern/src/decorator.py new file mode 100644 index 00000000..df666440 --- /dev/null +++ b/pythonpattern/src/decorator.py @@ -0,0 +1,21 @@ +# http://stackoverflow.com/questions/3118929/implementing-the-decorator-pattern-in-python + +class foo(object): + def f1(self): + print("original f1") + def f2(self): + print("original f2") + +class foo_decorator(object): + def __init__(self, decoratee): + self._decoratee = decoratee + def f1(self): + print("decorated f1") + self._decoratee.f1() + def __getattr__(self, name): + return getattr(self._decoratee, name) + +u = foo() +v = foo_decorator(u) +v.f1() +v.f2() diff --git a/pythonpattern/src/facade.py b/pythonpattern/src/facade.py new file mode 100644 index 00000000..c7088102 --- /dev/null +++ b/pythonpattern/src/facade.py @@ -0,0 +1,58 @@ +'''http://dpip.testingperspective.com/?p=26''' + +import time + +SLEEP = 0.5 + +# Complex Parts +class TC1: + def run(self): + print("###### In Test 1 ######") + time.sleep(SLEEP) + print("Setting up") + time.sleep(SLEEP) + print("Running test") + time.sleep(SLEEP) + print("Tearing down") + time.sleep(SLEEP) + print("Test Finished\n") + +class TC2: + def run(self): + print("###### In Test 2 ######") + time.sleep(SLEEP) + print("Setting up") + time.sleep(SLEEP) + print("Running test") + time.sleep(SLEEP) + print("Tearing down") + time.sleep(SLEEP) + print("Test Finished\n") + +class TC3: + def run(self): + print("###### In Test 3 ######") + time.sleep(SLEEP) + print("Setting up") + time.sleep(SLEEP) + print("Running test") + time.sleep(SLEEP) + print("Tearing down") + time.sleep(SLEEP) + print("Test Finished\n") + +# Facade +class TestRunner: + def __init__(self): + self.tc1 = TC1() + self.tc2 = TC2() + self.tc3 = TC3() + self.tests = [i for i in (self.tc1, self.tc2, self.tc3)] + + def runAll(self): + [i.run() for i in self.tests] + +# Client +if __name__ == '__main__': + testrunner = TestRunner() + testrunner.runAll() diff --git a/pythonpattern/src/factory_method.py b/pythonpattern/src/factory_method.py new file mode 100644 index 00000000..9127fd4a --- /dev/null +++ b/pythonpattern/src/factory_method.py @@ -0,0 +1,30 @@ +#encoding=utf-8 +'''http://ginstrom.com/scribbles/2007/10/08/design-patterns-python-style/''' + +class GreekGetter: + """A simple localizer a la gettext""" + def __init__(self): + self.trans = dict(dog="σκύλος", cat="γάτα") + + def get(self, msgid): + """We'll punt if we don't have a translation""" + try: + return self.trans[msgid] + except KeyError: + return str(msgid) + +class EnglishGetter: + """Simply echoes the msg ids""" + def get(self, msgid): + return str(msgid) + +def get_localizer(language="English"): + """The factory method""" + languages = dict(English=EnglishGetter, Greek=GreekGetter) + return languages[language]() + +# Create our localizers +e, g = get_localizer("English"), get_localizer("Greek") +# Localize some text +for msgid in "dog parrot cat bear".split(): + print(e.get(msgid), g.get(msgid)) diff --git a/pythonpattern/src/flyweight.py b/pythonpattern/src/flyweight.py new file mode 100644 index 00000000..206e8fae --- /dev/null +++ b/pythonpattern/src/flyweight.py @@ -0,0 +1,31 @@ +'''http://codesnipers.com/?q=python-flyweights''' + +import weakref + +class Card(object): + '''The object pool. Has builtin reference counting''' + _CardPool = weakref.WeakValueDictionary() + + '''Flyweight implementation. If the object exists in the + pool just return it (instead of creating a new one)''' + def __new__(cls, value, suit): + obj = Card._CardPool.get(value + suit, None) + if not obj: + obj = object.__new__(cls) + Card._CardPool[value + suit] = obj + obj.value, obj.suit = value, suit + return obj + + # def __init__(self, value, suit): + # self.value, self.suit = value, suit + + def __repr__(self): + return "" % (self.value, self.suit) + +if __name__ == '__main__': + # comment __new__ and uncomment __init__ to see the difference + c1 = Card('9', 'h') + c2 = Card('9', 'h') + print(c1, c2) + print(c1 == c2) + print(id(c1), id(c2)) diff --git a/pythonpattern/src/iterator.py b/pythonpattern/src/iterator.py new file mode 100644 index 00000000..cf680c02 --- /dev/null +++ b/pythonpattern/src/iterator.py @@ -0,0 +1,25 @@ +'''http://ginstrom.com/scribbles/2007/10/08/design-patterns-python-style/''' + +"""Implementation of the iterator pattern with a generator""" +def count_to(count): + """Counts by word numbers, up to a maximum of five""" + numbers = ["one", "two", "three", "four", "five"] + # enumerate() returns a tuple containing a count (from start which defaults to 0) and the values obtained from iterating over sequence + for pos, number in enumerate(numbers): + yield number + +# Test the generator +count_to_two = lambda : count_to(2) +count_to_five = lambda : count_to(5) + +print('Counting to two...') +for number in count_to_two(): + print(number, end=' ') + +print() + +print('Counting to five...') +for number in count_to_five(): + print(number, end=' ') + +print() diff --git a/pythonpattern/src/mediator.py b/pythonpattern/src/mediator.py new file mode 100644 index 00000000..57541557 --- /dev/null +++ b/pythonpattern/src/mediator.py @@ -0,0 +1,111 @@ +'''http://dpip.testingperspective.com/?p=28''' + +import time + +class TC: + def __init__(self): + self._tm = tm + self._bProblem = 0 + + def setup(self): + print("Setting up the Test") + time.sleep(1) + self._tm.prepareReporting() + + def execute(self): + if not self._bProblem: + print("Executing the test") + time.sleep(1) + else: + print("Problem in setup. Test not executed.") + + def tearDown(self): + if not self._bProblem: + print("Tearing down") + time.sleep(1) + self._tm.publishReport() + else: + print("Test not executed. No tear down required.") + + def setTM(self,TM): + self._tm = tm + + def setProblem(self, value): + self._bProblem = value + +class Reporter: + def __init__(self): + self._tm = None + + def prepare(self): + print("Reporter Class is preparing to report the results") + time.sleep(1) + + def report(self): + print("Reporting the results of Test") + time.sleep(1) + + def setTM(self,TM): + self._tm = tm + +class DB: + def __init__(self): + self._tm = None + + def insert(self): + print("Inserting the execution begin status in the Database") + time.sleep(1) + #Following code is to simulate a communication from DB to TC + import random + if random.randrange(1,4) == 3: + return -1 + + def update(self): + print("Updating the test results in Database") + time.sleep(1) + + def setTM(self,TM): + self._tm = tm + +class TestManager: + def __init__(self): + self._reporter = None + self._db = None + self._tc = None + + def prepareReporting(self): + rvalue = self._db.insert() + if rvalue == -1: + self._tc.setProblem(1) + self._reporter.prepare() + + def setReporter(self, reporter): + self._reporter = reporter + + def setDB(self, db): + self._db = db + + def publishReport(self): + self._db.update() + rvalue = self._reporter.report() + + def setTC(self,tc): + self._tc = tc + +if __name__ == '__main__': + reporter = Reporter() + db = DB() + tm = TestManager() + tm.setReporter(reporter) + tm.setDB(db) + reporter.setTM(tm) + db.setTM(tm) + # For simplification we are looping on the same test. + # Practically, it could be about various unique test classes and their objects + while (True): + tc = TC() + tc.setTM(tm) + tm.setTC(tc) + tc.setup() + tc.execute() + tc.tearDown() diff --git a/pythonpattern/src/memento.py b/pythonpattern/src/memento.py new file mode 100644 index 00000000..d09e266c --- /dev/null +++ b/pythonpattern/src/memento.py @@ -0,0 +1,91 @@ +'''code.activestate.com/recipes/413838-memento-closure/''' + +import copy + +def Memento(obj, deep=False): + state = (copy.copy, copy.deepcopy)[bool(deep)](obj.__dict__) + + def Restore(): + obj.__dict__.clear() + obj.__dict__.update(state) + return Restore + +class Transaction: + """A transaction guard. This is really just + syntactic suggar arount a memento closure. + """ + deep = False + + def __init__(self, *targets): + self.targets = targets + self.Commit() + + def Commit(self): + self.states = [Memento(target, self.deep) for target in self.targets] + + def Rollback(self): + for st in self.states: + st() + +class transactional(object): + """Adds transactional semantics to methods. Methods decorated + with @transactional will rollback to entry state upon exceptions. + """ + def __init__(self, method): + self.method = method + + def __get__(self, obj, T): + def transaction(*args, **kwargs): + state = Memento(obj) + try: + return self.method(obj, *args, **kwargs) + except: + state() + raise + return transaction + + +class NumObj(object): + def __init__(self, value): + self.value = value + + def __repr__(self): + return '<%s: %r>' % (self.__class__.__name__, self.value) + + def Increment(self): + self.value += 1 + + @transactional + def DoStuff(self): + self.value = '1111' # <- invalid value + self.Increment() # <- will fail and rollback + + +if __name__ == '__main__': + n = NumObj(-1) + print(n) + t = Transaction(n) + try: + for i in range(3): + n.Increment() + print(n) + t.Commit() + print('-- commited') + for i in range(3): + n.Increment() + print(n) + n.value += 'x' # will fail + print(n) + except: + t.Rollback() + print('-- rolled back') + print(n) + print('-- now doing stuff ...') + try: + n.DoStuff() + except: + print('-> doing stuff failed!') + import traceback + traceback.print_exc(0) + pass + print(n)