Scribbling

Python: Operator Overloading 본문

Computer Science/Python

Python: Operator Overloading

focalpoint 2022. 4. 20. 10:53

 

In this post, we add operator overloading code to custom 'Vector' class in the previous post.

https://focalpoint.tistory.com/300

 

Python: Sequence Protocol

To learn sequence protocol in Python, we create a custom vector class. from array import array import math import reprlib class Vector: typecode = 'd' def __init__(self, components): self._component..

focalpoint.tistory.com

 

Vector Class is as below.

from array import array
import math
import reprlib
import numbers
import functools
import operator
import itertools

class Vector:
    typecode = 'd'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __iter__(self):
        return iter(self._components)

    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + bytes(self._components))

    def __eq__(self, other):
        return len(self) == len(other) and all((a == b for a, b in zip(self, other)))

    def __hash__(self):
        # hashes = (hash(x) for x in self)
        hashes = map(hash, self)
        return functools.reduce(operator.xor, hashes, 0)

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))

    def __bool__(self):
        return bool(abs(self))

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = f'{cls.__name__!r} indices must be integers'
            raise TypeError(msg)

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)
        print(cls, name)
        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = f'{cls.__name__!r} object has no attribute {name!r}'
        raise AttributeError(msg)

    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!r}'
            elif name.islower():
                error = 'can`t set attributes `a` to `z` in {cls_name!r}'
            else:
                error = ''
            if error:
                msg = error.format(cls_name=cls.__name__, attr_name=name)
                raise AttributeError(msg)
        super().__setattr__(name, value)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a

    @property
    def angles(self):
        return [self.angle(n) for n in range(1, len(self))]

    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], self.angles)
            outer_fmt = '<{}>'
        else:
            coords = self
            outer_fmt = '({})'
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(','.join(components))

    @classmethod
    def frombytes(cls, octests):
        typecode = chr(octests[0])
        memv = memoryview(octests[1:]).cast(typecode)
        return cls(memv)

 

 

Unary Operands: Always return a new object

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))
    
    def __neg__(self):
        return Vector(-x for x in self)
    
    def __pos__(self):
        return Vector(self)

 

 

Overloading + for Vector Addition

    def __add__(self, other):
        pairs = itertools.zip_longest(self, other, fillvalue=0.0)
        return Vector(a + b for a, b in pairs)

With the above implementation, below code works like a charm.

v1 = Vector([3, 4, 5, 6])
v3 = Vector([1, 2])
print(v1+v3)

However, we also need to implement __radd__ method for cases like below.

v1 = Vector([3, 4, 5, 6])
print([1, 2] + v1)
    def __add__(self, other):
        pairs = itertools.zip_longest(self, other, fillvalue=0.0)
        return Vector(a + b for a, b in pairs)

    def __radd__(self, other):
        return self + other

 

Plus, by adding error-catching code, now Python interpreter tries __radd__ method after failing __add__ method.

    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented

    def __radd__(self, other):
        return self + other

 

 

Overloading * for Scalar Multiplication

    def __mul__(self, scalar):
        if isinstance(scalar, numbers.Real):
            return Vector(x * scalar for x in self)
        else:
            return NotImplemented
    
    def __rmul__(self, scalar):
        return self * scalar

Note: Use ABC(in this case, numbers.Real) in isinstance().

    def __matmul__(self, other):
        try:
            return sum(a * b for a in self for b in other)
        except TypeError:
            return NotImplemented
    
    def __rmatmul__(self, other):
        return self @ other

 

In Python3, you do not need to implement __ne__ emthod if you have __eq__ method.

    def __eq__(self, other):
        if isinstance(other, Vector):     
            return len(self) == len(other) and all((a == b for a, b in zip(self, other)))
        else:
            return NotImplemented

 

 

Full Code: 

from array import array
import math
import reprlib
import numbers
import functools
import operator
import itertools

class Vector:
    typecode = 'd'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __iter__(self):
        return iter(self._components)

    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + bytes(self._components))

    def __hash__(self):
        # hashes = (hash(x) for x in self)
        hashes = map(hash, self)
        return functools.reduce(operator.xor, hashes, 0)

    def __bool__(self):
        return bool(abs(self))

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = f'{cls.__name__!r} indices must be integers'
            raise TypeError(msg)

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)
        print(cls, name)
        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = f'{cls.__name__!r} object has no attribute {name!r}'
        raise AttributeError(msg)

    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!r}'
            elif name.islower():
                error = 'can`t set attributes `a` to `z` in {cls_name!r}'
            else:
                error = ''
            if error:
                msg = error.format(cls_name=cls.__name__, attr_name=name)
                raise AttributeError(msg)
        super().__setattr__(name, value)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a

    @property
    def angles(self):
        return [self.angle(n) for n in range(1, len(self))]

    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], self.angles)
            outer_fmt = '<{}>'
        else:
            coords = self
            outer_fmt = '({})'
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(','.join(components))

    @classmethod
    def frombytes(cls, octests):
        typecode = chr(octests[0])
        memv = memoryview(octests[1:]).cast(typecode)
        return cls(memv)

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))

    def __neg__(self):
        return Vector(-x for x in self)

    def __pos__(self):
        return Vector(self)

    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented

    def __radd__(self, other):
        return self + other

    def __mul__(self, scalar):
        if isinstance(scalar, numbers.Real):
            return Vector(x * scalar for x in self)
        else:
            return NotImplemented

    def __rmul__(self, scalar):
        return self * scalar

    def __matmul__(self, other):
        try:
            return sum(a * b for a in self for b in other)
        except TypeError:
            return NotImplemented

    def __rmatmul__(self, other):
        return self @ other

    def __eq__(self, other):
        if isinstance(other, Vector):
            return len(self) == len(other) and all((a == b for a, b in zip(self, other)))
        else:
            return NotImplemented

 

 

Summary:

- Unary operators must return a new object.

- To support operators with a different data type, return NotImplemented when TypeError is raised so that __radd__ would be called.

- To support operators with a different data type, either use 'Duck Typing' or isinstance(). Duck Typing is more flexible, on the other hand, isinstance() is more clear.

 

'Computer Science > Python' 카테고리의 다른 글

Python: Context Manager  (0) 2022.04.29
Python: Iterator, Generator  (0) 2022.04.22
Python: Inheritance  (0) 2022.04.19
Python: Interfaces  (0) 2022.04.18
Python: ABC Class  (0) 2022.04.07