Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions examples/Raven-RWKV/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
FROM nvidia/cuda:11.7.1-devel-ubuntu20.04

# Update, install
RUN apt-get update && \
apt-get install -y build-essential ninja-build git wget

RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda && \
rm Miniconda3-latest-Linux-x86_64.sh && \
/opt/conda/bin/conda create -y --name py39 python=3.9 && \
/opt/conda/bin/conda clean -ya

ENV PATH /opt/conda/envs/py39/bin:$PATH

RUN pip install --upgrade pip setuptools wheel

# Create user instead of using root
ENV USER='user'
RUN groupadd -r user && useradd -r -g $USER $USER
RUN mkdir -p /home/$USER/app
RUN chown -R $USER:$USER /home/$USER
USER $USER

# Define workdir
WORKDIR /home/$USER/app

# Install project
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt

COPY get_models.py .

# Get model weights and tokenizer
RUN python3 get_models.py

# Copy rest
COPY . .

# Publish port
EXPOSE 50051:50051

# Enjoy
ENTRYPOINT ["python3", "server.py"]
CMD ["--address", "[::]:50051"]
48 changes: 48 additions & 0 deletions examples/Raven-RWKV/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# RavenRWKV service

## Description

This project uses the [RWKV-LM](https://github.com/BlinkDL/RWKV-LM) model and turns it into an gRPC service that can be used through [SimpleAI](https://github.com/lhenault/simpleAI).

RWKV is an RNN with Transformer-level language model performance that can be trained like a GPT transformer and is 100% attention-free. It combines the best of RNN and transformer, providing great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

## Usage

Edit the `MODEL` variable in `get_models.py` to choose the model size and context.

Edit the `STRATEGY` variable in `lib_raven.py` to decide how the weights will be loaded, play with this to optimise the throughput for your system. See below for a graphic explanation or checkout [ChatRWKV](https://github.com/BlinkDL/ChatRWKV) for more information.

![Strategies as of 20 Apr 2023](https://raw.githubusercontent.com/BlinkDL/ChatRWKV/536b4b3bf87fbd999798141f409b151ca91a76c7/ChatRWKV-strategy.png)

## Build

```bash
docker build . -t raven-rwkv-service:latest
```

## Start service

```bash
docker run -it --rm -p 50051:50051 --gpus all raven-rwkv-service:latest
```

## Add to model.toml

```
```toml
[raven]
[raven.metadata]
owned_by = 'BlinkDL'
permission = []
description = 'RWKV fine tuned for instruction answering'
[raven.network]
url = 'localhost:50051'
```

```

## Credits

Heavily borrowed from lhenault & BlinkDL

https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B
60 changes: 60 additions & 0 deletions examples/Raven-RWKV/get_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from pathlib import Path

import requests
from huggingface_hub import hf_hub_download

MODEL = "raven-1b-ctx4096"

TOKENIZER_PATH = Path(__file__).parent / "20B_tokenizer.json"
models = {
"raven-14b-ctx4096": {
"repo_id": "BlinkDL/rwkv-4-raven",
"title": "RWKV-4-Raven-14B-v8-Eng-20230408-ctx4096",
},
"raven-7b-ctx4096": {
"repo_id": "BlinkDL/rwkv-4-raven",
"title": "RWKV-4-Raven-7B-v7-Eng-20230404-ctx4096",
},
"raven-7b-ctx1024": {
"repo_id": "BlinkDL/rwkv-4-pile-7b",
"title": "RWKV-4-Pile-7B-Instruct-test4-20230326",
},
"rwkv-4-pile-169m": {
"repo_id": "BlinkDL/rwkv-4-pile-169m",
"title": "RWKV-4-Pile-169M-20220807-8023",
},
"raven-1b-ctx4096": {
"repo_id": "BlinkDL/rwkv-4-raven",
"title": "RWKV-4-Raven-1B5-v11-Eng99%-Other1%-20230425-ctx4096",
},
"raven-3b-ctx4096": {
"repo_id": "BlinkDL/rwkv-4-raven",
"title": "RWKV-4-Raven-3B-v11-Eng99%-Other1%-20230425-ctx4096",
},
}


def fetch_tokenizer(tokenizer_path: Path):
url = "https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B/raw/main/20B_tokenizer.json"
tokenizer_path.parent.mkdir(exist_ok=True)

response = requests.get(url)
tokenizer_path.write_bytes(response.content)


def get_model_path(model="rwkv-4-pile-169m"):
tokenizer_path = Path(__file__).parent / "20B_tokenizer.json"
if not tokenizer_path.exists():
fetch_tokenizer(tokenizer_path)

model_params = models[model]

model_path = hf_hub_download(
repo_id=model_params["repo_id"], filename=f"{model_params['title']}.pth"
)

return model_path


if __name__ == "__main__":
get_model_path(MODEL)
108 changes: 108 additions & 0 deletions examples/Raven-RWKV/lib_raven.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import logging
from typing import List

from get_models import MODEL, TOKENIZER_PATH, get_model_path

# if RWKV_CUDA_ON='1' then use CUDA kernel for seq mode (much faster)
# these settings must be configured before attempting to import rwkv
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS

STRATEGIES = {
"streaming": "cuda fp16i8 *40+ -> cpu fp32 *1", # Quite slow, take ~3gb VRAM
"fp16i8": "cuda fp16i8 *40 -> cpu fp32 *1", # fits the 14b on a T4, quite fast
"cpu": "cpu fp32 *1", # requires a lot of RAM
}

STRATEGY = STRATEGIES["cpu"]

logger = logging.getLogger(__file__)

ctx_limit = 4096


def get_model():
model_path = get_model_path(MODEL)

model = RWKV(model=model_path, strategy=STRATEGY) # stream mode w/some static

pipeline = PIPELINE(model, str(TOKENIZER_PATH))

return model, pipeline


def generate_prompt(instruction, prompt=None):
if prompt:
return f"""Below is an instruction that describes a task, paired with an input"\
" that provides further context. Write a response that appropriately completes the request.

# Instruction:
{instruction}

# Input:
{prompt}

# Response:
"""
else:
return f"""Below is an instruction that describes a task. Write a response that "\
"appropriately completes the request.

# Instruction:
{instruction}

# Response:
"""


def complete(
instruction,
model,
pipeline: PIPELINE,
prompt="",
token_count=200,
temperature=1.0,
top_p=0.7,
presencePenalty=0.1,
countPenalty=0.1,
stop_words=None,
):
args = PIPELINE_ARGS(
temperature=max(0.2, float(temperature)),
top_p=float(top_p),
alpha_frequency=countPenalty,
alpha_presence=presencePenalty,
token_ban=[], # ban the generation of some tokens
token_stop=[0],
stop_words=stop_words,
) # stop generation whenever you see any token here

for delta in pipeline.igenerate(ctx=instruction, token_count=token_count, args=args):
yield delta


def embedding(
inputs: List[str],
model,
pipeline,
temperature=1.0, # TODO remove
top_p=0.7,
presencePenalty=0.1,
countPenalty=0.1,
):
PIPELINE_ARGS(
temperature=max(0.2, float(temperature)),
top_p=float(top_p),
alpha_frequency=countPenalty,
alpha_presence=presencePenalty,
token_ban=[], # ban the generation of some tokens
token_stop=[0],
) # stop generation whenever you see any token here

context = [pipeline.encode(ctx)[-ctx_limit:] for ctx in inputs]
_, state = model.forward(context[0], None)
*_, embedding = state

if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
return embedding
36 changes: 36 additions & 0 deletions examples/Raven-RWKV/logging.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[loggers]
keys=root,uicheckapp

[handlers]
keys=consoleHandler,detailedConsoleHandler

[formatters]
keys=normalFormatter,detailedFormatter

[logger_root]
level=INFO
handlers=consoleHandler

[logger_uicheckapp]
level=DEBUG
handlers=detailedConsoleHandler
qualname=uicheckapp
propagate=0

[handler_consoleHandler]
class=StreamHandler
level=DEBUG
formatter=normalFormatter
args=(sys.stdout,)

[handler_detailedConsoleHandler]
class=StreamHandler
level=DEBUG
formatter=detailedFormatter
args=(sys.stdout,)

[formatter_normalFormatter]
format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s

[formatter_detailedFormatter]
format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s call_trace=%(pathname)s L%(lineno)-4d
Loading