diff --git a/src/stpr/__init__.py b/src/stpr/__init__.py index d736465..67300cb 100644 --- a/src/stpr/__init__.py +++ b/src/stpr/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.0.1" -from .transformer import fn, seq, parallel, parallelFor, fork, race +from .transformer import fn, seq, parallel, parallelFor, fork, race, _dump, _fn_id from .functions import run, wait, start, exec, async_partial from .runtime import _call, _to_aiter, _start, _run_sync, _await, _to_acm @@ -9,4 +9,6 @@ from .channels import Channel, select -from .io import open \ No newline at end of file +from .io import open + +from . import net \ No newline at end of file diff --git a/src/stpr/_astdump.py b/src/stpr/_astdump.py index 00dad28..b464134 100644 --- a/src/stpr/_astdump.py +++ b/src/stpr/_astdump.py @@ -23,6 +23,8 @@ def _get_func(node: ast.AST) -> str: return node.value if node.__class__ == ast.Subscript: return f'{_get_func(node.value)}[{_get_func(node.slice)}]' + if node.__class__ == ast.Call: + return f'call' raise Exception('?? %s' % node) diff --git a/src/stpr/_debug.py b/src/stpr/_debug.py index 0f9409f..3898406 100644 --- a/src/stpr/_debug.py +++ b/src/stpr/_debug.py @@ -1,3 +1,5 @@ +import datetime + import time from enum import Enum @@ -5,7 +7,7 @@ DEBUG = False -_TS = None +_TS = time.time() def _ts(): @@ -48,4 +50,5 @@ def _print(msg: str, color: Color = None, background: Color = None) -> None: cstr1 += f'\033[{40 + background.value}m' if color or background: cstr2 = '\033[0m' - print(f'{cstr1}{msg}{cstr2}') + date = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f") + print(f'{date} {cstr1}{msg}{cstr2}') diff --git a/src/stpr/channels.py b/src/stpr/channels.py index feb1452..9e7ab2a 100644 --- a/src/stpr/channels.py +++ b/src/stpr/channels.py @@ -1,5 +1,7 @@ import asyncio -from typing import TypeVar, Tuple, Optional, AsyncIterator +from abc import ABC, abstractmethod +from typing import TypeVar, Tuple, Optional, AsyncIterator, Iterable, List, Callable, Awaitable, \ + Generic import stpr from stpr.types import T @@ -13,6 +15,43 @@ def __init__(self, exception: Optional[Exception]) -> None: _GUARD = _Guard(StopAsyncIteration()) +class ChannelCallback(Generic[T], ABC): + """ + An abstract base class for channel callbacks. + + This class can be overridden in order to receive notifications from a channel using the channel + callback mechanism. + """ + @abstractmethod + def value_appended(self, channel: 'Channel[T]', value: T) -> None: + """ + This method is invoked when a value is added to a channel. + + :param channel: The channel that the value was added to. + :param value: The value that was added + """ + pass + + @abstractmethod + def channel_error(self, channel: 'Channel[T]', error: Exception) -> None: + """ + This method is called when a channel receives an error. + + :param channel: The channel that received the error. + :param error: The error. + """ + pass + + @abstractmethod + def channel_closed(self, channel: 'Channel[T]') -> None: + """ + This method is called when a channel is closed. + + :param channel: The channel that was closed. + """ + pass + + class Channel(AsyncIterator[T]): """ This class implements a channel. @@ -60,9 +99,16 @@ def producer(c: Channel[int]) -> None: In the above example, the channel is closed automatically when the `with` statement completes. Any exceptions thrown by `f()` will propagate out of `producer()`, but also be raised in the consumer function after all other values in the channel are consumed. + + Iterating a channel concurrently (that is, having multiple concurrent for loops for the same + channel instance) will cause each iteration to produce distinct values in a nondeterministic + fashion. For deterministic behaviour, there should be no concurrent iteration on a channel. + If deterministic concurrent iteration on the values added to a channel is needed, use the + :meth:`~.split` method. """ def __init__(self): self._q = asyncio.Queue() + self._callbacks: List[ChannelCallback[T]] = [] async def __aenter__(self): pass @@ -82,6 +128,21 @@ async def __anext__(self) -> T: raise value.exception return value + def add_callback(self, cb: ChannelCallback) -> None: + """ + Adds a callback to this reactive. + + A callback can be used to get asynchronous notifications when a value is added to a + Channel or when the channel is closed/fails. + + :param cb: An instance of :class:`.ChannelCallback`. + """ + if cb is None: + raise ValueError(f'Invalid callback {cb}') + if not asyncio.iscoroutinefunction(cb): + raise ValueError(f'Callback {cb} is not a coroutine.') + self._callbacks.append(cb) + async def append(self, value: T) -> None: """ Appends a value to this channel. @@ -89,13 +150,12 @@ async def append(self, value: T) -> None: :param value: The value to append. """ await self._q.put(value) + if self._callbacks: + self._notify_append(value) def _append_now(self, value: T) -> None: self._q.put_nowait(value) - def __iadd__(self, value: T) -> None: - self._q.put_nowait(value) - def close(self) -> None: """ Closes this channel. @@ -105,6 +165,8 @@ def close(self) -> None: that is used to consume values from this channel. """ self._q.put_nowait(_GUARD) + if self._callbacks: + self._notify_close() def fail(self, exception: Exception) -> None: """ @@ -114,6 +176,61 @@ def fail(self, exception: Exception) -> None: made to consume that value, ``exception`` is raised instead. """ self._q.put_nowait(_Guard(exception)) + if self._callbacks: + self._notify_failure(exception) + + def _notify_append(self, value: T) -> None: + for cb in self._callbacks: + cb.value_appended(self, value) + + def _notify_closed(self, value: T) -> None: + for cb in self._callbacks: + cb.channel_closed(self) + + def _notify_failure(self, e: Exception) -> None: + for cb in self._callbacks: + cb.channel_error(self, e) + + def split(self) -> 'Channel[T]': + """ + Splits this channel. + + This method returns a new channel that will receive all the values that + this channel receives, including errors. The returned channel will be + closed when this channel is closed. The returned channel can, in turn + be split further. + + Iterating on both this channel and the channel returned by this method + will produce the values added on the channel in sequence for each + iteration provided that at most one iteration is active on each at a + given time. + :return: + """ + split = _SplitChannel() + self.add_callback(split) + return split + + async def drain(self, dest: List[T]) -> None: + """ + Drains this channel into a list. + + One by one, remove all items from this channel until the channel is + empty and add them to the given list. This method can be used to + periodically bulk-collect items that have accumulated in this channel. + + :param dest: A list to drain this channel into. + """ + + +class _SplitChannel(Channel[T], ChannelCallback): + def value_appended(self, channel: 'Channel[T]', value: T): + self.append(value) + + def channel_error(self, channel: 'Channel[T]', error: Exception): + self.fail(error) + + def channel_closed(self, channel: 'Channel[T]'): + self.close() async def select(*args: Channel[T]) -> Channel[Tuple[T, Channel[T]]]: diff --git a/src/stpr/functions.py b/src/stpr/functions.py index 1b1cfc9..4fd65f0 100644 --- a/src/stpr/functions.py +++ b/src/stpr/functions.py @@ -44,7 +44,10 @@ def _start(fn) -> concurrent.futures.Future: raise ValueError('Cannot run %s' % fn) #print(f'running {coro} in {id(_LOOP)}') - return asyncio.run_coroutine_threadsafe(coro, _LOOP) + loop = asyncio.get_running_loop() + if loop is None: + loop = _LOOP + return asyncio.run_coroutine_threadsafe(coro, loop) def run(fn) -> object | None: diff --git a/src/stpr/net.py b/src/stpr/net.py index d86eb9d..2758243 100644 --- a/src/stpr/net.py +++ b/src/stpr/net.py @@ -1,6 +1,6 @@ import asyncio from contextlib import AbstractAsyncContextManager -from socket import socket +import socket from stpr.channels import Channel diff --git a/src/stpr/reactive.py b/src/stpr/reactive.py index ad4223e..5b4a3e2 100644 --- a/src/stpr/reactive.py +++ b/src/stpr/reactive.py @@ -127,7 +127,7 @@ async def update(self, value: T, force: bool = False) -> None: async def _notify(self, value: T, old: T) -> None: for cb in self._callbacks: try: - await cb(self, value, old) + asyncio.create_task(cb(self, value, old)) except TypeError as e: raise TypeError(f'Failed to invoke callback {cb}: {e}') diff --git a/src/stpr/runtime.py b/src/stpr/runtime.py index a937367..1a0af75 100644 --- a/src/stpr/runtime.py +++ b/src/stpr/runtime.py @@ -81,5 +81,5 @@ async def _call(fn, *args, **kwargs): #print('Calling coro %s with %s, %s' % (fn, args, kwargs)) return await fn(*args, **kwargs) else: - #print('Wrap-calling %s with %s' % (fn, args)) + #print('Wrap-calling %s with %s, %s' % (fn, args, kwargs)) return await _run_sync(fn, *args, **kwargs) diff --git a/src/stpr/transformer.py b/src/stpr/transformer.py index d38d1c6..016b16d 100644 --- a/src/stpr/transformer.py +++ b/src/stpr/transformer.py @@ -8,6 +8,7 @@ from enum import Enum from typing import Tuple, List, Optional, Coroutine, Type, Dict, Callable, Set, Union, Iterable +from stpr.functions import start from stpr._astdump import astdump, astdumps from stpr._debug import debug_print, Color, _ts, DEBUG, _print @@ -16,6 +17,8 @@ _SAFE_MODULES.add('stpr.transformer') _SAFE_MODULES.add('stpr.types') _SAFE_MODULES.add('stpr.reactive') +_SAFE_MODULES.add('stpr.channels') +_SAFE_MODULES.add('stpr.functions') class _Ref: @@ -79,8 +82,8 @@ def _copy_params(src: ast.AST, dst: ast.AST) -> None: return dst -def _transform(node: ast.AST, frame, sp_mod_name: str): - t = _Transformer(frame, sp_mod_name) +def _transform(node: ast.AST, frame, sp_mod_name: str, autosync=True): + t = _Transformer(frame, sp_mod_name, autosync) t.visit(node) debug_print(frame, Color.BLUE) return node @@ -478,11 +481,12 @@ def var_names(self) -> Set[str]: class _Transformer(ast.NodeTransformer): - def __init__(self, outer_frame, sp_mod_name: str): + def __init__(self, outer_frame, sp_mod_name: str, autosync: bool = True): self.outer_frame = outer_frame self.sp_mod_name = sp_mod_name self.context_stack = [] self.crt_context = None + self.autosync = autosync def get_global(self, name: str) -> _Ref: return _find_global(self.outer_frame, name) @@ -716,6 +720,8 @@ def visit_Call(self, node: ast.AST) -> ast.AST: return self._make_parallel_fn(node) elif ref.value == race: return self._make_race_fn(node) + elif ref.value == start: + return self._make_start_fn(node) else: raise Exception(f'Unhandled sp function: {ref.value}') elif ref.is_coro(): @@ -800,6 +806,8 @@ def visit_With_item(self, node: ast.AST, item_ix: int) -> ast.AST: inspect.iscoroutinefunction(val) or inspect.iscoroutine(val): r = ast.AsyncWith() _copy_params(node, r) + if inspect.iscoroutinefunction(val): + item.context_expr = _wrap_await(item.context_expr) r.items = [item] if body_visited: r.body = body @@ -914,6 +922,14 @@ def _make_race_fn(self, call: ast.Call) -> ast.AST: call.args[i] = self._to_coro(call.args[i]) return ast.Await(call) + def _make_start_fn(self, call: ast.Call) -> ast.AST: + if len(call.keywords) > 0: + raise Exception('start() does not support keyword arguments.') + if len(call.args) != 1: + raise Exception('start() needs exactly one argument.') + call.args[0] = self._to_coro(call.args[0]) + return call + def _make_parallel(self, nodes: List[ast.AST], r_nodes: Optional[List[ast.AST]] = None) -> List[ast.AST]: if nodes is None: @@ -1060,7 +1076,7 @@ def _get_code(t: Tuple[object]) -> types.CodeType: raise RuntimeError('Code object not found in %s' % t) -def fn(*args, crt_frame=None, stpr_module_name=None, debug=False): +def fn(*args, crt_frame=None, stpr_module_name=None, debug=False, autosync=True): """ Decorator for Stpr functions. @@ -1105,7 +1121,7 @@ def inner(f): if debug: _print('Before instrumentation', Color.BLUE) astdump(t) - t2 = _transform(t, crt_frame.f_back, stpr_module_name) + t2 = _transform(t, crt_frame.f_back, stpr_module_name, autosync) nlines = t2.body[0].end_lineno - t2.body[0].lineno t2.body[0].lineno = lineno t2.body[0].end_lineno = lineno + nlines @@ -1123,6 +1139,45 @@ def inner(f): print(ex) astdump(t2) raise + f.__SP_CC = True + r = types.FunctionType(_get_code(code.co_consts), f.__globals__, f.__name__, + f.__defaults__) + r.__doc__ = f.__doc__ + r.__annotations__ = f.__annotations__ + return r + return f + + if len(args) > 1: + raise TypeError() + if len(args) == 0: + return inner + else: + return inner(args[0]) + + +def _fn_id(*args, crt_frame=None, stpr_module_name=None, debug=False): + if crt_frame is None: + crt_frame = inspect.currentframe() + + def inner(f): + print(f'Defaults: {f.__defaults__}') + nonlocal debug, crt_frame, stpr_module_name + if DEBUG: + debug = True + if debug: + _print('Instrumenting %s' % f, Color.BLUE) + if not hasattr(f, '__SP_CC'): + lineno = f.__code__.co_firstlineno + t = ast.parse(textwrap.dedent(inspect.getsource(f))) + up = _MyUnparser() + if debug: + print(up.visit(t)) + try: + code = compile(t, inspect.getfile(f), 'exec') + except Exception as ex: + print(ex) + astdump(t) + raise return types.FunctionType(_get_code(code.co_consts), f.__globals__, f.__name__, f.__defaults__) f.__SP_CC = True @@ -1136,6 +1191,18 @@ def inner(f): return inner(args[0]) +def _dump(f): + """ + Prints the AST of a function. + + """ + def inner(f): + t = ast.parse(textwrap.dedent(inspect.getsource(f))) + astdump(t) + + return inner(f) + + class seq: """ Runs statements sequentially. @@ -1274,7 +1341,7 @@ def _transformer(cls) -> Callable: return _Transformer._parallel @staticmethod - async def _fn(*coros) -> Tuple[...]: + async def _fn(*coros) -> Tuple: tasks = [] async with parallel() as ctx: for coro in coros: @@ -1426,4 +1493,4 @@ def _transformer(cls) -> Callable: _SP_CMS = [seq, parallel, parallelFor, fork] -_SP_FNS = [parallel, race] +_SP_FNS = [parallel, race, start] diff --git a/web/image-src/logo.svg b/web/image-src/logo.svg index 63356bf..558b5c9 100644 --- a/web/image-src/logo.svg +++ b/web/image-src/logo.svg @@ -24,11 +24,11 @@ inkscape:deskcolor="#d1d1d1" inkscape:document-units="mm" showgrid="false" - inkscape:zoom="3.4772035" - inkscape:cx="161.91172" - inkscape:cy="240.99826" + inkscape:zoom="2.4587542" + inkscape:cx="144.38206" + inkscape:cy="261.10784" inkscape:window-width="1798" - inkscape:window-height="1025" + inkscape:window-height="1145" inkscape:window-x="122" inkscape:window-y="27" inkscape:window-maximized="1" @@ -77,6 +77,42 @@ width="107.43045" height="130.50816" id="rect4010-6-7-25" /> + + + + + + + - S - I - I - R - + inkscape:export-ydpi="85.55"> + S + I + I + R + + + + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:3;stroke-linecap:butt;stroke-linejoin:miter;stroke-dasharray:none;stroke-opacity:1" + d="m 31.155119,52.844511 h 9.429632" + id="path4419" /> - S + id="tspan6268">S diff --git a/web/image-src/main.blend b/web/image-src/main.blend index a5d07f7..1cf8860 100644 Binary files a/web/image-src/main.blend and b/web/image-src/main.blend differ