Easier Classes

Python Classes Without All The Cruft

/ @treyhunner

Python Morsels
Truthful Technology

Pythonic Classes

Friendly Classes

  • have nice string representations
  • can be compared to each other
  • embrace operator overloading

Case 1: Month class

Comparability and Orderability


>>> m = Month(2018, 6)
>>> m
Month(year=2018, month=6)
>>> m == Month(2018, 6)
True
>>> m != Month(2018, 6)
False
>>> m < Month(2018, 12)
True
>>> m > Month(2019, 1)
False
>>> m <= Month(2018, 12)
True
>>> m >= Month(2018, 6)
True
        

class Month:
    def __init__(self, year, month):
        pass
    def __repr__(self):
        pass
    def __eq__(self, other):
        pass
    def __lt__(self, other):
        pass
    def __gt__(self, other):
        pass
    def __le__(self, other):
        pass
    def __ge__(self, other):
        pass
        

class Month:
    def __init__(self, year, month):
        self.year, self.month = year, month
    def __repr__(self):
        return f"Month(year={self.year}, month={self.month})"
    def __eq__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) == (other.year, other.month)
    def __lt__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) < (other.year, other.month)
    def __gt__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) > (other.year, other.month)
    def __le__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) <= (other.year, other.month)
    def __ge__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) >= (other.year, other.month)
        

from functools import total_ordering


@total_ordering
class Month:
    def __init__(self, year, month):
        self.year, self.month = year, month
    def __repr__(self):
        return f"Month(year={self.year}, month={self.month})"
    def __eq__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) == (other.year, other.month)
    def __lt__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) < (other.year, other.month)
        

Case 2: Point class

Iterability and Immutability


>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True
>>> x, y, z = p
>>> x
1
>>> p.x = 4
Traceback (most recent call last):
  File "<stdin>", line 1
  File "<string>", line 3
AttributeError: object is immutable
>>> {Point(1, 2, 3), Point(1, 2, 3)}
{Point(x=1, y=2, z=3)}
        

class Point:
    def __init__(self, x, y, z):
        pass
    def __repr__(self):
        pass
    def __eq__(self, other):
        pass
    def __iter__(self):
        pass
    def __setattr__(self, name, value):
        pass
    def __hash__(self):
        pass

        
        

class Point:
    def __init__(self, x, y, z):
        self.x, self.y, self.z = x, y, z
    def __repr__(self):
        return f"Point(x={self.x}, y={self.y}, z={self.z})"
    def __eq__(self, other):
        if not isinstance(other, Point):
            return NotImplemented
        return tuple(self) == tuple(other)
    def __iter__(self):
        yield from (self.x, self.y, self.z)
    def __setattr__(self, attribute, value):
        raise AttributeError("object is immutable")
    __delattr__ = __setattr__
    def __hash__(self):
        return hash(tuple(self))
        

NamedTuple


from typing import NamedTuple

class Point(NamedTuple):
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True
>>> p.x = 4
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: cannot set attribute
>>> 
        

namedtuple are tuples


>>> Point(1, 2, 3) < Point(4, 5, 6)
True
>>> Point(1, 2, 3) + Point(4, 5, 6)
(1, 2, 3, 4, 5, 6)
>>> Point(1, 2, 3) * 2
(1, 2, 3, 1, 2, 3)
>>> len(Point(1, 2, 3))
3
        

from typing import NamedTuple

class Point(NamedTuple):
    x: float
    y: float
    z: float
    def __lt__(self, other): raise TypeError
    def __le__(self, other): raise TypeError
    def __gt__(self, other): raise TypeError
    def __ge__(self, other): raise TypeError
    def __add__(self, other): raise TypeError
    def __mul__(self, other): raise TypeError
    def __rmul__(self, other): raise TypeError
    def __len__(self, other): raise TypeError
        

attrs


$ pip install attrs
        

import attr

@attr.s(auto_attribs=True)
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True
>>> p < Point(4, 5, 6)
True
        

import attr

@attr.s(auto_attribs=True, cmp=False, frozen=True)
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from (self.x, self.y, self.z)
    def __eq__(self, other):
        if not isinstance(other, Point):
            return NotImplemented
        return tuple(self) == tuple(other)
    def __hash__(self):
        return hash(tuple(self))
        

dataclasses


from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True
>>> p < Point(4, 5, 6)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: '<' not supported between instances of 'Point' and 'Point'
        

from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> x, y, z = p
Traceback (most recent call last):
  File "", line 1, in 
TypeError: cannot unpack non-iterable Point object
>>> p.x = 4
>>> p
Point(x=4, y=2, z=3)


        

from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from (self.x, self.y, self.z)
        

from dataclasses import dataclass, astuple

@dataclass(frozen=True)
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from astuple(self)
        

>>> p = Point(1, 2, 3)
>>> x, y, z = p
>>> x
1
>>> y
2

from dataclasses import dataclass, astuple

@dataclass(frozen=True)
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from astuple(self)
        

>>> p = Point(1, 2, 3)
>>> p.x = 4
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 3, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'x'
        

dataclasses

  • Built-in to standard library (in Python 3.7)
  • Available as a third-party library
  • While attrs supports both Python 2 and Python 3, dataclasses only work on Python 3
  • dataclasses is simpler, but less feature-rich than attrs

Friendly Class Recipes

What you get out of the box


from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True


        

Immutability


from dataclasses import dataclass

@dataclass(frozen=True)
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p.x = 4
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 3, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'x'
        

Iterability


from dataclasses import dataclass, astuple

@dataclass
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from astuple(self)
        

>>> p = Point(1, 2, 3)
>>> x, y, z = p
>>> tuple(p)
(1, 2, 3)
        

Orderability


from dataclasses import dataclass

@dataclass(order=True)
class Month:
    year: int
    month: int
        

>>> Month(2019, 12) < Month(2020, 1)
True
>>> months = [Month(2018, 6), Month(2018, 1), Month(2019, 10)]
>>> print(*sorted(months), sep='\n')
Month(year=2018, month=1)
Month(year=2018, month=6)
Month(year=2019, month=10)
        

Friendly operations


from dataclasses import dataclass, astuple

@dataclass
class Vector:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from astuple(self)
    def __add__(self, other):
        return Vector(*(a + b for a, b in zip(self, other)))
    def __sub__(self, other):
        return Vector(*(a - b for a, b in zip(self, other)))        

>>> Vector(1, 2, 3) + Vector(4, 5, 6)
Vector(x=5, y=7, z=9)
>>> Vector(4, 5, 6) - Vector(1, 2, 3)
Vector(x=3, y=3, z=3)
        

Should I make a class?


class Matrix:
    """Turn a string into a matrix-like thing."""
    def __init__(self, string):
        self.string = string
    @property
    def rows(self):
        return [
            [float(x) for x in row.split()]
            for row in self.string.splitlines()
        ]
    @property
    def columns(self):
        return [
            list(column)
            for column in zip(*self.rows)
        ]
        

>>> matrix = Matrix("9 8 7\n5 3 2\n6 6 7")
>>> matrix.rows
[[9, 8, 7], [5, 3, 2], [6, 6, 7]]
>>> matrix.columns
[[9, 5, 6], [8, 3, 6], [7, 2, 7]]
        

def matrix_from_string(string):
    """Convert rows of numbers to list of lists."""
    return [
        [float(x) for x in row.split()]
        for row in string.splitlines()
    ]

def transpose(matrix):
    """Return a transposed version of given list of lists."""
    return [
        list(column)
        for column in zip(*matrix)
    ]
        

>>> matrix = matrix_from_string("9 8 7\n5 3 2\n6 6 7")
>>> matrix
[[9, 8, 7], [5, 3, 2], [6, 6, 7]]
>>> transpose(matrix)
[[9, 5, 6], [8, 3, 6], [7, 2, 7]]
        

When making classes, consider...

  • You don't always need custom classes for your data (see Jack Diederich's Stop Writing Classes)
  • Custom can make things very handy, but they can be a step backward unless you make them friendly
  • Friendly classes can require a lot of boilerplate code
  • To avoid distracting boilerplate code, use dataclasses or attrs for creating custom data-heavy classes

Dive a little deeper into classes

http://trey.io/nbpy2018

Trey Hunner
Python Team Trainer

Advanced Recipes

Metadata on fields!


from dataclasses import dataclass, field, fields

@dataclass
class Point:
    x: float = field(metadata={'iter': True})
    y: float = field(metadata={'iter': True})
    z: float = field(metadata={'iter': True})
    color: str
    def __iter__(self):
        return (
            getattr(self, field.name)
            for field in fields(self)
            if field.metadata.get('iter')
        )
        

>>> x, y, z = Point(1, 2, 3, color='red')
>>> (x, y, z)
(1, 2, 3)
        

Non-comparable fields!


from dataclasses import dataclass, field, fields

@dataclass
class Point:
    x: float = field(metadata={'iter': True})
    y: float = field(metadata={'iter': True})
    z: float = field(metadata={'iter': True})
    color: str = field(compare=False)
    def __iter__(self):
        return (
            getattr(self, field.name)
            for field in fields(self)
            if field.metadata.get('iter')
        )
        

>>> p1 = Point(1, 2, 3, color='red')
>>> p2 = Point(1, 2, 3, color='blue')
>>> p1 == p2
True
        

Dynamic default values!


from dataclasses import dataclass, field, fields
import random

def random_color(): return random.choice(['purple', 'blue', 'red'])

@dataclass
class Point:
    x: float = field(metadata={'iter': True})
    y: float = field(metadata={'iter': True})
    z: float = field(metadata={'iter': True})
    color: str = field(default_factory=random_color, compare=False)
    def __iter__(self):
        ...
        

>>> p1 = Point(1, 2, 3)
>>> p1.color
'blue'
        

Auto-created fields


from dataclasses import dataclass, field, fields
import random

def random_color(): return random.choice(['purple', 'blue', 'red'])

@dataclass
class Point:
    x: float = field(metadata={'iter': True})
    y: float = field(metadata={'iter': True})
    z: float = field(metadata={'iter': True})
    color: str = field(compare=False, init=False)
    def __post_init__(self):
        self.color = random_color()
    def __iter__(self):
        ...
        

>>> point = Point(1, 2, 3)
>>> point.color
'purple'
>>> point = Point(1, 2, 3, color='blue')
TypeError: __init__() got an unexpected keyword argument 'color'