summaryrefslogtreecommitdiffstats
path: root/4/libsolve2.py
diff options
context:
space:
mode:
Diffstat (limited to '4/libsolve2.py')
-rw-r--r--4/libsolve2.py299
1 files changed, 299 insertions, 0 deletions
diff --git a/4/libsolve2.py b/4/libsolve2.py
new file mode 100644
index 0000000..d723ba3
--- /dev/null
+++ b/4/libsolve2.py
@@ -0,0 +1,299 @@
+from __future__ import annotations
+from typing import Iterable, Union
+from fractions import Fraction
+from itertools import zip_longest
+
+
+def is_scalar(obj):
+ return isinstance(obj, (Fraction, int))
+
+
+class PolyRenderBase:
+ def render(self, poly: Poly) -> str:
+ raise NotImplementedError()
+
+ def hint(self, poly: Poly) -> int:
+ raise NotImplementedError()
+
+
+class Poly:
+ def __init__(self, vec: Iterable, letter='x', renderer=None):
+ '''
+ vec: big-endian coefficients
+ '''
+ from render import UnicodeRender
+ self.vec = list(vec)
+ self.letter = letter
+ self.renderer = renderer or UnicodeRender()
+ self._trim_zeros()
+
+ def _trim_zeros(self):
+ '''
+ Trim trailing zeros in coefficients: [0, 1, 1, 0] -> [0, 1, 1]
+ Needed for multiplication to work correctly
+ '''
+ len_zeros = 0
+ for c in reversed(self.vec):
+ if c != 0:
+ break
+ len_zeros += 1
+ if len_zeros > 0:
+ self.vec = self.vec[:-len_zeros]
+ if len(self.vec) == 0:
+ self.vec = [0]
+
+ def __mul__(self, rhs: Union[Poly, Fraction, int]):
+ if is_scalar(rhs):
+ return Poly([c * rhs for c in self.vec])
+
+ elif isinstance(rhs, Poly):
+ # no fft 4 u
+ result = [0] * (self.deg() + rhs.deg() + 1)
+ self._trim_zeros()
+ rhs._trim_zeros()
+ for i in range(len(self.vec)):
+ for j in range(len(rhs.vec)):
+ if all(c != 0 for c in (self.vec[i], rhs.vec[j])):
+ result[i + j] += self.vec[i] * rhs.vec[j]
+ return Poly(result)
+
+ else:
+ raise TypeError(f'{type(rhs)} not supported')
+
+ def __truediv__(self, rhs: Union[Fraction, int]):
+ if is_scalar(rhs):
+ return Poly([c * Fraction(1, rhs) for c in self.vec])
+ else:
+ raise TypeError(f'{type(rhs)} not supported')
+
+ def __add__(self, rhs: Union[Poly, Fraction, int]):
+ if is_scalar(rhs):
+ return Poly([self.vec[0] + rhs] + self.vec[1:])
+ elif isinstance(rhs, Poly):
+ result = []
+ for pair in zip_longest(self.vec, rhs.vec):
+ result.append(sum(c for c in pair if c is not None))
+ return Poly(result)
+
+ def __sub__(self, rhs: Union[Poly, Fraction, int]):
+ return self + (-rhs)
+
+ def __neg__(self):
+ return Poly(-c for c in self.vec)
+
+ def __eq__(self, other):
+ return self.vec == other.vec
+
+ def deg(self):
+ '''
+ Get degree of poly
+ '''
+ return len(self.vec) - 1
+
+ def __repr__(self):
+ return self.renderer.render(self)
+
+ def hint(self):
+ return self.renderer.hint(self)
+
+
+class Row:
+ '''
+ Matrix row.
+ '''
+
+ def __init__(self, it: Iterable, dtype=Poly):
+ def make_poly(obj):
+ if isinstance(obj, (int, Fraction)):
+ return Poly([obj])
+ elif isinstance(obj, Poly):
+ return obj
+ else:
+ raise TypeError(f'{type(obj)} not supported')
+
+ self.lst = list(map(make_poly, it))
+ self.hints = [1] * len(self.lst)
+
+ def __add__(self, obj: Row):
+ '''
+ Add another row
+ '''
+ assert isinstance(obj, Row)
+ assert len(obj.lst) == len(self.lst)
+ pairs = zip(self.lst, obj.lst)
+ return Row(map(lambda x: x[0] + x[1], pairs))
+
+ def __mul__(self, k: Union[int, Fraction, Poly]):
+ '''
+ Multiply by int or fractions.Fraction
+ '''
+ assert isinstance(k, (int, Fraction, Poly))
+ return Row(map(lambda x: x * k, self.lst))
+
+ def __iter__(self):
+ return iter(self.lst)
+
+ def __getitem__(self, key):
+ return self.lst[key]
+
+ def __setitem__(self, key, val):
+ self.lst[key] = val
+
+ def __len__(self):
+ return len(self.lst)
+
+ def __repr__(self):
+ parts = []
+ for el, hint in zip(self.lst, self.hints):
+ part = '{el: >{hint}}'.format(el=repr(el), hint=hint)
+ if el.deg() > 0 and POLY_COLOR is not None:
+ part = '\x1b[' + POLY_COLOR + 'm' + part + '\x1b[0m'
+ parts.append(part)
+
+ return ('[' + ' '.join(parts) + ']')
+
+ def cloned(self):
+ return Row(self.lst)
+
+ def get_hints(self):
+ '''
+ Get hints for other structures to align items
+ '''
+ return [el.hint() for el in self.lst]
+
+ def set_hints(self, hints):
+ '''
+ Set hints to align items within row
+ '''
+ self.hints = hints
+
+
+class Matrix:
+ '''
+ Matrix. You can apply three elementary transforms to self.
+ '''
+
+ def __init__(self, rows):
+ def make_row(obj):
+ if isinstance(obj, Row):
+ return obj.cloned()
+ else:
+ return Row(obj)
+
+ self.rows = list(map(make_row, rows))
+ assert all(len(row) == len(self.rows[0]) for row in self.rows)
+
+ def det(self) -> Poly:
+ '''
+ Get determinant of matrix
+ '''
+ assert all(len(row) == len(self.rows) for row in self.rows)
+ if len(self.rows) == 1:
+ return self.rows[0][0]
+ res = Poly([0])
+ try:
+ for i in range(len(self.rows[0])):
+ cofactor = Matrix([
+ self.rows[j][:i] + self.rows[j][i + 1:]
+ for j in range(1, len(self.rows))
+ ])
+ res += cofactor.det() * (-1) ** i * self.rows[0][i]
+ except Exception as e:
+ print(self)
+ raise e
+
+ return res
+
+ def __sub__(self, oth):
+ return self + (-oth)
+
+ def __neg__(self):
+ rows = [row * (-1) for row in self.rows]
+ return Matrix(rows)
+
+ def __add__(self, oth):
+ rows = [row1 + row2 for row1, row2 in zip(self.rows, oth.rows)]
+ return Matrix(rows)
+
+ def __getitem__(self, idx):
+ return self.rows[idx]
+
+ def __setitem__(self, idx, val):
+ self.rows[idx] = val
+
+ @property
+ def shape(self):
+ return len(self.rows), len(self.rows[0])
+
+ def __matmul__(self, other):
+ assert self.shape[1] == other.shape[0]
+ rows = [
+ [
+ sum((
+ self.rows[j][i] * other.rows[i][k]
+ for i in range(self.shape[1])
+ ), Poly([0]))
+ for k in range(other.shape[1])
+ ] for j in range(self.shape[0])
+ ]
+ return Matrix(rows)
+
+ def __mul__(self, other: Union[int, Fraction, Poly]):
+ if not isinstance(other, (Poly, Fraction, int)):
+ raise TypeError(f'{type(other)} not supported for mul')
+
+ return Matrix([
+ row * other for row in self.rows
+ ])
+
+ def __repr__(self):
+ hints = [row.get_hints() for row in self.rows]
+ max_hints = [
+ max(hints[j][i] for j in range(len(hints)))
+ for i in range(len(hints[0]))
+ ]
+
+ for row in self.rows:
+ row.set_hints(max_hints)
+
+ return '\n'.join(repr(row) for row in self.rows)
+
+ def cloned(self):
+ return Matrix(self.rows)
+
+ def to_tex(self):
+ parts = []
+ for row in self.rows:
+ parts.append(' & '.join(map(str, row.lst)))
+ return '\\begin{bmatrix}\n' + '\\\\\n'.join(parts) + '\n\\end{bmatrix}'
+
+ def rank(self):
+ tri = self.triangled()
+ return self.shape[0] - sum(
+ all(el == Poly([0]) for el in row) for row in tri.rows
+ )
+
+ def triangled(self, swap=False):
+ copy = self.cloned()
+ for i in range(copy.shape[1]):
+ for j in range(i, copy.shape[0]):
+ assert(copy[j][i].deg() == 0)
+ if copy[j][i].vec[0] != 0:
+ # check we are main variable
+ if not all(copy[j][k].vec[0] == 0 for k in range(i)):
+ continue
+
+ for k in range(copy.shape[0]):
+ assert(copy[k][i].deg() == 0)
+ if k == j:
+ continue
+ coef = -Fraction(copy[k][i].vec[0]) / Fraction(copy[j][i].vec[0])
+ copy[k] = copy[k] + copy[j] * coef
+
+ if swap:
+ copy[j], copy[i] = copy[i], copy[j]
+ break
+ return copy
+
+
+x = Poly([0, 1])