diff options
Diffstat (limited to '4/libsolve2.py')
-rw-r--r-- | 4/libsolve2.py | 299 |
1 files changed, 299 insertions, 0 deletions
diff --git a/4/libsolve2.py b/4/libsolve2.py new file mode 100644 index 0000000..d723ba3 --- /dev/null +++ b/4/libsolve2.py @@ -0,0 +1,299 @@ +from __future__ import annotations +from typing import Iterable, Union +from fractions import Fraction +from itertools import zip_longest + + +def is_scalar(obj): + return isinstance(obj, (Fraction, int)) + + +class PolyRenderBase: + def render(self, poly: Poly) -> str: + raise NotImplementedError() + + def hint(self, poly: Poly) -> int: + raise NotImplementedError() + + +class Poly: + def __init__(self, vec: Iterable, letter='x', renderer=None): + ''' + vec: big-endian coefficients + ''' + from render import UnicodeRender + self.vec = list(vec) + self.letter = letter + self.renderer = renderer or UnicodeRender() + self._trim_zeros() + + def _trim_zeros(self): + ''' + Trim trailing zeros in coefficients: [0, 1, 1, 0] -> [0, 1, 1] + Needed for multiplication to work correctly + ''' + len_zeros = 0 + for c in reversed(self.vec): + if c != 0: + break + len_zeros += 1 + if len_zeros > 0: + self.vec = self.vec[:-len_zeros] + if len(self.vec) == 0: + self.vec = [0] + + def __mul__(self, rhs: Union[Poly, Fraction, int]): + if is_scalar(rhs): + return Poly([c * rhs for c in self.vec]) + + elif isinstance(rhs, Poly): + # no fft 4 u + result = [0] * (self.deg() + rhs.deg() + 1) + self._trim_zeros() + rhs._trim_zeros() + for i in range(len(self.vec)): + for j in range(len(rhs.vec)): + if all(c != 0 for c in (self.vec[i], rhs.vec[j])): + result[i + j] += self.vec[i] * rhs.vec[j] + return Poly(result) + + else: + raise TypeError(f'{type(rhs)} not supported') + + def __truediv__(self, rhs: Union[Fraction, int]): + if is_scalar(rhs): + return Poly([c * Fraction(1, rhs) for c in self.vec]) + else: + raise TypeError(f'{type(rhs)} not supported') + + def __add__(self, rhs: Union[Poly, Fraction, int]): + if is_scalar(rhs): + return Poly([self.vec[0] + rhs] + self.vec[1:]) + elif isinstance(rhs, Poly): + result = [] + for pair in zip_longest(self.vec, rhs.vec): + result.append(sum(c for c in pair if c is not None)) + return Poly(result) + + def __sub__(self, rhs: Union[Poly, Fraction, int]): + return self + (-rhs) + + def __neg__(self): + return Poly(-c for c in self.vec) + + def __eq__(self, other): + return self.vec == other.vec + + def deg(self): + ''' + Get degree of poly + ''' + return len(self.vec) - 1 + + def __repr__(self): + return self.renderer.render(self) + + def hint(self): + return self.renderer.hint(self) + + +class Row: + ''' + Matrix row. + ''' + + def __init__(self, it: Iterable, dtype=Poly): + def make_poly(obj): + if isinstance(obj, (int, Fraction)): + return Poly([obj]) + elif isinstance(obj, Poly): + return obj + else: + raise TypeError(f'{type(obj)} not supported') + + self.lst = list(map(make_poly, it)) + self.hints = [1] * len(self.lst) + + def __add__(self, obj: Row): + ''' + Add another row + ''' + assert isinstance(obj, Row) + assert len(obj.lst) == len(self.lst) + pairs = zip(self.lst, obj.lst) + return Row(map(lambda x: x[0] + x[1], pairs)) + + def __mul__(self, k: Union[int, Fraction, Poly]): + ''' + Multiply by int or fractions.Fraction + ''' + assert isinstance(k, (int, Fraction, Poly)) + return Row(map(lambda x: x * k, self.lst)) + + def __iter__(self): + return iter(self.lst) + + def __getitem__(self, key): + return self.lst[key] + + def __setitem__(self, key, val): + self.lst[key] = val + + def __len__(self): + return len(self.lst) + + def __repr__(self): + parts = [] + for el, hint in zip(self.lst, self.hints): + part = '{el: >{hint}}'.format(el=repr(el), hint=hint) + if el.deg() > 0 and POLY_COLOR is not None: + part = '\x1b[' + POLY_COLOR + 'm' + part + '\x1b[0m' + parts.append(part) + + return ('[' + ' '.join(parts) + ']') + + def cloned(self): + return Row(self.lst) + + def get_hints(self): + ''' + Get hints for other structures to align items + ''' + return [el.hint() for el in self.lst] + + def set_hints(self, hints): + ''' + Set hints to align items within row + ''' + self.hints = hints + + +class Matrix: + ''' + Matrix. You can apply three elementary transforms to self. + ''' + + def __init__(self, rows): + def make_row(obj): + if isinstance(obj, Row): + return obj.cloned() + else: + return Row(obj) + + self.rows = list(map(make_row, rows)) + assert all(len(row) == len(self.rows[0]) for row in self.rows) + + def det(self) -> Poly: + ''' + Get determinant of matrix + ''' + assert all(len(row) == len(self.rows) for row in self.rows) + if len(self.rows) == 1: + return self.rows[0][0] + res = Poly([0]) + try: + for i in range(len(self.rows[0])): + cofactor = Matrix([ + self.rows[j][:i] + self.rows[j][i + 1:] + for j in range(1, len(self.rows)) + ]) + res += cofactor.det() * (-1) ** i * self.rows[0][i] + except Exception as e: + print(self) + raise e + + return res + + def __sub__(self, oth): + return self + (-oth) + + def __neg__(self): + rows = [row * (-1) for row in self.rows] + return Matrix(rows) + + def __add__(self, oth): + rows = [row1 + row2 for row1, row2 in zip(self.rows, oth.rows)] + return Matrix(rows) + + def __getitem__(self, idx): + return self.rows[idx] + + def __setitem__(self, idx, val): + self.rows[idx] = val + + @property + def shape(self): + return len(self.rows), len(self.rows[0]) + + def __matmul__(self, other): + assert self.shape[1] == other.shape[0] + rows = [ + [ + sum(( + self.rows[j][i] * other.rows[i][k] + for i in range(self.shape[1]) + ), Poly([0])) + for k in range(other.shape[1]) + ] for j in range(self.shape[0]) + ] + return Matrix(rows) + + def __mul__(self, other: Union[int, Fraction, Poly]): + if not isinstance(other, (Poly, Fraction, int)): + raise TypeError(f'{type(other)} not supported for mul') + + return Matrix([ + row * other for row in self.rows + ]) + + def __repr__(self): + hints = [row.get_hints() for row in self.rows] + max_hints = [ + max(hints[j][i] for j in range(len(hints))) + for i in range(len(hints[0])) + ] + + for row in self.rows: + row.set_hints(max_hints) + + return '\n'.join(repr(row) for row in self.rows) + + def cloned(self): + return Matrix(self.rows) + + def to_tex(self): + parts = [] + for row in self.rows: + parts.append(' & '.join(map(str, row.lst))) + return '\\begin{bmatrix}\n' + '\\\\\n'.join(parts) + '\n\\end{bmatrix}' + + def rank(self): + tri = self.triangled() + return self.shape[0] - sum( + all(el == Poly([0]) for el in row) for row in tri.rows + ) + + def triangled(self, swap=False): + copy = self.cloned() + for i in range(copy.shape[1]): + for j in range(i, copy.shape[0]): + assert(copy[j][i].deg() == 0) + if copy[j][i].vec[0] != 0: + # check we are main variable + if not all(copy[j][k].vec[0] == 0 for k in range(i)): + continue + + for k in range(copy.shape[0]): + assert(copy[k][i].deg() == 0) + if k == j: + continue + coef = -Fraction(copy[k][i].vec[0]) / Fraction(copy[j][i].vec[0]) + copy[k] = copy[k] + copy[j] * coef + + if swap: + copy[j], copy[i] = copy[i], copy[j] + break + return copy + + +x = Poly([0, 1]) |