diff --git a/src/stpr/__init__.py b/src/stpr/__init__.py
index d736465..67300cb 100644
--- a/src/stpr/__init__.py
+++ b/src/stpr/__init__.py
@@ -1,6 +1,6 @@
__version__ = "0.0.1"
-from .transformer import fn, seq, parallel, parallelFor, fork, race
+from .transformer import fn, seq, parallel, parallelFor, fork, race, _dump, _fn_id
from .functions import run, wait, start, exec, async_partial
from .runtime import _call, _to_aiter, _start, _run_sync, _await, _to_acm
@@ -9,4 +9,6 @@
from .channels import Channel, select
-from .io import open
\ No newline at end of file
+from .io import open
+
+from . import net
\ No newline at end of file
diff --git a/src/stpr/_astdump.py b/src/stpr/_astdump.py
index 00dad28..b464134 100644
--- a/src/stpr/_astdump.py
+++ b/src/stpr/_astdump.py
@@ -23,6 +23,8 @@ def _get_func(node: ast.AST) -> str:
return node.value
if node.__class__ == ast.Subscript:
return f'{_get_func(node.value)}[{_get_func(node.slice)}]'
+ if node.__class__ == ast.Call:
+ return f'call'
raise Exception('?? %s' % node)
diff --git a/src/stpr/_debug.py b/src/stpr/_debug.py
index 0f9409f..3898406 100644
--- a/src/stpr/_debug.py
+++ b/src/stpr/_debug.py
@@ -1,3 +1,5 @@
+import datetime
+
import time
from enum import Enum
@@ -5,7 +7,7 @@
DEBUG = False
-_TS = None
+_TS = time.time()
def _ts():
@@ -48,4 +50,5 @@ def _print(msg: str, color: Color = None, background: Color = None) -> None:
cstr1 += f'\033[{40 + background.value}m'
if color or background:
cstr2 = '\033[0m'
- print(f'{cstr1}{msg}{cstr2}')
+ date = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
+ print(f'{date} {cstr1}{msg}{cstr2}')
diff --git a/src/stpr/channels.py b/src/stpr/channels.py
index feb1452..9e7ab2a 100644
--- a/src/stpr/channels.py
+++ b/src/stpr/channels.py
@@ -1,5 +1,7 @@
import asyncio
-from typing import TypeVar, Tuple, Optional, AsyncIterator
+from abc import ABC, abstractmethod
+from typing import TypeVar, Tuple, Optional, AsyncIterator, Iterable, List, Callable, Awaitable, \
+ Generic
import stpr
from stpr.types import T
@@ -13,6 +15,43 @@ def __init__(self, exception: Optional[Exception]) -> None:
_GUARD = _Guard(StopAsyncIteration())
+class ChannelCallback(Generic[T], ABC):
+ """
+ An abstract base class for channel callbacks.
+
+ This class can be overridden in order to receive notifications from a channel using the channel
+ callback mechanism.
+ """
+ @abstractmethod
+ def value_appended(self, channel: 'Channel[T]', value: T) -> None:
+ """
+ This method is invoked when a value is added to a channel.
+
+ :param channel: The channel that the value was added to.
+ :param value: The value that was added
+ """
+ pass
+
+ @abstractmethod
+ def channel_error(self, channel: 'Channel[T]', error: Exception) -> None:
+ """
+ This method is called when a channel receives an error.
+
+ :param channel: The channel that received the error.
+ :param error: The error.
+ """
+ pass
+
+ @abstractmethod
+ def channel_closed(self, channel: 'Channel[T]') -> None:
+ """
+ This method is called when a channel is closed.
+
+ :param channel: The channel that was closed.
+ """
+ pass
+
+
class Channel(AsyncIterator[T]):
"""
This class implements a channel.
@@ -60,9 +99,16 @@ def producer(c: Channel[int]) -> None:
In the above example, the channel is closed automatically when the `with` statement completes.
Any exceptions thrown by `f()` will propagate out of `producer()`, but also be raised in the
consumer function after all other values in the channel are consumed.
+
+ Iterating a channel concurrently (that is, having multiple concurrent for loops for the same
+ channel instance) will cause each iteration to produce distinct values in a nondeterministic
+ fashion. For deterministic behaviour, there should be no concurrent iteration on a channel.
+ If deterministic concurrent iteration on the values added to a channel is needed, use the
+ :meth:`~.split` method.
"""
def __init__(self):
self._q = asyncio.Queue()
+ self._callbacks: List[ChannelCallback[T]] = []
async def __aenter__(self):
pass
@@ -82,6 +128,21 @@ async def __anext__(self) -> T:
raise value.exception
return value
+ def add_callback(self, cb: ChannelCallback) -> None:
+ """
+ Adds a callback to this reactive.
+
+ A callback can be used to get asynchronous notifications when a value is added to a
+ Channel or when the channel is closed/fails.
+
+ :param cb: An instance of :class:`.ChannelCallback`.
+ """
+ if cb is None:
+ raise ValueError(f'Invalid callback {cb}')
+ if not asyncio.iscoroutinefunction(cb):
+ raise ValueError(f'Callback {cb} is not a coroutine.')
+ self._callbacks.append(cb)
+
async def append(self, value: T) -> None:
"""
Appends a value to this channel.
@@ -89,13 +150,12 @@ async def append(self, value: T) -> None:
:param value: The value to append.
"""
await self._q.put(value)
+ if self._callbacks:
+ self._notify_append(value)
def _append_now(self, value: T) -> None:
self._q.put_nowait(value)
- def __iadd__(self, value: T) -> None:
- self._q.put_nowait(value)
-
def close(self) -> None:
"""
Closes this channel.
@@ -105,6 +165,8 @@ def close(self) -> None:
that is used to consume values from this channel.
"""
self._q.put_nowait(_GUARD)
+ if self._callbacks:
+ self._notify_close()
def fail(self, exception: Exception) -> None:
"""
@@ -114,6 +176,61 @@ def fail(self, exception: Exception) -> None:
made to consume that value, ``exception`` is raised instead.
"""
self._q.put_nowait(_Guard(exception))
+ if self._callbacks:
+ self._notify_failure(exception)
+
+ def _notify_append(self, value: T) -> None:
+ for cb in self._callbacks:
+ cb.value_appended(self, value)
+
+ def _notify_closed(self, value: T) -> None:
+ for cb in self._callbacks:
+ cb.channel_closed(self)
+
+ def _notify_failure(self, e: Exception) -> None:
+ for cb in self._callbacks:
+ cb.channel_error(self, e)
+
+ def split(self) -> 'Channel[T]':
+ """
+ Splits this channel.
+
+ This method returns a new channel that will receive all the values that
+ this channel receives, including errors. The returned channel will be
+ closed when this channel is closed. The returned channel can, in turn
+ be split further.
+
+ Iterating on both this channel and the channel returned by this method
+ will produce the values added on the channel in sequence for each
+ iteration provided that at most one iteration is active on each at a
+ given time.
+ :return:
+ """
+ split = _SplitChannel()
+ self.add_callback(split)
+ return split
+
+ async def drain(self, dest: List[T]) -> None:
+ """
+ Drains this channel into a list.
+
+ One by one, remove all items from this channel until the channel is
+ empty and add them to the given list. This method can be used to
+ periodically bulk-collect items that have accumulated in this channel.
+
+ :param dest: A list to drain this channel into.
+ """
+
+
+class _SplitChannel(Channel[T], ChannelCallback):
+ def value_appended(self, channel: 'Channel[T]', value: T):
+ self.append(value)
+
+ def channel_error(self, channel: 'Channel[T]', error: Exception):
+ self.fail(error)
+
+ def channel_closed(self, channel: 'Channel[T]'):
+ self.close()
async def select(*args: Channel[T]) -> Channel[Tuple[T, Channel[T]]]:
diff --git a/src/stpr/functions.py b/src/stpr/functions.py
index 1b1cfc9..4fd65f0 100644
--- a/src/stpr/functions.py
+++ b/src/stpr/functions.py
@@ -44,7 +44,10 @@ def _start(fn) -> concurrent.futures.Future:
raise ValueError('Cannot run %s' % fn)
#print(f'running {coro} in {id(_LOOP)}')
- return asyncio.run_coroutine_threadsafe(coro, _LOOP)
+ loop = asyncio.get_running_loop()
+ if loop is None:
+ loop = _LOOP
+ return asyncio.run_coroutine_threadsafe(coro, loop)
def run(fn) -> object | None:
diff --git a/src/stpr/net.py b/src/stpr/net.py
index d86eb9d..2758243 100644
--- a/src/stpr/net.py
+++ b/src/stpr/net.py
@@ -1,6 +1,6 @@
import asyncio
from contextlib import AbstractAsyncContextManager
-from socket import socket
+import socket
from stpr.channels import Channel
diff --git a/src/stpr/reactive.py b/src/stpr/reactive.py
index ad4223e..5b4a3e2 100644
--- a/src/stpr/reactive.py
+++ b/src/stpr/reactive.py
@@ -127,7 +127,7 @@ async def update(self, value: T, force: bool = False) -> None:
async def _notify(self, value: T, old: T) -> None:
for cb in self._callbacks:
try:
- await cb(self, value, old)
+ asyncio.create_task(cb(self, value, old))
except TypeError as e:
raise TypeError(f'Failed to invoke callback {cb}: {e}')
diff --git a/src/stpr/runtime.py b/src/stpr/runtime.py
index a937367..1a0af75 100644
--- a/src/stpr/runtime.py
+++ b/src/stpr/runtime.py
@@ -81,5 +81,5 @@ async def _call(fn, *args, **kwargs):
#print('Calling coro %s with %s, %s' % (fn, args, kwargs))
return await fn(*args, **kwargs)
else:
- #print('Wrap-calling %s with %s' % (fn, args))
+ #print('Wrap-calling %s with %s, %s' % (fn, args, kwargs))
return await _run_sync(fn, *args, **kwargs)
diff --git a/src/stpr/transformer.py b/src/stpr/transformer.py
index d38d1c6..016b16d 100644
--- a/src/stpr/transformer.py
+++ b/src/stpr/transformer.py
@@ -8,6 +8,7 @@
from enum import Enum
from typing import Tuple, List, Optional, Coroutine, Type, Dict, Callable, Set, Union, Iterable
+from stpr.functions import start
from stpr._astdump import astdump, astdumps
from stpr._debug import debug_print, Color, _ts, DEBUG, _print
@@ -16,6 +17,8 @@
_SAFE_MODULES.add('stpr.transformer')
_SAFE_MODULES.add('stpr.types')
_SAFE_MODULES.add('stpr.reactive')
+_SAFE_MODULES.add('stpr.channels')
+_SAFE_MODULES.add('stpr.functions')
class _Ref:
@@ -79,8 +82,8 @@ def _copy_params(src: ast.AST, dst: ast.AST) -> None:
return dst
-def _transform(node: ast.AST, frame, sp_mod_name: str):
- t = _Transformer(frame, sp_mod_name)
+def _transform(node: ast.AST, frame, sp_mod_name: str, autosync=True):
+ t = _Transformer(frame, sp_mod_name, autosync)
t.visit(node)
debug_print(frame, Color.BLUE)
return node
@@ -478,11 +481,12 @@ def var_names(self) -> Set[str]:
class _Transformer(ast.NodeTransformer):
- def __init__(self, outer_frame, sp_mod_name: str):
+ def __init__(self, outer_frame, sp_mod_name: str, autosync: bool = True):
self.outer_frame = outer_frame
self.sp_mod_name = sp_mod_name
self.context_stack = []
self.crt_context = None
+ self.autosync = autosync
def get_global(self, name: str) -> _Ref:
return _find_global(self.outer_frame, name)
@@ -716,6 +720,8 @@ def visit_Call(self, node: ast.AST) -> ast.AST:
return self._make_parallel_fn(node)
elif ref.value == race:
return self._make_race_fn(node)
+ elif ref.value == start:
+ return self._make_start_fn(node)
else:
raise Exception(f'Unhandled sp function: {ref.value}')
elif ref.is_coro():
@@ -800,6 +806,8 @@ def visit_With_item(self, node: ast.AST, item_ix: int) -> ast.AST:
inspect.iscoroutinefunction(val) or inspect.iscoroutine(val):
r = ast.AsyncWith()
_copy_params(node, r)
+ if inspect.iscoroutinefunction(val):
+ item.context_expr = _wrap_await(item.context_expr)
r.items = [item]
if body_visited:
r.body = body
@@ -914,6 +922,14 @@ def _make_race_fn(self, call: ast.Call) -> ast.AST:
call.args[i] = self._to_coro(call.args[i])
return ast.Await(call)
+ def _make_start_fn(self, call: ast.Call) -> ast.AST:
+ if len(call.keywords) > 0:
+ raise Exception('start() does not support keyword arguments.')
+ if len(call.args) != 1:
+ raise Exception('start() needs exactly one argument.')
+ call.args[0] = self._to_coro(call.args[0])
+ return call
+
def _make_parallel(self, nodes: List[ast.AST],
r_nodes: Optional[List[ast.AST]] = None) -> List[ast.AST]:
if nodes is None:
@@ -1060,7 +1076,7 @@ def _get_code(t: Tuple[object]) -> types.CodeType:
raise RuntimeError('Code object not found in %s' % t)
-def fn(*args, crt_frame=None, stpr_module_name=None, debug=False):
+def fn(*args, crt_frame=None, stpr_module_name=None, debug=False, autosync=True):
"""
Decorator for Stpr functions.
@@ -1105,7 +1121,7 @@ def inner(f):
if debug:
_print('Before instrumentation', Color.BLUE)
astdump(t)
- t2 = _transform(t, crt_frame.f_back, stpr_module_name)
+ t2 = _transform(t, crt_frame.f_back, stpr_module_name, autosync)
nlines = t2.body[0].end_lineno - t2.body[0].lineno
t2.body[0].lineno = lineno
t2.body[0].end_lineno = lineno + nlines
@@ -1123,6 +1139,45 @@ def inner(f):
print(ex)
astdump(t2)
raise
+ f.__SP_CC = True
+ r = types.FunctionType(_get_code(code.co_consts), f.__globals__, f.__name__,
+ f.__defaults__)
+ r.__doc__ = f.__doc__
+ r.__annotations__ = f.__annotations__
+ return r
+ return f
+
+ if len(args) > 1:
+ raise TypeError()
+ if len(args) == 0:
+ return inner
+ else:
+ return inner(args[0])
+
+
+def _fn_id(*args, crt_frame=None, stpr_module_name=None, debug=False):
+ if crt_frame is None:
+ crt_frame = inspect.currentframe()
+
+ def inner(f):
+ print(f'Defaults: {f.__defaults__}')
+ nonlocal debug, crt_frame, stpr_module_name
+ if DEBUG:
+ debug = True
+ if debug:
+ _print('Instrumenting %s' % f, Color.BLUE)
+ if not hasattr(f, '__SP_CC'):
+ lineno = f.__code__.co_firstlineno
+ t = ast.parse(textwrap.dedent(inspect.getsource(f)))
+ up = _MyUnparser()
+ if debug:
+ print(up.visit(t))
+ try:
+ code = compile(t, inspect.getfile(f), 'exec')
+ except Exception as ex:
+ print(ex)
+ astdump(t)
+ raise
return types.FunctionType(_get_code(code.co_consts), f.__globals__, f.__name__,
f.__defaults__)
f.__SP_CC = True
@@ -1136,6 +1191,18 @@ def inner(f):
return inner(args[0])
+def _dump(f):
+ """
+ Prints the AST of a function.
+
+ """
+ def inner(f):
+ t = ast.parse(textwrap.dedent(inspect.getsource(f)))
+ astdump(t)
+
+ return inner(f)
+
+
class seq:
"""
Runs statements sequentially.
@@ -1274,7 +1341,7 @@ def _transformer(cls) -> Callable:
return _Transformer._parallel
@staticmethod
- async def _fn(*coros) -> Tuple[...]:
+ async def _fn(*coros) -> Tuple:
tasks = []
async with parallel() as ctx:
for coro in coros:
@@ -1426,4 +1493,4 @@ def _transformer(cls) -> Callable:
_SP_CMS = [seq, parallel, parallelFor, fork]
-_SP_FNS = [parallel, race]
+_SP_FNS = [parallel, race, start]
diff --git a/web/image-src/logo.svg b/web/image-src/logo.svg
index 63356bf..558b5c9 100644
--- a/web/image-src/logo.svg
+++ b/web/image-src/logo.svg
@@ -24,11 +24,11 @@
inkscape:deskcolor="#d1d1d1"
inkscape:document-units="mm"
showgrid="false"
- inkscape:zoom="3.4772035"
- inkscape:cx="161.91172"
- inkscape:cy="240.99826"
+ inkscape:zoom="2.4587542"
+ inkscape:cx="144.38206"
+ inkscape:cy="261.10784"
inkscape:window-width="1798"
- inkscape:window-height="1025"
+ inkscape:window-height="1145"
inkscape:window-x="122"
inkscape:window-y="27"
inkscape:window-maximized="1"
@@ -77,6 +77,42 @@
width="107.43045"
height="130.50816"
id="rect4010-6-7-25" />
+
+
+
+
+
+
+
- S
- I
- I
- R
-
+ inkscape:export-ydpi="85.55">
+ S
+ I
+ I
+ R
+
+
+
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:3;stroke-linecap:butt;stroke-linejoin:miter;stroke-dasharray:none;stroke-opacity:1"
+ d="m 31.155119,52.844511 h 9.429632"
+ id="path4419" />
-
S
+ id="tspan6268">S
diff --git a/web/image-src/main.blend b/web/image-src/main.blend
index a5d07f7..1cf8860 100644
Binary files a/web/image-src/main.blend and b/web/image-src/main.blend differ