from __future__ import annotations from typing import Iterable, Union from fractions import Fraction from itertools import zip_longest def is_scalar(obj): return isinstance(obj, (Fraction, int)) POLY_COLOR = None # POLY_COLOR = '36' # Uncomment to have color class PolyRenderBase: def render(self, poly: Poly) -> str: raise NotImplementedError() def hint(self, poly: Poly) -> int: raise NotImplementedError() class Poly: def __init__(self, vec: Iterable, letter='x', renderer=None): ''' vec: big-endian coefficients ''' from render import UnicodeRender self.vec = list(vec) self.letter = letter self.renderer = renderer or UnicodeRender() self._trim_zeros() def _trim_zeros(self): ''' Trim trailing zeros in coefficients: [0, 1, 1, 0] -> [0, 1, 1] Needed for multiplication to work correctly ''' len_zeros = 0 for c in reversed(self.vec): if c != 0: break len_zeros += 1 if len_zeros > 0: self.vec = self.vec[:-len_zeros] if len(self.vec) == 0: self.vec = [0] def __mul__(self, rhs: Union[Poly, Fraction, int]): if is_scalar(rhs): return Poly([c * rhs for c in self.vec]) elif isinstance(rhs, Poly): # no fft 4 u result = [0] * (self.deg() + rhs.deg() + 1) self._trim_zeros() rhs._trim_zeros() for i in range(len(self.vec)): for j in range(len(rhs.vec)): if all(c != 0 for c in (self.vec[i], rhs.vec[j])): result[i + j] += self.vec[i] * rhs.vec[j] return Poly(result) else: raise TypeError(f'{type(rhs)} not supported') def __truediv__(self, rhs: Union[Fraction, int]): if is_scalar(rhs): return Poly([c * Fraction(1, rhs) for c in self.vec]) else: raise TypeError(f'{type(rhs)} not supported') def __add__(self, rhs: Union[Poly, Fraction, int]): if is_scalar(rhs): return Poly([self.vec[0] + rhs] + self.vec[1:]) elif isinstance(rhs, Poly): result = [] for pair in zip_longest(self.vec, rhs.vec): result.append(sum(c for c in pair if c is not None)) return Poly(result) def __sub__(self, rhs: Union[Poly, Fraction, int]): return self + (-rhs) def __neg__(self): return Poly(-c for c in self.vec) def deg(self): ''' Get degree of poly ''' d = len(self.vec) - 1 for c in reversed(self.vec): if c != 0: return d else: d -= 1 return max(d, 0) def shift(self, deg): ''' divide by x^deg, drop rest ''' for i in range(deg): del self.vec[0] def __repr__(self): return self.renderer.render(self) def hint(self): return self.renderer.hint(self) class Row: ''' Matrix row. ''' def __init__(self, it: Iterable, dtype=Poly): def make_poly(obj): if isinstance(obj, (int, Fraction)): return Poly([obj]) elif isinstance(obj, Poly): return obj else: raise TypeError(f'{type(obj)} not supported') self.lst = list(map(make_poly, it)) self.hints = [1] * len(self.lst) def __add__(self, obj: Row): ''' Add another row ''' assert isinstance(obj, Row) assert len(obj.lst) == len(self.lst) pairs = zip(self.lst, obj.lst) return Row(map(lambda x: x[0] + x[1], pairs)) def __mul__(self, k: Union[int, Fraction]): ''' Multiply by int or fractions.Fraction ''' assert isinstance(k, (int, Fraction)) return Row(map(lambda x: x * k, self.lst)) def __iter__(self): return iter(self.lst) def __getitem__(self, key): return self.lst[key] def __setitem__(self, key, val): self.lst[key] = val def __len__(self): return len(self.lst) def __repr__(self): parts = [] for el, hint in zip(self.lst, self.hints): part = '{el: >{hint}}'.format(el=repr(el), hint=hint) if el.deg() > 0 and POLY_COLOR is not None: part = '\x1b[' + POLY_COLOR + 'm' + part + '\x1b[0m' parts.append(part) return ('[' + ' '.join(parts) + ']') def get_hints(self): ''' Get hints for other structures to align items ''' return [el.hint() for el in self.lst] def set_hints(self, hints): ''' Set hints to align items within row ''' self.hints = hints class Matrix: ''' Matrix. You can apply three elementary transforms to self. ''' def __init__(self, rows): def make_row(obj): if isinstance(obj, Row): return obj else: return Row(obj) self.rows = list(map(make_row, rows)) assert all(len(row) == len(self.rows[0]) for row in self.rows) def make_S(self, i: int, j: int, lbd: Union[int, Fraction], axis=0): ''' if axis == 0, do transform on rows, else on cols M[i] = M[i] + M[j] * lbd ''' if axis == 0: self.rows[i] = self.rows[i] + self.rows[j] * lbd else: for row in self.rows: row[i] = row[i] + row[j] * lbd def make_U(self, i: int, j: int, axis=0): ''' if axis == 0, do transform on rows, else on cols Swap M[i] and M[j] ''' if axis == 0: self.rows[i], self.rows[j] = self.rows[j], self.rows[i] else: for row in self.rows: row[i], row[j] = row[j], row[i] def make_D(self, i: int, lbd: Union[int, Fraction], axis=0): ''' if axis == 0, do transform on rows, else on cols Multiply M[i] by rational lbd ''' if axis == 0: self.rows[i] = self.rows[i] * lbd else: for row in self.rows: row[i] = row[i] * lbd def det(self) -> Poly: ''' Get determinant of matrix ''' assert all(len(row) == len(self.rows) for row in self.rows) if len(self.rows) == 1: return self.rows[0][0] res = Poly([0]) try: for i in range(len(self.rows[0])): cofactor = Matrix([ self.rows[j][:i] + self.rows[j][i + 1:] for j in range(1, len(self.rows)) ]) res += cofactor.det() * (-1) ** i * self.rows[0][i] except Exception as e: print(self) raise e return res def __sub__(self, oth): return self + (-oth) def __neg__(self): rows = [row * (-1) for row in self.rows] return Matrix(rows) def __add__(self, oth): rows = [row1 + row2 for row1, row2 in zip(self.rows, oth.rows)] return Matrix(rows) def __getitem__(self, idx): return self.rows[idx] def __setitem__(self, idx, val): self.rows[idx] = val @property def shape(self): return len(self.rows), len(self.rows[0]) def __matmul__(self, other): assert self.shape[1] == other.shape[0] rows = [ [ sum(( self.rows[j][i] * other.rows[i][k] for i in range(self.shape[1]) ), Poly([0])) for k in range(other.shape[1]) ] for j in range(self.shape[0]) ] return Matrix(rows) def __repr__(self): hints = [row.get_hints() for row in self.rows] max_hints = [ max(hints[j][i] for j in range(len(hints))) for i in range(len(hints[0])) ] for row in self.rows: row.set_hints(max_hints) return '\n'.join(repr(row) for row in self.rows) def to_tex(self): parts = [] for row in self.rows: parts.append(' & '.join(map(str, row.lst))) return '\\begin{bmatrix}\n' + '\\\\\n'.join(parts) + '\n\\end{bmatrix}' def triangulate(self, swap=False): for i in range(self.shape[1]): for j in range(i, self.shape[0]): assert(self[j][i].deg() == 0) if self[j][i].vec[0] != 0: # check we are main variable if not all(self[j][k].vec[0] == 0 for k in range(i)): continue for k in range(self.shape[0]): assert(self[k][i].deg() == 0) if k == j: continue coef = -Fraction(self[k][i].vec[0]) / Fraction(self[j][i].vec[0]) self[k] = self[k] + self[j] * coef if swap: self.make_U(j, i) break class Permutation: def __init__(self, perm: Iterable): self.perm = list(perm) def apply(self, now: list): assert len(self.perm) == len(now) res = [0] * len(self.perm) for i in range(len(self.perm)): res[self.perm[i]] = now[i] return res