summaryrefslogtreecommitdiffstats
path: root/src/generate.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/generate.py')
-rw-r--r--src/generate.py454
1 files changed, 454 insertions, 0 deletions
diff --git a/src/generate.py b/src/generate.py
new file mode 100644
index 0000000..da86816
--- /dev/null
+++ b/src/generate.py
@@ -0,0 +1,454 @@
+from __future__ import annotations
+from lark import Lark, Token
+from dataclasses import dataclass, field as dataclass_field
+from collections import defaultdict
+from typing import Optional
+import sys
+
+wanted_types = {
+ 'User',
+ 'Chat',
+ 'Message',
+ 'Error',
+ 'Ok',
+ 'TdlibParameters',
+ 'PhoneNumberAuthenticationSettings',
+}
+
+wanted_classes = {
+ 'AuthorizationState',
+ #'MessageContent',
+ 'Update',
+ 'UserStatus',
+}
+
+wanted_methods = {
+ 'get_me',
+ 'set_tdlib_parameters',
+ 'get_network_statistics',
+ 'get_application_config',
+ 'get_authorization_state',
+ 'set_database_encryption_key',
+ 'set_authentication_phone_number',
+ 'check_authentication_code',
+ 'check_authentication_password',
+}
+
+if not "REMOVE 'NOT' IF YOU WANT EVERYTHING TO BE RENDERED":
+ with open('everything.json') as f:
+ import json
+ everything = json.load(f)
+ wanted_classes = list(everything['classes'].keys())
+ wanted_types = everything['types']
+ wanted_methods = everything['methods']
+
+
+import re
+# https://stackoverflow.com/a/1176023/6938271
+MIXED_2_SNAKE_CASE = re.compile(r'(?<!^)(?=[A-Z])')
+
+CLASS_EXCLUDE_ALWAYS = [
+ 'JsonValue'
+]
+
+BOXED_TYPES = [
+ 'RichText',
+ 'PageBlock'
+]
+
+TYPE_EXCLUDE_ALWAYS = [
+ 'JsonValueNull',
+ 'JsonValueBoolean',
+ 'JsonValueNumber',
+ 'JsonValueString',
+ 'JsonValueArray',
+ 'JsonValueObject'
+]
+
+
+def eprint(*args, **kwargs):
+ print(*args, file=sys.stderr, **kwargs)
+
+
+def to_snake_case(ident):
+ return MIXED_2_SNAKE_CASE.sub('_', ident).lower()
+
+
+def to_camel_case(ident):
+ if len(ident) == 0:
+ return ''
+ return ident[0].upper() + ident[1:]
+
+
+def escape_doc(doc):
+ return doc.translate(str.maketrans({"\"": '\\"', "\\": "\\\\"})).replace('\n', ' \\n')
+
+
+@dataclass
+class Mod:
+ depth: Optional[int] = None
+ exclude: bool = False
+
+
+@dataclass
+class Field:
+ name: str
+ type_: str
+ doc: str
+ mod: Mod = dataclass_field(default_factory=Mod)
+ optional: bool = False
+ doc_modifier: Optional[object] = None
+
+ def __post_init__(self):
+ self.doc = escape_doc(self.doc)
+ assert(self.name == to_snake_case(self.name))
+ self.check_optional()
+
+ def check_optional(self):
+ optional_heuristics = {
+ 'may be null',
+ 'only available to bots',
+ 'bots only',
+ 'or null'
+ }
+ for s in optional_heuristics:
+ if s in self.doc:
+ self.optional = True
+ break
+
+ def is_literally_type(self) -> bool:
+ return self.name == 'type'
+
+ def is_jsonvalue(self, all_types: dict) -> bool:
+ type_ = all_types[self.type_]
+ return type_.exclude
+
+ def get_typename(self) -> str:
+ typename = self.type_
+ if self.mod.exclude:
+ typename = f'SerdeJsonValue'
+ if self.optional:
+ typename = f'Option<{typename}>'
+ if self.mod.exclude:
+ self.doc_modifier = f' \\n\\nOriginal type: {typename}'
+
+ return typename
+
+ def get_doc(self):
+ return self.doc + (self.doc_modifier or '')
+
+ def serde_rename(self):
+ return f'#[serde(rename = "{self.orig_name}")]'
+
+ def get_name(self):
+ return 'type_' if self.is_literally_type() else self.name
+
+
+@dataclass
+class Type:
+ name: str
+ doc: str
+ fields: list[Field]
+ exclude: bool = False
+ non_camel_name: Optional[str] = None
+
+ def __post_init__(self):
+ self.doc = escape_doc(self.doc)
+ self.non_camel_name = self.name
+ self.name = to_camel_case(self.name)
+ if self.name not in wanted_types:
+ self.exclude = True
+
+
+@dataclass
+class Class:
+ name: str
+ doc: Optional[str]
+ members: list[str]
+ non_camel_name: Optional[str] = None
+
+ def __post_init__(self):
+ self.doc = escape_doc(self.doc)
+ self.non_camel_name = self.name
+ self.name = to_camel_case(self.name)
+
+@dataclass
+class Method:
+ name: str
+ orig_name: str
+ doc: str
+ params: list[str]
+ ret: str
+ orig_ret: str
+
+ def __post_init__(self):
+ self.doc = escape_doc(self.doc)
+ self.name = to_snake_case(self.name)
+
+
+def convert_param_type(raw_type) -> tuple[str, Mod]:
+ '''
+ return type and mod
+ '''
+ if isinstance(raw_type, Token):
+ value = raw_type.value
+ if value == 'string':
+ return 'String', Mod()
+ elif value == 'int32':
+ return 'i32', Mod()
+ elif value == 'int53':
+ return 'i64', Mod()
+ elif value == 'int64':
+ return 'i64', Mod(depth=0)
+ elif value == 'double':
+ return 'f64', Mod()
+ elif value == 'bytes':
+ return 'String', Mod()
+ elif value == 'Bool':
+ return 'bool', Mod()
+ else:
+ final_type = to_camel_case(raw_type)
+ mod = Mod()
+ if final_type not in wanted_types and final_type not in wanted_classes:
+ mod.exclude = True
+ if final_type in BOXED_TYPES:
+ return f'Box<{final_type}>', mod
+
+ return final_type, mod
+ else:
+ inner, mod = convert_param_type(raw_type.children[0])
+ if mod.depth is not None:
+ mod.depth += 1
+ return 'Vec<' + inner + '>', mod
+
+def parse_decl(decl) -> tuple: # -> (description, cname, list[Field], cname)
+ docs, cname1, params, cname2 = decl.children
+ docs = {
+ doc.children[0]: ''.join(doc.children[1].children)
+ for doc in docs.children
+ }
+ params = params.children
+ fields = []
+ for param in params:
+ name, raw_type = param.children
+ type_, mod = convert_param_type(raw_type)
+ doc = docs[name]
+ if name == 'description':
+ doc = docs['param_description']
+ fields.append(Field(name=name, type_=type_, doc=doc, mod=mod))
+ return docs['description'], cname1, fields, cname2
+
+
+try:
+ with open('tree_cache.pkl', "rb") as tree:
+ import pickle
+ parsed = pickle.load(tree)
+ eprint('using cached tree')
+except:
+ grammar = r'''
+start: type_decls "---functions---" "\n"* decls
+type_decls: (decl | class_decl)+
+decls: decl+
+
+decl: docstring CNAME params "=" type ";" "\n"+
+params: param*
+param: CNAME ":" type
+
+class_decl: "//@class" CNAME nameddoc "\n"+
+
+docdescr: ((LF "//-")? DOCTEXT)*
+nameddoc: "@" CNAME " " docdescr
+docstring: ("//" nameddoc+ "\n"?)+
+
+?type: CNAME | vector
+vector: "vector" "<" type ">"
+
+
+DOCTEXT: /[^\n@]+/
+WHITESPACE: (" ")
+
+%import common.CNAME
+%ignore WHITESPACE
+%import common.LF
+ '''
+ lark = Lark(grammar)
+
+ with open("td_api.tl") as f:
+ for i in range(14):
+ f.readline()
+ text = f.read()
+ parsed = lark.parse(text)
+ eprint('lark parsed')
+
+ with open("tree_cache.pkl", "wb") as tree:
+ import pickle
+ pickle.dump(parsed, tree)
+ eprint("pickled cache")
+
+parsed_types, parsed_methods = parsed.children
+
+#print(parsed_types)
+
+types = dict()
+classes = dict()
+methods = dict()
+
+for decl in parsed_types.children:
+ if decl.data == 'decl':
+ description, type_name, fields, base_type = parse_decl(decl)
+ type_name_orig = type_name
+ base_type_orig = base_type
+ type_name = to_camel_case(type_name.value)
+ base_type = to_camel_case(base_type.value)
+
+ types[type_name] = Type(name=type_name_orig, doc=description, fields=fields)
+ if type_name != base_type:
+ if base_type not in classes:
+ classes[base_type] = Class(name=base_type_orig, doc=None, members=[type_name])
+ else:
+ classes[base_type].members.append(type_name)
+
+ else:
+ classname, doc = decl.children
+ class_name_camel = to_camel_case(classname)
+ doc = ''.join(doc.children[1].children)
+ if classname not in classes:
+ classes[class_name_camel] = Class(name=classname, doc=doc, members=[])
+ else:
+ classes[class_name_camel].doc = doc
+
+classes = dict([(k, v) for (k, v) in classes.items() if k not in CLASS_EXCLUDE_ALWAYS])
+types = dict([(k, v) for (k, v) in types.items() if k not in TYPE_EXCLUDE_ALWAYS])
+
+for w in wanted_classes:
+ if w not in [c for c in classes]:
+ eprint(f'WARN: {w} class is wanted, but not found')
+
+for w in wanted_types:
+ if w not in [t for t in types]:
+ eprint(f'WARN: {w} type is wanted, but not found')
+
+eprint('parsed types & classes')
+
+for decl in parsed_methods.children:
+ docs, name, params, ret = parse_decl(decl)
+ snake_name = to_snake_case(name)
+ orig_ret = to_camel_case(ret)
+ if orig_ret in types:
+ if types[orig_ret].exclude:
+ ret = 'SerdeJsonValue'
+ elif orig_ret in classes:
+ ret = orig_ret
+ else:
+ ret = 'SerdeJsonValue'
+
+ methods[name] = Method(name=snake_name, orig_name=name, doc=docs, params=params, ret=ret, orig_ret=orig_ret)
+
+
+for w in wanted_methods:
+ if w not in [m.name for m in methods.values()]:
+ eprint(f'WARN: {w} method is wanted, but not found')
+eprint('parsed methods')
+
+for class_ in [cls for cls in classes.values() if cls.name in wanted_classes]:
+ for member in class_.members:
+ type_ = types[member]
+ if len(type_.fields) == 0:
+ type_.exclude = True
+ else:
+ type_.exclude = False
+
+print('''
+#![allow(unused)]
+use serde::Deserializer;
+use tdlib_rs::client::ClientLike;
+
+use serde_derive::{Serialize, Deserialize};
+use serde_json::{json, Value as SerdeJsonValue};
+use tdlib_rs::Client;
+use tdlib_rs::client::ResponseFuture;
+use super::{deserialize_i64_0, deserialize_i64_1};
+'''.lstrip())
+
+for type_ in [tp for tp in types.values() if not tp.exclude]:
+ print(f'#[derive(Serialize, Deserialize, Debug, Clone)]')
+ print(f'#[doc="{type_.doc}"]')
+ print(f'pub struct {type_.name} {{')
+ for field in type_.fields:
+ typename = field.get_typename()
+ doc = field.get_doc()
+ name = field.name
+ if field.is_literally_type():
+ print(f' #[serde(rename="type")]')
+ name = 'type_'
+
+ print(f' #[doc="{doc}"]')
+ if field.mod.depth is not None:
+ print(f' #[serde(deserialize_with="deserialize_i64_{field.mod.depth}")]')
+ print(f' pub {name}: {typename},')
+
+ print(f'}}')
+
+eprint('rendered types')
+
+
+for class_ in [cls for cls in classes.values() if cls.name in wanted_classes]:
+ print(f'#[derive(Serialize, Deserialize, Debug, Clone)]')
+ print(f'#[doc="{class_.doc}"]')
+ print(f'#[serde(tag="@type")]')
+ print(f'pub enum {class_.name} {{')
+ for member in class_.members:
+ type_ = types[member]
+ if type_.name != type_.non_camel_name: # little optimization for serde not to rename good types
+ print(f' #[serde(rename = "{type_.non_camel_name}")]')
+ if len(type_.fields) == 0:
+ print(f' #[doc="{type_.doc}"]')
+ type_.exclude = True
+ print(f' {member},')
+ else:
+ type_.exclude = False
+ print(f' {member}({member}),')
+ print(f'}}')
+
+eprint('rendered classes')
+
+
+print('pub trait ClientExt: ClientLike {')
+for method in [m for m in methods.values() if m.name in wanted_methods]:
+ print(f' #[doc="{method.doc}"]')
+ for param in method.params:
+ param.get_typename()
+ doc = param.get_doc()
+ name = param.get_name()
+ print(f' #[doc=" \\n\\n"]') # ensure newline
+ print(f' #[doc="parameters: "]')
+ print(f' #[doc=" * `{name}`: {doc}"]')
+ if method.ret != method.orig_ret:
+ print(f' #[doc=" \\n\\n"]') # ensure newline
+ print(f' #[doc="Original return type: `{method.orig_ret}`"]')
+
+ print(f' fn {method.name}(&self,')
+ for param in method.params:
+ typename = param.get_typename()
+ print(f' {param.get_name()}: {typename},')
+ print(f' ) -> ResponseFuture<{method.ret}> {{')
+ print(f' self.send(json!({{')
+ for param in method.params:
+ name = 'type_' if param.is_literally_type() else param.name
+ print(f' "{param.name}": {name},')
+ print(f' "@type": "{method.orig_name}"')
+ print(f' }}))')
+ print(f' }}')
+print(f'}}')
+
+eprint('rendered methods')
+
+from os import path
+if not path.exists('everything.json'):
+ with open('everything.json', 'w') as f:
+ import json
+ json.dump(dict(
+ classes={c.name: c.members for c in classes.values()},
+ types=[t.name for t in types.values()],
+ methods=[m.name for m in methods.values()]
+ ), f, indent=4)