# TODO: because _get_type turns ptn types into py types, something like this: exepcted _ptn_t_string but got _ptn_t_bool, is turned into this instead: exepcted str but got bool
class PistonRuntimeError(Exception):pass
def _raise_runtime(m):raise PistonRuntimeError(m)
def _assert(c,m):None if c else _raise_runtime(m)
def _extend(c,n=None,s=False):
def d(f):
fn=n if n else f.__name__;setattr(c,fn,staticmethod(f))if s else setattr(c,fn,f)
def w(*a,**k):return f(*a,**k)
return w
return d
# type util
class PistonTypeError(Exception):pass
def _get_ptn_parent_type(t):return t._t()
def _get_type(t):return _get_ptn_parent_type(t)if(isinstance(t,_ptn_type)or issubclass(t if isinstance(t,type)else type(t), _ptn_type))else t if isinstance(t,type)else type(t)
def _cmp_type(t,o):return _get_type(t)==_get_type(o)
def _ensure_iterable(o):return o if isinstance(o,(list,tuple))else[o]
def _type_of(t,o):return _get_type(t)in[_get_type(e)for e in _ensure_iterable(o)]
def _type_mismatch(t,o):raise PistonTypeError("type mismatch. expected "+("|".join([_get_type(e).__name__ for e in _ensure_iterable(o)])+" but got "+_get_type(t).__name__))
def _type_list_mismatch(i,t,o):raise PistonTypeError("type mismatch at index "+str(i)+". exepcted "+("|".join([_get_type(e).__name__ for e in _ensure_iterable(o)])+" but got "+_get_type(t).__name__))
def _type_ptn_mismatch(t):raise PistonTypeError("type mismatch. expected a piston type, but got "+_get_type(t).__name__)
def _assert_type(t,o):_type_mismatch(t,o)if not _type_of(t,o)else None
def _assert_list_type(l,t):[_type_list_mismatch(i+1,e,t)if not _type_of(e,t)else None for i,e in enumerate(l)]
def _cast_type(v,t):ns=dict(v=v,t=t);exec("try:t(v);r=False\nexcept:r=True",ns);return _type_mismatch(v,t)if(ns["r"])else t(v)
def _is_ptn_type(t):t=t if isinstance(t,type)else type(t);return issubclass(t,_ptn_type)
def _is_ptn_type_list(l):a=False;[a:=isinstance(e,_ptn_type)for e in _ensure_iterable(l)];return a;
def _assert_ptn_type(t):_type_ptn_mismatch(t)if not _is_ptn_type(t)else None
def _assert_ptn_type_list(l):[_type_ptn_mismatch(e)if not _is_ptn_type(e)else None for e in l]
def _is_types_list(l):a=False;[a:=isinstance(e,type)for e in _ensure_iterable(l)];return a
def _length_error(e,l):raise ValueError("expected length "+str(len(e))+" but got "+str(l))
def _assert_length(e,l):_length_error(e,l)if len(e)!=l else None
def _assert_min_length(e,l):_length_error(e,l)if len(e)<l else None
def _assert_max_length(e,l):_length_error(e,l)if len(e)>l else None
# match util
def _value_cb(v):return lambda:v
def _call_fn(cb,args,n):
n_args=cb.__code__.co_argcount
if n_args<=0:return cb()
else:
template="res=call({args})";str_args=[f'"{arg}"'if isinstance(arg,str)else str(arg)for arg in args];arg_dif=n_args-n
if arg_dif<=0:str_args=str_args[:arg_dif]
else:str_args.extend(["None"]*arg_dif)
f_call=template.format(args=",".join(str_args));ns=dict(__name__="match");ns["call"]=cb;exec(f_call,ns);return ns["res"]
class _Matchable(object):
def get_values(self):raise NotImplementedError
def get_n_values(self):return len(self.get_values())
def match(self,*cases,default):
for case in cases:
_assert_type(case,(tuple,list));_assert_length(case,2);comp,call=case;call=call if callable(call)else _value_cb(call)
if self.__eq__(comp):return _call_fn(call,self.get_values(),self.get_n_values())
if default:call=default if callable(default)else _value_cb(default);return _call_fn(call,self.get_values(),self.get_n_values())
# types
class _ptn_type(object):pass
class _ptn_t_integer(_ptn_type):
def __init__(s,v):s.v=_cast_type(v,int)
def __str__(s):return str(s.v)
def __repr__(s):return str(s)
def __len__(s):return s.v.bit_length()
def _cast(s,t):_assert(issubclass(t,_ptn_type),"casting type must be a piston primitive type");return t(s.v)
@staticmethod
def _t():return int
@staticmethod
def _init():return _ptn_t_integer(0)
def __float__(s):return float(s.v)
def __bool__(s):return True if s.v>0 else False
class _ptn_t_float(_ptn_type):
def __init__(s,v):s.v=_cast_type(v,float)
def __str__(s):return str(s.v)
def __repr__(s):return str(s)
def __len__(s):return s.v.__sizeof__()
def _cast(s,t):_assert(issubclass(t,_ptn_type),"casting type must be a piston primitive type");return t(s.v)
@staticmethod
def _t():return float
@staticmethod
def _init():return _ptn_t_float(0.0)
def __int__(s):return int(s.v)
def __bool__(s):return True if s.v>0 else False
class _ptn_t_string(_ptn_type):
def __init__(s,v):s.v=_cast_type(v,str)
def __str__(s):return str(s.v)
def __repr__(s):return str(s)
def __len__(s):return len(s.v)
def _cast(s,t):_assert(issubclass(t,_ptn_type),"casting type must be a piston primitive type");return t(s.v)
@staticmethod
def _t():return str
@staticmethod
def _init():return _ptn_t_string("")
def __int__(s):return int(s.v)
def __float__(s):return float(s.v)
def __bool__(s):return bool(s.v)
class _ptn_t_bool(_ptn_type):
def __init__(s,v):s.v=_cast_type(v,bool)
def __str__(s):return str(s.v)
def __repr__(s):return str(s)
def __len__(s):return s.v.bit_length()
def _cast(s,t):_assert(issubclass(t,_ptn_type),"casting type must be a piston primitive type");return t(s.v)
@staticmethod
def _t():return bool
@staticmethod
def _init():return _ptn_t_bool(False)
def __int__(s):return 1 if s.v else 0
def __float__(s):return 1.0 if s.v else 0.0
class _ptn_t_array(_ptn_type):
def __init__(s,v,l=1):
if isinstance(v,type):s.at=v;s.ai=[s.at.init()]*l;s.al=l
elif isinstance(v,(list,tuple)):_assert_min_length(v,1);_assert_ptn_type(v[0]);s.at=type(v[0]);_assert_list_type(v,s.at);s.ai=v;s.al=len(v)
elif isinstance(v,_ptn_t_array):s.at=v.at;s.ai=v.ai;s.al=v.al;
else:s.at=type(v);s.ai=[v];s.ai.extend([s.at._init()]*(l-1));s.al=l;
def __str__(s):return str(s.ai)
def __repr__(s):return str(s)
def __len__(s):return s.al
def _cast(s,t):_assert(not issubclass(t,(_ptn_t_array,_ptn_t_tuple)),"cannot cast array to array or tuple");c=[e._cast(t)for e in s.ai];return _ptn_t_array(c)
@staticmethod
def _t():return list
# TODO: how do we do casting?
class _ptn_t_tuple(_ptn_type):
def __init__(s,v):
v=_ensure_iterable(v)
if _is_ptn_type_list(v):s.tt=[type(e)for e in v];s.ti=v;s.tl=len(v)
elif _is_types_list(v):_assert_ptn_type_list(v);s.tt=v;s.ti=[_raise_runtime("tuples cannot contain tuples or arrays")if issubclass(e,(_ptn_t_array,_ptn_t_tuple))else e._init()for e in v];s.tl=len(v)
else:_raise_runtime("invalid tuple initialization "+str(v))
def __str__(s):return str(s.ti)
def __repr__(s):return str(s)
def __len__(s):return s.tl
def _cast(s,t):_assert(not issubclass(t,(_ptn_t_array,_ptn_t_tuple)),"cannot cast array to array or tuple");_raise_runtime("not implemented yet. how?")
@staticmethod
def _t():return tuple
def _int(v=None):return _ptn_t_integer._init()if v is None else _ptn_t_integer(v)
def _float(v=None):return _ptn_t_float._init()if v is None else _ptn_t_float(v)
def _string(v=None):return _ptn_t_string._init()if v is None else _ptn_t_string(v)
def _bool(v=None):return _ptn_t_bool._init()if v is None else _ptn_t_bool(v)
def _array(vt,l=None):return _ptn_t_array(vt)if l is None else _ptn_t_array(vt,l)
def _tuple(vt):return _ptn_t_tuple(vt)