diff options
Diffstat (limited to 'src/client_ext/generate.py')
-rw-r--r-- | src/client_ext/generate.py | 435 |
1 files changed, 435 insertions, 0 deletions
diff --git a/src/client_ext/generate.py b/src/client_ext/generate.py new file mode 100644 index 0000000..16ed400 --- /dev/null +++ b/src/client_ext/generate.py @@ -0,0 +1,435 @@ +from __future__ import annotations +from lark import Lark, Token +from dataclasses import dataclass, field as dataclass_field +from collections import defaultdict +from typing import Optional +import sys + +# TODO: rename mixedCase + +wanted_types = [ + 'User', + 'Chat', + 'Message', + 'Error', + 'Game', + 'PhotoSize', + 'Ok', + 'TdlibParameters', + 'PhoneNumberAuthenticationSettings' +] + +wanted_classes = [ + 'AuthorizationState', + 'MessageContent', + 'Update' +] + +wanted_methods = [ + 'set_tdlib_parameters', + 'get_network_statistics', + 'get_application_config', + 'set_database_encryption_key', + 'set_authentication_phone_number', + 'check_authentication_code', + 'check_authentication_password' +] + +if not "REMOVE 'NOT' IF YOU WANT EVERYTHING TO BE RENDERED": + with open('everything.json') as f: + import json + everything = json.load(f) + wanted_classes = list(everything['classes'].keys()) + wanted_types = everything['types'] + wanted_methods = everything['methods'] + + +import re +# https://stackoverflow.com/a/1176023/6938271 +MIXED_2_SNAKE_CASE = re.compile(r'(?<!^)(?=[A-Z])') + +CLASS_EXCLUDE_ALWAYS = [ + 'JsonValue' +] + +BOXED_TYPES = [ + 'RichText', + 'PageBlock' +] + +TYPE_EXCLUDE_ALWAYS = [ + 'JsonValueNull', + 'JsonValueBoolean', + 'JsonValueNumber', + 'JsonValueString', + 'JsonValueArray', + 'JsonValueObject' +] + + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + +def to_snake_case(ident): + return MIXED_2_SNAKE_CASE.sub('_', ident).lower() + + +def to_camel_case(ident): + if len(ident) == 0: + return '' + return ident[0].upper() + ident[1:] + + +def escape_doc(doc): + return doc.translate(str.maketrans({"\"": '\\"', "\\": "\\\\"})).replace('\n', ' \\n') + + +@dataclass +class Mod: + depth: Optional[int] = None + exclude: bool = False + + +@dataclass +class Field: + name: str + type_: str + doc: str + mod: Mod = dataclass_field(default_factory=Mod) + optional: bool = False + doc_modifier: Optional[object] = None + orig_name: str = None + + def __post_init__(self): + self.doc = escape_doc(self.doc) + self.orig_name = self.name + self.name = to_snake_case(self.name) + self.check_optional() + + def check_optional(self): + optional_heuristics = { + 'may be null', + 'only available to bots', + 'bots only', + 'or null' + } + for s in optional_heuristics: + if s in self.doc: + self.optional = True + break + + def is_literally_type(self) -> bool: + return self.name == 'type' + + def is_jsonvalue(self, all_types: dict) -> bool: + type_ = all_types[self.type_] + return type_.exclude + + def get_typename(self) -> str: + typename = self.type_ + if self.optional: + typename = f'Option<{typename}>' + if self.mod.exclude: + self.doc_modifier = f' \\n\\nOriginal type: {typename}' + typename = f'SerdeJsonValue' + + return typename + + def get_doc(self): + return self.doc + (self.doc_modifier or '') + + def serde_rename(self): + return f'#[serde(rename = "{self.orig_name}")]' + + def get_name(self): + return 'type_' if self.is_literally_type() else self.name + + +@dataclass +class Type: + name: str + doc: str + fields: list[Field] + exclude: bool = False + + def __post_init__(self): + self.doc = escape_doc(self.doc) + self.name = to_camel_case(self.name) + if self.name not in wanted_types: + self.exclude = True + + +@dataclass +class Class: + name: str + doc: Optional[str] + members: list[str] + + def __post_init__(self): + self.doc = escape_doc(self.doc) + self.name = to_camel_case(self.name) + +@dataclass +class Method: + name: str + orig_name: str + doc: str + params: list[str] + ret: str + orig_ret: str + + def __post_init__(self): + self.doc = escape_doc(self.doc) + self.name = to_snake_case(self.name) + + +def convert_param_type(raw_type) -> tuple[str, Mod]: + ''' + return type and mod + ''' + if isinstance(raw_type, Token): + value = raw_type.value + if value == 'string': + return 'String', Mod() + elif value == 'int32': + return 'i32', Mod() + elif value == 'int53': + return 'i64', Mod() + elif value == 'int64': + return 'i64', Mod(depth=0) + elif value == 'double': + return 'f64', Mod() + elif value == 'bytes': + return 'String', Mod() + elif value == 'Bool': + return 'bool', Mod() + else: + final_type = to_camel_case(raw_type) + mod = Mod() + if final_type not in wanted_types and final_type not in wanted_classes: + mod.exclude = True + if final_type in BOXED_TYPES: + return f'Box<{final_type}>', mod + + return final_type, mod + else: + inner, mod = convert_param_type(raw_type.children[0]) + if mod.depth is not None: + mod.depth += 1 + return 'Vec<' + inner + '>', mod + +def parse_decl(decl) -> tuple: # -> (description, cname, list[Field], cname) + docs, cname1, params, cname2 = decl.children + docs = { + doc.children[0]: ''.join(doc.children[1].children) + for doc in docs.children + } + params = params.children + fields = [] + for param in params: + name, raw_type = param.children + type_, mod = convert_param_type(raw_type) + doc = docs[name] + if name == 'description': + doc = docs['param_description'] + fields.append(Field(name=name, type_=type_, doc=doc, mod=mod)) + return docs['description'], cname1, fields, cname2 + + +try: + with open('tree_cache.pkl', "rb") as tree: + import pickle + parsed = pickle.load(tree) + eprint('using cached tree') +except: + grammar = r''' +start: type_decls "---functions---" "\n"* decls +type_decls: (decl | class_decl)+ +decls: decl+ + +decl: docstring CNAME params "=" type ";" "\n"+ +params: param* +param: CNAME ":" type + +class_decl: "//@class" CNAME nameddoc "\n"+ + +docdescr: ((LF "//-")? DOCTEXT)* +nameddoc: "@" CNAME " " docdescr +docstring: ("//" nameddoc+ "\n"?)+ + +?type: CNAME | vector +vector: "vector" "<" type ">" + + +DOCTEXT: /[^\n@]+/ +WHITESPACE: (" ") + +%import common.CNAME +%ignore WHITESPACE +%import common.LF + ''' + lark = Lark(grammar) + + with open("td_api.tl") as f: + for i in range(14): + f.readline() + text = f.read() + parsed = lark.parse(text) + eprint('lark parsed') + + with open("tree_cache.pkl", "wb") as tree: + import pickle + pickle.dump(parsed, tree) + eprint("pickled cache") + +parsed_types, parsed_methods = parsed.children + +#print(parsed_types) + +types = dict() +classes = dict() +methods = dict() + +for decl in parsed_types.children: + if decl.data == 'decl': + description, type_name, fields, base_type = parse_decl(decl) + type_name = to_camel_case(type_name.value) + base_type = to_camel_case(base_type.value) + + types[type_name] = Type(name=type_name, doc=description, fields=fields) + if type_name != base_type: + if base_type not in classes: + classes[base_type] = Class(name=base_type, doc=None, members=[type_name]) + else: + classes[base_type].members.append(type_name) + + else: + classname, doc = decl.children + doc = ''.join(doc.children[1].children) + if classname not in classes: + classes[classname] = Class(name=classname, doc=doc, members=[]) + else: + classes[classname].doc = doc + +classes = dict([(k, v) for (k, v) in classes.items() if k.value not in CLASS_EXCLUDE_ALWAYS]) +types = dict([(k, v) for (k, v) in types.items() if k not in TYPE_EXCLUDE_ALWAYS]) + +for w in wanted_classes: + if w not in [c.value for c in classes]: + eprint(f'WARN: {w} class is wanted, but not found') + +for w in wanted_types: + if w not in [t for t in types]: + eprint(f'WARN: {w} type is wanted, but not found') + +eprint('parsed types & classes') + +for decl in parsed_methods.children: + docs, name, params, ret = parse_decl(decl) + snake_name = to_snake_case(name) + orig_ret = to_camel_case(ret) + if orig_ret in types: + if types[orig_ret].exclude: + ret = 'SerdeJsonValue' + elif orig_ret in classes: + ret = orig_ret + else: + ret = 'SerdeJsonValue' + + methods[name] = Method(name=snake_name, orig_name=name, doc=docs, params=params, ret=ret, orig_ret=orig_ret) + + +for w in wanted_methods: + if w not in [m.name for m in methods.values()]: + eprint(f'WARN: {w} method is wanted, but not found') +eprint('parsed methods') + +for class_ in [cls for cls in classes.values() if cls.name in wanted_classes]: + for member in class_.members: + type_ = types[member] + if len(type_.fields) == 0: + type_.exclude = True + else: + type_.exclude = False + +for type_ in [tp for tp in types.values() if not tp.exclude]: + print(f'#[derive(Serialize, Deserialize, Debug, Clone)]') + print(f'#[doc="{type_.doc}"]') + print(f'pub struct {type_.name} {{') + for field in type_.fields: + typename = field.get_typename() + doc = field.get_doc() + name = field.name + if field.is_literally_type(): + print(f' #[serde(rename="type")]') + name = 'type_' + + print(f' #[doc="{doc}"]') + if field.mod.depth is not None: + print(f' #[serde(deserialize_with="deserialize_i64_{field.mod.depth}")]') + print(f' pub {name}: {typename},') + + print(f'}}') + +eprint('rendered types') + + +for class_ in [cls for cls in classes.values() if cls.name in wanted_classes]: + print(f'#[derive(Serialize, Deserialize, Debug, Clone)]') + print(f'#[doc="{class_.doc}"]') + print(f'#[serde(tag="@type")]') + print(f'pub enum {class_.name} {{') + for member in class_.members: + type_ = types[member] + if len(type_.fields) == 0: + print(f' #[doc="{type_.doc}"]') + type_.exclude = True + print(f' {member},') + else: + type_.exclude = False + print(f' {member}({member}),') + print(f'}}') + +eprint('rendered classes') + + +print('pub trait ClientExt: ClientLike {') +for method in [m for m in methods.values() if m.name in wanted_methods]: + print(f' #[doc="{method.doc}"]') + for param in method.params: + param.get_typename() + doc = param.get_doc() + name = param.get_name() + print(f' #[doc=" \\n\\n"]') # ensure newline + print(f' #[doc="parameters: "]') + print(f' #[doc=" * `{name}`: {doc}"]') + if method.ret != method.orig_ret: + print(f' #[doc=" \\n\\n"]') # ensure newline + print(f' #[doc="Original return type: `{method.orig_ret}`"]') + + print(f' fn {method.name}(&self,') + for param in method.params: + typename = param.get_typename() + print(f' {param.get_name()}: {typename},') + print(f' ) -> ResponseFuture<{method.ret}> {{') + print(f' self.send(json!({{') + for param in method.params: + name = 'type_' if param.is_literally_type() else param.name + print(f' "{param.name}": {name},') + print(f' "@type": "{method.orig_name}"') + print(f' }}))') + print(f' }}') +print(f'}}') + +eprint('rendered methods') + +from os import path +if not path.exists('everything.json'): + with open('everything.json', 'w') as f: + import json + json.dump(dict( + classes={c.name: c.members for c in classes.values()}, + types=[t.name for t in types.values()], + methods=[m.name for m in methods.values()] + ), f, indent=4)
\ No newline at end of file |