Skip to content

Commit 70d5cdb

Browse files
Merge pull request #492 from OddCoincidence/inplace-ops
Support magic methods for in-place operations
2 parents 71f32ee + ddc7da4 commit 70d5cdb

File tree

8 files changed

+346
-70
lines changed

8 files changed

+346
-70
lines changed

tests/.travis-runner.sh

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pip install pipenv
1212
# Build outside of the test runner
1313
if [ $CODE_COVERAGE = "true" ]
1414
then
15+
find . -name '*.gcda' -delete
16+
1517
export CARGO_INCREMENTAL=0
1618
export RUSTFLAGS="-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Zno-landing-pads"
1719

tests/snippets/inplace_ops.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
class InPlace:
2+
def __init__(self, val):
3+
self.val = val
4+
5+
def __ipow__(self, other):
6+
self.val **= other
7+
return self
8+
9+
def __imul__(self, other):
10+
self.val *= other
11+
return self
12+
13+
def __imatmul__(self, other):
14+
# I guess you could think of an int as a 1x1 matrix
15+
self.val *= other
16+
return self
17+
18+
def __itruediv__(self, other):
19+
self.val /= other
20+
return self
21+
22+
def __ifloordiv__(self, other):
23+
self.val //= other
24+
return self
25+
26+
def __imod__(self, other):
27+
self.val %= other
28+
return self
29+
30+
def __iadd__(self, other):
31+
self.val += other
32+
return self
33+
34+
def __isub__(self, other):
35+
self.val -= other
36+
return self
37+
38+
def __ilshift__(self, other):
39+
self.val <<= other
40+
return self
41+
42+
def __irshift__(self, other):
43+
self.val >>= other
44+
return self
45+
46+
def __iand__(self, other):
47+
self.val &= other
48+
return self
49+
50+
def __ixor__(self, other):
51+
self.val ^= other
52+
return self
53+
54+
def __ior__(self, other):
55+
self.val |= other
56+
return self
57+
58+
59+
i = InPlace(2)
60+
i **= 3
61+
assert i.val == 8
62+
63+
i = InPlace(2)
64+
i *= 2
65+
assert i.val == 4
66+
67+
i = InPlace(2)
68+
i @= 2
69+
assert i.val == 4
70+
71+
i = InPlace(1)
72+
i /= 2
73+
assert i.val == 0.5
74+
75+
i = InPlace(1)
76+
i //= 2
77+
assert i.val == 0
78+
79+
i = InPlace(10)
80+
i %= 3
81+
assert i.val == 1
82+
83+
i = InPlace(1)
84+
i += 1
85+
assert i.val == 2
86+
87+
i = InPlace(2)
88+
i -= 1
89+
assert i.val == 1
90+
91+
i = InPlace(2)
92+
i <<= 3
93+
assert i.val == 16
94+
95+
i = InPlace(16)
96+
i >>= 3
97+
assert i.val == 2
98+
99+
i = InPlace(0b010101)
100+
i &= 0b111000
101+
assert i.val == 0b010000
102+
103+
i = InPlace(0b010101)
104+
i ^= 0b111000
105+
assert i.val == 0b101101
106+
107+
i = InPlace(0b010101)
108+
i |= 0b111000
109+
assert i.val == 0b111101

tests/snippets/list.py

+4
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,7 @@
7070
assert not 1 in a
7171

7272
assert_raises(ValueError, lambda: a.remove(10), 'Remove not exist element')
73+
74+
foo = bar = [1]
75+
foo += [2]
76+
assert (foo, bar) == ([1, 2], [1, 2])

vm/src/bytecode.rs

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pub enum Instruction {
7171
},
7272
BinaryOperation {
7373
op: BinaryOperator,
74+
inplace: bool,
7475
},
7576
LoadAttr {
7677
name: String,

vm/src/compile.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ impl Compiler {
608608
self.compile_expression(value)?;
609609

610610
// Perform operation:
611-
self.compile_op(op);
611+
self.compile_op(op, true);
612612
self.compile_store(target)?;
613613
}
614614
ast::Statement::Delete { targets } => {
@@ -757,7 +757,7 @@ impl Compiler {
757757
Ok(())
758758
}
759759

760-
fn compile_op(&mut self, op: &ast::Operator) {
760+
fn compile_op(&mut self, op: &ast::Operator, inplace: bool) {
761761
let i = match op {
762762
ast::Operator::Add => bytecode::BinaryOperator::Add,
763763
ast::Operator::Sub => bytecode::BinaryOperator::Subtract,
@@ -773,7 +773,7 @@ impl Compiler {
773773
ast::Operator::BitXor => bytecode::BinaryOperator::Xor,
774774
ast::Operator::BitAnd => bytecode::BinaryOperator::And,
775775
};
776-
self.emit(Instruction::BinaryOperation { op: i });
776+
self.emit(Instruction::BinaryOperation { op: i, inplace });
777777
}
778778

779779
fn compile_test(
@@ -852,13 +852,14 @@ impl Compiler {
852852
self.compile_expression(b)?;
853853

854854
// Perform operation:
855-
self.compile_op(op);
855+
self.compile_op(op, false);
856856
}
857857
ast::Expression::Subscript { a, b } => {
858858
self.compile_expression(a)?;
859859
self.compile_expression(b)?;
860860
self.emit(Instruction::BinaryOperation {
861861
op: bytecode::BinaryOperator::Subscript,
862+
inplace: false,
862863
});
863864
}
864865
ast::Expression::Unop { op, a } => {

vm/src/frame.rs

+25-11
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,9 @@ impl Frame {
316316
vm.call_method(&dict_obj, "__setitem__", vec![key, value])?;
317317
Ok(None)
318318
}
319-
bytecode::Instruction::BinaryOperation { ref op } => self.execute_binop(vm, op),
319+
bytecode::Instruction::BinaryOperation { ref op, inplace } => {
320+
self.execute_binop(vm, op, *inplace)
321+
}
320322
bytecode::Instruction::LoadAttr { ref name } => self.load_attr(vm, name),
321323
bytecode::Instruction::StoreAttr { ref name } => self.store_attr(vm, name),
322324
bytecode::Instruction::DeleteAttr { ref name } => self.delete_attr(vm, name),
@@ -893,27 +895,39 @@ impl Frame {
893895
&mut self,
894896
vm: &mut VirtualMachine,
895897
op: &bytecode::BinaryOperator,
898+
inplace: bool,
896899
) -> FrameResult {
897900
let b_ref = self.pop_value();
898901
let a_ref = self.pop_value();
899902
let value = match *op {
903+
bytecode::BinaryOperator::Subtract if inplace => vm._isub(a_ref, b_ref),
900904
bytecode::BinaryOperator::Subtract => vm._sub(a_ref, b_ref),
905+
bytecode::BinaryOperator::Add if inplace => vm._iadd(a_ref, b_ref),
901906
bytecode::BinaryOperator::Add => vm._add(a_ref, b_ref),
907+
bytecode::BinaryOperator::Multiply if inplace => vm._imul(a_ref, b_ref),
902908
bytecode::BinaryOperator::Multiply => vm._mul(a_ref, b_ref),
903-
bytecode::BinaryOperator::MatrixMultiply => {
904-
vm.call_method(&a_ref, "__matmul__", vec![b_ref])
905-
}
909+
bytecode::BinaryOperator::MatrixMultiply if inplace => vm._imatmul(a_ref, b_ref),
910+
bytecode::BinaryOperator::MatrixMultiply => vm._matmul(a_ref, b_ref),
911+
bytecode::BinaryOperator::Power if inplace => vm._ipow(a_ref, b_ref),
906912
bytecode::BinaryOperator::Power => vm._pow(a_ref, b_ref),
907-
bytecode::BinaryOperator::Divide => vm._div(a_ref, b_ref),
908-
bytecode::BinaryOperator::FloorDivide => {
909-
vm.call_method(&a_ref, "__floordiv__", vec![b_ref])
910-
}
913+
bytecode::BinaryOperator::Divide if inplace => vm._itruediv(a_ref, b_ref),
914+
bytecode::BinaryOperator::Divide => vm._truediv(a_ref, b_ref),
915+
bytecode::BinaryOperator::FloorDivide if inplace => vm._ifloordiv(a_ref, b_ref),
916+
bytecode::BinaryOperator::FloorDivide => vm._floordiv(a_ref, b_ref),
917+
// TODO: Subscript should probably have its own op
918+
bytecode::BinaryOperator::Subscript if inplace => unreachable!(),
911919
bytecode::BinaryOperator::Subscript => self.subscript(vm, a_ref, b_ref),
912-
bytecode::BinaryOperator::Modulo => vm._modulo(a_ref, b_ref),
913-
bytecode::BinaryOperator::Lshift => vm.call_method(&a_ref, "__lshift__", vec![b_ref]),
914-
bytecode::BinaryOperator::Rshift => vm.call_method(&a_ref, "__rshift__", vec![b_ref]),
920+
bytecode::BinaryOperator::Modulo if inplace => vm._imod(a_ref, b_ref),
921+
bytecode::BinaryOperator::Modulo => vm._mod(a_ref, b_ref),
922+
bytecode::BinaryOperator::Lshift if inplace => vm._ilshift(a_ref, b_ref),
923+
bytecode::BinaryOperator::Lshift => vm._lshift(a_ref, b_ref),
924+
bytecode::BinaryOperator::Rshift if inplace => vm._irshift(a_ref, b_ref),
925+
bytecode::BinaryOperator::Rshift => vm._rshift(a_ref, b_ref),
926+
bytecode::BinaryOperator::Xor if inplace => vm._ixor(a_ref, b_ref),
915927
bytecode::BinaryOperator::Xor => vm._xor(a_ref, b_ref),
928+
bytecode::BinaryOperator::Or if inplace => vm._ior(a_ref, b_ref),
916929
bytecode::BinaryOperator::Or => vm._or(a_ref, b_ref),
930+
bytecode::BinaryOperator::And if inplace => vm._iand(a_ref, b_ref),
917931
bytecode::BinaryOperator::And => vm._and(a_ref, b_ref),
918932
}?;
919933

vm/src/obj/objlist.rs

+16
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,21 @@ fn list_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
181181
}
182182
}
183183

184+
fn list_iadd(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
185+
arg_check!(
186+
vm,
187+
args,
188+
required = [(zelf, Some(vm.ctx.list_type())), (other, None)]
189+
);
190+
191+
if objtype::isinstance(other, &vm.ctx.list_type()) {
192+
get_mut_elements(zelf).extend_from_slice(&get_elements(other));
193+
Ok(zelf.clone())
194+
} else {
195+
Ok(vm.ctx.not_implemented())
196+
}
197+
}
198+
184199
fn list_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
185200
arg_check!(vm, args, required = [(o, Some(vm.ctx.list_type()))]);
186201

@@ -443,6 +458,7 @@ pub fn init(context: &PyContext) {
443458
The argument must be an iterable if specified.";
444459

445460
context.set_attr(&list_type, "__add__", context.new_rustfunc(list_add));
461+
context.set_attr(&list_type, "__iadd__", context.new_rustfunc(list_iadd));
446462
context.set_attr(
447463
&list_type,
448464
"__contains__",

0 commit comments

Comments
 (0)