summaryrefslogtreecommitdiffstats
path: root/libsolve.py
diff options
context:
space:
mode:
Diffstat (limited to 'libsolve.py')
-rw-r--r--libsolve.py298
1 files changed, 298 insertions, 0 deletions
diff --git a/libsolve.py b/libsolve.py
new file mode 100644
index 0000000..d932dd6
--- /dev/null
+++ b/libsolve.py
@@ -0,0 +1,298 @@
+from __future__ import annotations
+from typing import Iterable, Union
+from fractions import Fraction
+from itertools import zip_longest
+
+
+def is_scalar(obj):
+ return isinstance(obj, (Fraction, int))
+
+POLY_COLOR = None
+# POLY_COLOR = '36' # Uncomment to have color
+
+
+# may not work. see
+# https://en.wikipedia.org/wiki/Unicode_subscripts_and_superscripts
+# to test with your font
+superscripts = '⁰¹²³⁴⁵⁶⁷⁸⁹' # must be <sup>0123456789</sup>
+
+
+def int2sup(n):
+ res = []
+ while n > 0:
+ res.append(superscripts[n % 10])
+ n //= 10
+ return ''.join(reversed(res))
+
+
+class Poly:
+ def __init__(self, vec: Iterable, letter='x'):
+ '''
+ vec: big-endian coefficients
+ '''
+ self.vec = list(vec)
+ self.letter = letter
+ self._trim_zeros()
+
+ def _trim_zeros(self):
+ '''
+ Trim trailing zeros in coefficients: [0, 1, 1, 0] -> [0, 1, 1]
+ Needed for multiplication to work correctly
+ '''
+ len_zeros = 0
+ for c in reversed(self.vec):
+ if c != 0:
+ break
+ len_zeros += 1
+ if len_zeros > 0:
+ self.vec = self.vec[:-len_zeros]
+
+ def __mul__(self, rhs: Union[Poly, Fraction, int]):
+ if is_scalar(rhs):
+ return Poly([c * rhs for c in self.vec])
+
+ elif isinstance(rhs, Poly):
+ # no fft 4 u
+ result = [0] * (self.deg() + rhs.deg() + 1)
+ self._trim_zeros()
+ rhs._trim_zeros()
+ for i in range(len(self.vec)):
+ for j in range(len(rhs.vec)):
+ if all(c != 0 for c in (self.vec[i], rhs.vec[j])):
+ result[i + j] += self.vec[i] * rhs.vec[j]
+ return Poly(result)
+
+ else:
+ raise TypeError(f'{type(rhs)} not supported')
+
+ def __truediv__(self, rhs: Union[Fraction, int]):
+ if is_scalar(rhs):
+ return Poly([c * Fraction(1, rhs) for c in self.vec])
+ else:
+ raise TypeError(f'{type(rhs)} not supported')
+
+ def __add__(self, rhs: Union[Poly, Fraction, int]):
+ if is_scalar(rhs):
+ return Poly([self.vec[0] + rhs] + self.vec[1:])
+ elif isinstance(rhs, Poly):
+ result = []
+ for pair in zip_longest(self.vec, rhs.vec):
+ result.append(sum(c for c in pair if c is not None))
+ return Poly(result)
+
+ def __sub__(self, rhs: Union[Poly, Fraction, int]):
+ return self + (-rhs)
+
+ def __neg__(self):
+ return Poly(-c for c in self.vec)
+
+ def deg(self):
+ '''
+ Get degree of poly
+ '''
+ d = len(self.vec) - 1
+ for c in reversed(self.vec):
+ if c != 0:
+ return d
+ else:
+ d -= 1
+ return max(d, 0)
+
+ def shift(self, deg):
+ '''
+ divide by x^deg, drop rest
+ '''
+ for i in range(deg):
+ del self.vec[0]
+
+ def __repr__(self):
+ return self.get_hint()
+
+ def get_hint(self):
+ '''
+ Get raw text hints for outer structures to align items
+ '''
+ rev = list(reversed(self.vec))
+ if all(r == 0 for r in rev):
+ return '0'
+ d = self.deg()
+ res = []
+ for i in range(len(rev)):
+ if rev[i] == 0:
+ continue
+ if i == 0:
+ if rev[i] < 0:
+ res.append('-')
+ else:
+ if rev[i] < 0:
+ res.append(' - ')
+ else:
+ res.append(' + ')
+
+ if abs(rev[i]) != 1 or i == d:
+ res.append(str(abs(rev[i])))
+ if i != d:
+ res.append(self.letter)
+ if d - i != 1:
+ res.append(int2sup(d - i))
+ return ''.join(res)
+
+
+class Row:
+ '''
+ Matrix row.
+ '''
+
+ def __init__(self, it: Iterable):
+ def make_poly(obj):
+ if isinstance(obj, (int, Fraction)):
+ return Poly([obj])
+ elif isinstance(obj, Poly):
+ return obj
+ else:
+ raise TypeError(f'{type(obj)} not supported')
+
+ self.lst = list(map(make_poly, it))
+ self.hints = 1 * len(self.lst)
+
+ def __add__(self, obj: Row):
+ '''
+ Add another row
+ '''
+ assert isinstance(obj, Row)
+ assert len(obj.lst) == len(self.lst)
+ pairs = zip(self.lst, obj.lst)
+ return Row(map(lambda x: x[0] + x[1], pairs))
+
+ def __mul__(self, k: Union[int, Fraction]):
+ '''
+ Multiply by int or fractions.Fraction
+ '''
+ assert isinstance(k, (int, Fraction))
+ return Row(map(lambda x: x * k, self.lst))
+
+ def __iter__(self):
+ return iter(self.lst)
+
+ def __getitem__(self, key):
+ return self.lst[key]
+
+ def __setitem__(self, key, val):
+ self.lst[key] = val
+
+ def __len__(self):
+ return len(self.lst)
+
+ def __repr__(self):
+ parts = []
+ for el, hint in zip(self.lst, self.hints):
+ part = '{el: >{hint}}'.format(el=repr(el), hint=hint)
+ if el.deg() > 0 and POLY_COLOR is not None:
+ part = '\x1b[' + POLY_COLOR + 'm' + part + '\x1b[0m'
+ parts.append(part)
+
+ return ('[' + ' '.join(parts) + ']')
+
+ def get_hints(self):
+ '''
+ Get hints for other structures to align items
+ '''
+ return [len(el.get_hint()) for el in self.lst]
+
+ def set_hints(self, hints):
+ '''
+ Set hints to align items within row
+ '''
+ self.hints = hints
+
+
+class Matrix:
+ '''
+ Matrix. You can apply three elementary transforms to self.
+ '''
+
+ def __init__(self, rows):
+ def make_row(obj):
+ if isinstance(obj, Row):
+ return obj
+ else:
+ return Row(obj)
+
+ self.rows = list(map(make_row, rows))
+ assert all(len(row) == len(self.rows[0]) for row in self.rows)
+
+ def make_S(self, i: int, j: int, lbd: Union[int, Fraction], axis=0):
+ '''
+ if axis == 0, do transform on rows, else on cols
+ M[i] = M[i] + M[j] * lbd
+ '''
+ if axis == 0:
+ self.rows[i] = self.rows[i] + self.rows[j] * lbd
+ else:
+ for row in self.rows:
+ row[i] = row[i] + row[j] * lbd
+
+ def make_U(self, i: int, j: int, axis=0):
+ '''
+ if axis == 0, do transform on rows, else on cols
+ Swap M[i] and M[j]
+ '''
+ if axis == 0:
+ self.rows[i], self.rows[j] = self.rows[j], self.rows[i]
+ else:
+ for row in self.rows:
+ row[i], row[j] = row[j], row[i]
+
+ def make_D(self, i: int, lbd: Union[int, Fraction], axis=0):
+ '''
+ if axis == 0, do transform on rows, else on cols
+ Multiply M[i] by rational lbd
+ '''
+ if axis == 0:
+ self.rows[i] = self.rows[i] * lbd
+ else:
+ for row in self.rows:
+ row[i] = row[i] * lbd
+
+ def det(self) -> Poly:
+ '''
+ Get determinant of matrix
+ '''
+ assert all(len(row) == len(self.rows) for row in self.rows)
+ if len(self.rows) == 1:
+ return self.rows[0][0]
+ res = Poly([0])
+ try:
+ for i in range(len(self.rows[0])):
+ cofactor = Matrix([
+ self.rows[j][:i] + self.rows[j][i + 1:]
+ for j in range(1, len(self.rows))
+ ])
+ res += cofactor.det() * (-1) ** i * self.rows[0][i]
+ except Exception as e:
+ print(self)
+ raise e
+
+ return res
+
+ def __sub__(self, oth):
+ return self + (-oth)
+
+ def __neg__(self):
+ rows = [row * (-1) for row in self.rows]
+ return Matrix(rows)
+
+ def __add__(self, oth):
+ rows = [row1 + row2 for row1, row2 in zip(self.rows, oth.rows)]
+ return Matrix(rows)
+
+ def __repr__(self):
+ hints = [row.get_hints() for row in self.rows]
+ max_hints = []
+ for i in range(len(hints)):
+ max_hints.append(max(hints[j][i] for j in range(len(hints))))
+
+ for row in self.rows:
+ row.set_hints(max_hints)
+
+ return '\n'.join(repr(row) for row in self.rows)