diff options
Diffstat (limited to 'libsolve.py')
-rw-r--r-- | libsolve.py | 298 |
1 files changed, 298 insertions, 0 deletions
diff --git a/libsolve.py b/libsolve.py new file mode 100644 index 0000000..d932dd6 --- /dev/null +++ b/libsolve.py @@ -0,0 +1,298 @@ +from __future__ import annotations +from typing import Iterable, Union +from fractions import Fraction +from itertools import zip_longest + + +def is_scalar(obj): + return isinstance(obj, (Fraction, int)) + +POLY_COLOR = None +# POLY_COLOR = '36' # Uncomment to have color + + +# may not work. see +# https://en.wikipedia.org/wiki/Unicode_subscripts_and_superscripts +# to test with your font +superscripts = '⁰¹²³⁴⁵⁶⁷⁸⁹' # must be <sup>0123456789</sup> + + +def int2sup(n): + res = [] + while n > 0: + res.append(superscripts[n % 10]) + n //= 10 + return ''.join(reversed(res)) + + +class Poly: + def __init__(self, vec: Iterable, letter='x'): + ''' + vec: big-endian coefficients + ''' + self.vec = list(vec) + self.letter = letter + self._trim_zeros() + + def _trim_zeros(self): + ''' + Trim trailing zeros in coefficients: [0, 1, 1, 0] -> [0, 1, 1] + Needed for multiplication to work correctly + ''' + len_zeros = 0 + for c in reversed(self.vec): + if c != 0: + break + len_zeros += 1 + if len_zeros > 0: + self.vec = self.vec[:-len_zeros] + + def __mul__(self, rhs: Union[Poly, Fraction, int]): + if is_scalar(rhs): + return Poly([c * rhs for c in self.vec]) + + elif isinstance(rhs, Poly): + # no fft 4 u + result = [0] * (self.deg() + rhs.deg() + 1) + self._trim_zeros() + rhs._trim_zeros() + for i in range(len(self.vec)): + for j in range(len(rhs.vec)): + if all(c != 0 for c in (self.vec[i], rhs.vec[j])): + result[i + j] += self.vec[i] * rhs.vec[j] + return Poly(result) + + else: + raise TypeError(f'{type(rhs)} not supported') + + def __truediv__(self, rhs: Union[Fraction, int]): + if is_scalar(rhs): + return Poly([c * Fraction(1, rhs) for c in self.vec]) + else: + raise TypeError(f'{type(rhs)} not supported') + + def __add__(self, rhs: Union[Poly, Fraction, int]): + if is_scalar(rhs): + return Poly([self.vec[0] + rhs] + self.vec[1:]) + elif isinstance(rhs, Poly): + result = [] + for pair in zip_longest(self.vec, rhs.vec): + result.append(sum(c for c in pair if c is not None)) + return Poly(result) + + def __sub__(self, rhs: Union[Poly, Fraction, int]): + return self + (-rhs) + + def __neg__(self): + return Poly(-c for c in self.vec) + + def deg(self): + ''' + Get degree of poly + ''' + d = len(self.vec) - 1 + for c in reversed(self.vec): + if c != 0: + return d + else: + d -= 1 + return max(d, 0) + + def shift(self, deg): + ''' + divide by x^deg, drop rest + ''' + for i in range(deg): + del self.vec[0] + + def __repr__(self): + return self.get_hint() + + def get_hint(self): + ''' + Get raw text hints for outer structures to align items + ''' + rev = list(reversed(self.vec)) + if all(r == 0 for r in rev): + return '0' + d = self.deg() + res = [] + for i in range(len(rev)): + if rev[i] == 0: + continue + if i == 0: + if rev[i] < 0: + res.append('-') + else: + if rev[i] < 0: + res.append(' - ') + else: + res.append(' + ') + + if abs(rev[i]) != 1 or i == d: + res.append(str(abs(rev[i]))) + if i != d: + res.append(self.letter) + if d - i != 1: + res.append(int2sup(d - i)) + return ''.join(res) + + +class Row: + ''' + Matrix row. + ''' + + def __init__(self, it: Iterable): + def make_poly(obj): + if isinstance(obj, (int, Fraction)): + return Poly([obj]) + elif isinstance(obj, Poly): + return obj + else: + raise TypeError(f'{type(obj)} not supported') + + self.lst = list(map(make_poly, it)) + self.hints = 1 * len(self.lst) + + def __add__(self, obj: Row): + ''' + Add another row + ''' + assert isinstance(obj, Row) + assert len(obj.lst) == len(self.lst) + pairs = zip(self.lst, obj.lst) + return Row(map(lambda x: x[0] + x[1], pairs)) + + def __mul__(self, k: Union[int, Fraction]): + ''' + Multiply by int or fractions.Fraction + ''' + assert isinstance(k, (int, Fraction)) + return Row(map(lambda x: x * k, self.lst)) + + def __iter__(self): + return iter(self.lst) + + def __getitem__(self, key): + return self.lst[key] + + def __setitem__(self, key, val): + self.lst[key] = val + + def __len__(self): + return len(self.lst) + + def __repr__(self): + parts = [] + for el, hint in zip(self.lst, self.hints): + part = '{el: >{hint}}'.format(el=repr(el), hint=hint) + if el.deg() > 0 and POLY_COLOR is not None: + part = '\x1b[' + POLY_COLOR + 'm' + part + '\x1b[0m' + parts.append(part) + + return ('[' + ' '.join(parts) + ']') + + def get_hints(self): + ''' + Get hints for other structures to align items + ''' + return [len(el.get_hint()) for el in self.lst] + + def set_hints(self, hints): + ''' + Set hints to align items within row + ''' + self.hints = hints + + +class Matrix: + ''' + Matrix. You can apply three elementary transforms to self. + ''' + + def __init__(self, rows): + def make_row(obj): + if isinstance(obj, Row): + return obj + else: + return Row(obj) + + self.rows = list(map(make_row, rows)) + assert all(len(row) == len(self.rows[0]) for row in self.rows) + + def make_S(self, i: int, j: int, lbd: Union[int, Fraction], axis=0): + ''' + if axis == 0, do transform on rows, else on cols + M[i] = M[i] + M[j] * lbd + ''' + if axis == 0: + self.rows[i] = self.rows[i] + self.rows[j] * lbd + else: + for row in self.rows: + row[i] = row[i] + row[j] * lbd + + def make_U(self, i: int, j: int, axis=0): + ''' + if axis == 0, do transform on rows, else on cols + Swap M[i] and M[j] + ''' + if axis == 0: + self.rows[i], self.rows[j] = self.rows[j], self.rows[i] + else: + for row in self.rows: + row[i], row[j] = row[j], row[i] + + def make_D(self, i: int, lbd: Union[int, Fraction], axis=0): + ''' + if axis == 0, do transform on rows, else on cols + Multiply M[i] by rational lbd + ''' + if axis == 0: + self.rows[i] = self.rows[i] * lbd + else: + for row in self.rows: + row[i] = row[i] * lbd + + def det(self) -> Poly: + ''' + Get determinant of matrix + ''' + assert all(len(row) == len(self.rows) for row in self.rows) + if len(self.rows) == 1: + return self.rows[0][0] + res = Poly([0]) + try: + for i in range(len(self.rows[0])): + cofactor = Matrix([ + self.rows[j][:i] + self.rows[j][i + 1:] + for j in range(1, len(self.rows)) + ]) + res += cofactor.det() * (-1) ** i * self.rows[0][i] + except Exception as e: + print(self) + raise e + + return res + + def __sub__(self, oth): + return self + (-oth) + + def __neg__(self): + rows = [row * (-1) for row in self.rows] + return Matrix(rows) + + def __add__(self, oth): + rows = [row1 + row2 for row1, row2 in zip(self.rows, oth.rows)] + return Matrix(rows) + + def __repr__(self): + hints = [row.get_hints() for row in self.rows] + max_hints = [] + for i in range(len(hints)): + max_hints.append(max(hints[j][i] for j in range(len(hints)))) + + for row in self.rows: + row.set_hints(max_hints) + + return '\n'.join(repr(row) for row in self.rows) |