Easier Classes

Python Classes Without All The Cruft

/ @treyhunner

Truthful Technology
Python Morsels

Pythonic Classes

Friendly Classes

  • have nice string representations
  • can be compared to each other
  • embrace operator overloading

Case 1: Month class

Comparability and Orderability


>>> m = Month(2018, 6)
>>> m
Month(year=2018, month=6)
>>> m == Month(2018, 6)
True
>>> m != Month(2018, 6)
False
>>> m < Month(2018, 12)
True
>>> m > Month(2019, 1)
False
>>> m <= Month(2018, 12)
True
>>> m >= Month(2018, 6)
True
        

class Month:
    def __init__(self, year, month):
        pass
    def __repr__(self):
        pass
    def __eq__(self, other):
        pass
    def __lt__(self, other):
        pass
    def __gt__(self, other):
        pass
    def __le__(self, other):
        pass
    def __ge__(self, other):
        pass
        

class Month:
    def __init__(self, year, month):
        self.year, self.month = year, month
    def __repr__(self):
        return f"Month(year={self.year}, month={self.month})"
    def __eq__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) == (other.year, other.month)
    def __lt__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) < (other.year, other.month)
    def __gt__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) > (other.year, other.month)
    def __le__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) <= (other.year, other.month)
    def __ge__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) >= (other.year, other.month)
        

from functools import total_ordering


@total_ordering
class Month:
    def __init__(self, year, month):
        self.year, self.month = year, month
    def __repr__(self):
        return f"Month(year={self.year}, month={self.month})"
    def __eq__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) == (other.year, other.month)
    def __lt__(self, other):
        if not isinstance(other, Month):
            return NotImplemented
        return (self.year, self.month) < (other.year, other.month)
        

Case 2: Point class

Iterability and Immutability


>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True
>>> x, y, z = p
>>> x
1
>>> p.x = 4
Traceback (most recent call last):
  File "<stdin>", line 1
  File "<string>", line 3
AttributeError: object is immutable
>>> {Point(1, 2, 3), Point(1, 2, 3)}
{Point(x=1, y=2, z=3)}
        

class Point:
    def __init__(self, x, y, z):
        pass
    def __repr__(self):
        pass
    def __eq__(self, other):
        pass
    def __iter__(self):
        pass
    def __setattr__(self, name, value):
        pass
    def __hash__(self):
        pass

        
        

class Point:
    def __init__(self, x, y, z):
        self.x, self.y, self.z = x, y, z
    def __repr__(self):
        return f"Point(x={self.x}, y={self.y}, z={self.z})"
    def __eq__(self, other):
        if not isinstance(other, Point):
            return NotImplemented
        return tuple(self) == tuple(other)
    def __iter__(self):
        yield from (self.x, self.y, self.z)
    def __setattr__(self, attribute, value):
        raise AttributeError("object is immutable")
    __delattr__ = __setattr__
    def __hash__(self):
        return hash(tuple(self))
        

namedtuple


from collections import namedtuple

Point = namedtuple('Point', ['x', 'y', 'z'])
        

>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True
>>> p.x = 4
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: cannot set attribute
>>> {Point(1, 2, 3), Point(1, 2, 3)}
{Point(x=1, y=2, z=3)}
        

from typing import NamedTuple

class Point(NamedTuple):
    x: float = 0
    y: float = 0
    z: float = 0
        

from collections import namedtuple

Point = namedtuple('Point', ['x', 'y', 'z'])
        

namedtuple are tuples


>>> Point(1, 2, 3) < Point(4, 5, 6)
True
>>> Point(1, 2, 3) + Point(4, 5, 6)
(1, 2, 3, 4, 5, 6)
>>> Point(1, 2, 3) * 2
(1, 2, 3, 1, 2, 3)
>>> len(Point(1, 2, 3))
3
        

from typing import NamedTuple

class Point(NamedTuple):
    x: float
    y: float
    z: float
    def __lt__(self, other): raise TypeError
    def __le__(self, other): raise TypeError
    def __gt__(self, other): raise TypeError
    def __ge__(self, other): raise TypeError
    def __add__(self, other): raise TypeError
    def __mul__(self, other): raise TypeError
    def __rmul__(self, other): raise TypeError
    def __len__(self, other): raise TypeError
        

attrs


$ pip install attrs
        

import attr

@attr.s(auto_attribs=True)
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True
>>> p < Point(4, 5, 6)
True
        

import attr

@attr.s(auto_attribs=True, cmp=False, frozen=True)
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from (self.x, self.y, self.z)
    def __eq__(self, other):
        if not isinstance(other, Point):
            return NotImplemented
        return tuple(self) == tuple(other)
    def __hash__(self):
        return hash(tuple(self))
        

dataclasses


from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True
>>> p < Point(4, 5, 6)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: '<' not supported between instances of 'Point' and 'Point'
        

from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> x, y, z = p
Traceback (most recent call last):
  File "", line 1, in 
TypeError: cannot unpack non-iterable Point object
>>> p.x = 4
>>> p
Point(x=4, y=2, z=3)


        

from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from (self.x, self.y, self.z)
        

from dataclasses import dataclass, astuple

@dataclass(frozen=True)
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from astuple(self)
        

>>> p = Point(1, 2, 3)
>>> x, y, z = p
>>> x
1
>>> y
2

from dataclasses import dataclass, astuple

@dataclass(frozen=True)
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from astuple(self)
        

>>> p = Point(1, 2, 3)
>>> p.x = 4
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 3, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'x'
        

dataclasses

  • Built-in to standard library (in Python 3.7)
  • Available as a third-party library
  • dataclasses is simpler, but less feature-rich than attrs
  • While attrs supports both Python 2 and Python 3, dataclasses only work on Python 3

Friendly Class Recipes

What you get out of the box


from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p
Point(x=1, y=2, z=3)
>>> p == Point(1, 2, 3)
True


        

Immutability!


from dataclasses import dataclass

@dataclass(frozen=True)
class Point:
    x: float
    y: float
    z: float
        

>>> p = Point(1, 2, 3)
>>> p.x = 4
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 3, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'x'
        

Hashability!


from dataclasses import dataclass

@dataclass(frozen=True)
class Point:
    x: float
    y: float
    z: float
        

>>> aliveness= {Point(1, 2, 3): True, Point(4, 5, 6): False}
>>> aliveness[Point(4, 5, 6)]
False
>>> aliveness[Point(1, 2, 3)]
True


        

Iterability!


from dataclasses import dataclass, astuple

@dataclass
class Point:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from astuple(self)
        

>>> p = Point(1, 2, 3)
>>> x, y, z = p
>>> tuple(p)
(1, 2, 3)

        

Orderability?


from dataclasses import dataclass

@dataclass
class Month:
    year: int
    month: int
        

>>> eol_month = Month(2020, 1)
>>> eol_month
Month(year=2020, month=1)
>>> eol_month > Month(2019, 12)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: '>' not supported between instances of 'Month' and 'Month'
        

Orderability!


from dataclasses import dataclass

@dataclass(order=True)
class Month:
    year: int
    month: int
        

>>> Month(2020, 1) > Month(2019, 12)
True
>>> months = [Month(2018, 6), Month(2018, 1), Month(2019, 10)]
>>> print(*sorted(months), sep='\n')
Month(year=2018, month=1)
Month(year=2018, month=6)
Month(year=2019, month=10)
        

Summability and other operations?


from dataclasses import dataclass, astuple

@dataclass
class Vector:
    x: float
    y: float
    z: float
        

>>> p = Vector(1, 2, 3)
>>> q = Vector(4, 5, 6)
>>> p + q
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unsupported operand type(s) for +: 'Vector' and 'Vector'
>>> q - p
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unsupported operand type(s) for +: 'Vector' and 'Vector'

Summability and other operations!


from dataclasses import dataclass, astuple

@dataclass
class Vector:
    x: float
    y: float
    z: float
    def __iter__(self):
        yield from astuple(self)
    def __add__(self, other):
        return Vector(*(a + b for a, b in zip(self, other)))
    def __sub__(self, other):
        return Vector(*(a - b for a, b in zip(self, other)))        

>>> Vector(1, 2, 3) + Vector(4, 5, 6)
Vector(x=5, y=7, z=9)
>>> Vector(4, 5, 6) - Vector(1, 2, 3)
Vector(x=3, y=3, z=3)
        

Advanced Recipes

Metadata on fields!


from dataclasses import dataclass, field, fields

@dataclass
class Point:
    x: float = field(metadata={'iter': True})
    y: float = field(metadata={'iter': True})
    z: float = field(metadata={'iter': True})
    color: str
    def __iter__(self):
        return (
            getattr(self, field.name)
            for field in fields(self)
            if field.metadata.get('iter')
        )
        

>>> x, y, z = Point(1, 2, 3, color='red')
>>> (x, y, z)
(1, 2, 3)
        

Non-comparable fields!


from dataclasses import dataclass, field, fields

@dataclass
class Point:
    x: float = field(metadata={'iter': True})
    y: float = field(metadata={'iter': True})
    z: float = field(metadata={'iter': True})
    color: str = field(compare=False)
    def __iter__(self):
        return (
            getattr(self, field.name)
            for field in fields(self)
            if field.metadata.get('iter')
        )
        

>>> p1 = Point(1, 2, 3, color='red')
>>> p2 = Point(1, 2, 3, color='blue')
>>> p1 == p2
True
        

Dynamic default values!


from dataclasses import dataclass, field, fields
import random

def random_color(): return random.choice(['purple', 'blue', 'red'])

@dataclass
class Point:
    x: float = field(metadata={'iter': True})
    y: float = field(metadata={'iter': True})
    z: float = field(metadata={'iter': True})
    color: str = field(default_factory=random_color, compare=False)
    def __iter__(self):
        ...
        

>>> p1 = Point(1, 2, 3)
>>> p1.color
'blue'
        

Auto-created fields


from dataclasses import dataclass, field, fields
import random

def random_color(): return random.choice(['purple', 'blue', 'red'])

@dataclass
class Point:
    x: float = field(metadata={'iter': True})
    y: float = field(metadata={'iter': True})
    z: float = field(metadata={'iter': True})
    color: str = field(compare=False, init=False)
    def __post_init__(self):
        self.color = random_color()
    def __iter__(self):
        ...
        

>>> point = Point(1, 2, 3)
>>> point.color
'purple'
>>> point = Point(1, 2, 3, color='blue')
TypeError: __init__() got an unexpected keyword argument 'color'
        

Should I make a class?

(Examples inspired by exercism.io)


class Matrix:
    """Turn a string into a matrix-like thing."""
    def __init__(self, string):
        self.string = string
    @property
    def rows(self):
        return [
            [float(x) for x in row.split()]
            for row in self.string.splitlines()
        ]
    @property
    def columns(self):
        return [
            list(column)
            for column in zip(*self.rows)
        ]
        

>>> matrix = Matrix("9 8 7\n5 3 2\n6 6 7")
>>> matrix.rows
[[9, 8, 7], [5, 3, 2], [6, 6, 7]]
>>> matrix.columns
[[9, 5, 6], [8, 3, 6], [7, 2, 7]]
        

def matrix_from_string(string):
    """Convert rows of numbers to list of lists."""
    return [
        [float(x) for x in row.split()]
        for row in string.splitlines()
    ]

def transpose(matrix):
    """Return a transposed version of given list of lists."""
    return [
        list(column)
        for column in zip(*matrix)
    ]
        

>>> matrix = matrix_from_string("9 8 7\n5 3 2\n6 6 7")
>>> matrix
[[9, 8, 7], [5, 3, 2], [6, 6, 7]]
>>> transpose(matrix)
[[9, 5, 6], [8, 3, 6], [7, 2, 7]]
        

class SpaceAge:
    SECONDS_IN_EARTH_YEAR = 31557600.0
    def __init__(self, seconds):
        self.seconds = seconds
    def on_earth(self):
        annual_seconds = self.SECONDS_IN_EARTH_YEAR
        return round(self.seconds / annual_seconds, 2)
    def on_mercury(self):
        annual_seconds = self.SECONDS_IN_EARTH_YEAR * 0.2408467
        return round(self.seconds / annual_seconds, 2)
    def on_venus(self):
        annual_seconds = self.SECONDS_IN_EARTH_YEAR * 0.61519726
        return round(self.seconds / annual_seconds, 2)
    def on_mars(self):
        annual_seconds = self.SECONDS_IN_EARTH_YEAR * 1.8808158
        return round(self.seconds / annual_seconds, 2)
    # ...
        

>>> age = SpaceAge(seconds=1_000_000_000)
>>> age.on_earth()
31.69
>>> age.on_mars()
16.85
>>> age.on_mercury()
131.57
        

SECONDS_IN_EARTH_YEAR = 31557600

ORBITAL_PERIOD = {
    'earth': 1,
    'mercury': 0.2408467,
    'venus': 0.61519726,
    'mars': 1.8808158,
    'jupiter': 11.862615,
    'saturn': 29.447498,
    'uranus': 84.016846,
    'neptune': 164.79132,
}

def age_on_planet(seconds_alive, planet):
    """Return age on a given planet."""
    annual_seconds = SECONDS_IN_EARTH_YEAR * ORBITAL_PERIOD[planet]
    return round(seconds_alive / annual_seconds, 2)
        

>>> age_on_planet(seconds=1_000_000_000, planet='earth')
31.69
>>> age_on_planet(seconds=1_000_000_000, planet='mars')
16.85
>>> age_on_planet(seconds=1_000_000_000, planet='mercury')
131.57
        

When making classes, consider...

  • You don't always need custom classes for your data (see Jack Diederich's Stop Writing Classes)
  • Custom can make things very handy, but they can be a step backward unless you make them friendly
  • Friendly classes can require a lot of boilerplate code
  • To avoid distracting boilerplate code, use dataclasses or attrs for creating custom data-heavy classes

Need help getting your team up to speed on Python?

Trey Hunner
Python & Django Team Trainer

Contact me: trey@truthful.technology