Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/stpr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__version__ = "0.0.1"

from .transformer import fn, seq, parallel, parallelFor, fork, race
from .transformer import fn, seq, parallel, parallelFor, fork, race, _dump, _fn_id
from .functions import run, wait, start, exec, async_partial

from .runtime import _call, _to_aiter, _start, _run_sync, _await, _to_acm
Expand All @@ -9,4 +9,6 @@

from .channels import Channel, select

from .io import open
from .io import open

from . import net
2 changes: 2 additions & 0 deletions src/stpr/_astdump.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def _get_func(node: ast.AST) -> str:
return node.value
if node.__class__ == ast.Subscript:
return f'{_get_func(node.value)}[{_get_func(node.slice)}]'
if node.__class__ == ast.Call:
return f'call'
raise Exception('?? %s' % node)


Expand Down
7 changes: 5 additions & 2 deletions src/stpr/_debug.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import datetime

import time
from enum import Enum


DEBUG = False


_TS = None
_TS = time.time()


def _ts():
Expand Down Expand Up @@ -48,4 +50,5 @@ def _print(msg: str, color: Color = None, background: Color = None) -> None:
cstr1 += f'\033[{40 + background.value}m'
if color or background:
cstr2 = '\033[0m'
print(f'{cstr1}{msg}{cstr2}')
date = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
print(f'{date} {cstr1}{msg}{cstr2}')
125 changes: 121 additions & 4 deletions src/stpr/channels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from typing import TypeVar, Tuple, Optional, AsyncIterator
from abc import ABC, abstractmethod
from typing import TypeVar, Tuple, Optional, AsyncIterator, Iterable, List, Callable, Awaitable, \
Generic

import stpr
from stpr.types import T
Expand All @@ -13,6 +15,43 @@ def __init__(self, exception: Optional[Exception]) -> None:
_GUARD = _Guard(StopAsyncIteration())


class ChannelCallback(Generic[T], ABC):
"""
An abstract base class for channel callbacks.

This class can be overridden in order to receive notifications from a channel using the channel
callback mechanism.
"""
@abstractmethod
def value_appended(self, channel: 'Channel[T]', value: T) -> None:
"""
This method is invoked when a value is added to a channel.

:param channel: The channel that the value was added to.
:param value: The value that was added
"""
pass

@abstractmethod
def channel_error(self, channel: 'Channel[T]', error: Exception) -> None:
"""
This method is called when a channel receives an error.

:param channel: The channel that received the error.
:param error: The error.
"""
pass

@abstractmethod
def channel_closed(self, channel: 'Channel[T]') -> None:
"""
This method is called when a channel is closed.

:param channel: The channel that was closed.
"""
pass


class Channel(AsyncIterator[T]):
"""
This class implements a channel.
Expand Down Expand Up @@ -60,9 +99,16 @@ def producer(c: Channel[int]) -> None:
In the above example, the channel is closed automatically when the `with` statement completes.
Any exceptions thrown by `f()` will propagate out of `producer()`, but also be raised in the
consumer function after all other values in the channel are consumed.

Iterating a channel concurrently (that is, having multiple concurrent for loops for the same
channel instance) will cause each iteration to produce distinct values in a nondeterministic
fashion. For deterministic behaviour, there should be no concurrent iteration on a channel.
If deterministic concurrent iteration on the values added to a channel is needed, use the
:meth:`~.split` method.
"""
def __init__(self):
self._q = asyncio.Queue()
self._callbacks: List[ChannelCallback[T]] = []

async def __aenter__(self):
pass
Expand All @@ -82,20 +128,34 @@ async def __anext__(self) -> T:
raise value.exception
return value

def add_callback(self, cb: ChannelCallback) -> None:
"""
Adds a callback to this reactive.

A callback can be used to get asynchronous notifications when a value is added to a
Channel or when the channel is closed/fails.

:param cb: An instance of :class:`.ChannelCallback`.
"""
if cb is None:
raise ValueError(f'Invalid callback {cb}')
if not asyncio.iscoroutinefunction(cb):
raise ValueError(f'Callback {cb} is not a coroutine.')
self._callbacks.append(cb)

async def append(self, value: T) -> None:
"""
Appends a value to this channel.

:param value: The value to append.
"""
await self._q.put(value)
if self._callbacks:
self._notify_append(value)

def _append_now(self, value: T) -> None:
self._q.put_nowait(value)

def __iadd__(self, value: T) -> None:
self._q.put_nowait(value)

def close(self) -> None:
"""
Closes this channel.
Expand All @@ -105,6 +165,8 @@ def close(self) -> None:
that is used to consume values from this channel.
"""
self._q.put_nowait(_GUARD)
if self._callbacks:
self._notify_close()

def fail(self, exception: Exception) -> None:
"""
Expand All @@ -114,6 +176,61 @@ def fail(self, exception: Exception) -> None:
made to consume that value, ``exception`` is raised instead.
"""
self._q.put_nowait(_Guard(exception))
if self._callbacks:
self._notify_failure(exception)

def _notify_append(self, value: T) -> None:
for cb in self._callbacks:
cb.value_appended(self, value)

def _notify_closed(self, value: T) -> None:
for cb in self._callbacks:
cb.channel_closed(self)

def _notify_failure(self, e: Exception) -> None:
for cb in self._callbacks:
cb.channel_error(self, e)

def split(self) -> 'Channel[T]':
"""
Splits this channel.

This method returns a new channel that will receive all the values that
this channel receives, including errors. The returned channel will be
closed when this channel is closed. The returned channel can, in turn
be split further.

Iterating on both this channel and the channel returned by this method
will produce the values added on the channel in sequence for each
iteration provided that at most one iteration is active on each at a
given time.
:return:
"""
split = _SplitChannel()
self.add_callback(split)
return split

async def drain(self, dest: List[T]) -> None:
"""
Drains this channel into a list.

One by one, remove all items from this channel until the channel is
empty and add them to the given list. This method can be used to
periodically bulk-collect items that have accumulated in this channel.

:param dest: A list to drain this channel into.
"""


class _SplitChannel(Channel[T], ChannelCallback):
def value_appended(self, channel: 'Channel[T]', value: T):
self.append(value)

def channel_error(self, channel: 'Channel[T]', error: Exception):
self.fail(error)

def channel_closed(self, channel: 'Channel[T]'):
self.close()


async def select(*args: Channel[T]) -> Channel[Tuple[T, Channel[T]]]:
Expand Down
5 changes: 4 additions & 1 deletion src/stpr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def _start(fn) -> concurrent.futures.Future:
raise ValueError('Cannot run %s' % fn)

#print(f'running {coro} in {id(_LOOP)}')
return asyncio.run_coroutine_threadsafe(coro, _LOOP)
loop = asyncio.get_running_loop()
if loop is None:
loop = _LOOP
return asyncio.run_coroutine_threadsafe(coro, loop)


def run(fn) -> object | None:
Expand Down
2 changes: 1 addition & 1 deletion src/stpr/net.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from contextlib import AbstractAsyncContextManager
from socket import socket
import socket

from stpr.channels import Channel

Expand Down
2 changes: 1 addition & 1 deletion src/stpr/reactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def update(self, value: T, force: bool = False) -> None:
async def _notify(self, value: T, old: T) -> None:
for cb in self._callbacks:
try:
await cb(self, value, old)
asyncio.create_task(cb(self, value, old))
except TypeError as e:
raise TypeError(f'Failed to invoke callback {cb}: {e}')

Expand Down
2 changes: 1 addition & 1 deletion src/stpr/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ async def _call(fn, *args, **kwargs):
#print('Calling coro %s with %s, %s' % (fn, args, kwargs))
return await fn(*args, **kwargs)
else:
#print('Wrap-calling %s with %s' % (fn, args))
#print('Wrap-calling %s with %s, %s' % (fn, args, kwargs))
return await _run_sync(fn, *args, **kwargs)
Loading
Loading