Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
Imports: attention (>= 0.4.0)
LinkingTo: Rcpp, RcppArmadillo, RcppParallel, RcppEigen
Imports: Rcpp,
RcppArmadillo,
RcppParallel,
RcppEigen
Suggests:
covr,
testthat (>= 3.0.0)
Expand Down
14 changes: 4 additions & 10 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
# Generated by roxygen2: do not edit by hand

export(SoftMax)
export(attention)
export(row_means)
export(row_vars)
export(transformer)
importFrom(attention,SoftMax)
importFrom(attention,attention)
importFrom(stats,rnorm)
importFrom(Rcpp, evalCpp)
importFrom(RcppParallel, RcppParallelLibs)
exportPattern("^[[:alpha:]]+")
useDynLib(transformer)
63 changes: 63 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

SoftMax <- function(x) {
.Call('_transformer_SoftMax', PACKAGE = 'transformer', x)
}

attention <- function(Q, K, V) {
.Call('_transformer_attention', PACKAGE = 'transformer', Q, K, V)
}

feed_forward <- function(x, dff, d_model) {
.Call('_transformer_feed_forward', PACKAGE = 'transformer', x, dff, d_model)
}

fill_mat_rnorm <- function(mat) {
invisible(.Call('_transformer_fill_mat_rnorm', PACKAGE = 'transformer', mat))
}

fill_mat_row <- function(nb_Row, row_Vec) {
.Call('_transformer_fill_mat_row', PACKAGE = 'transformer', nb_Row, row_Vec)
}

insert_sub_mat <- function(mat, sub_Mat, col_Index) {
invisible(.Call('_transformer_insert_sub_mat', PACKAGE = 'transformer', mat, sub_Mat, col_Index))
}

layer_norm <- function(x, epsilon) {
.Call('_transformer_layer_norm', PACKAGE = 'transformer', x, epsilon)
}

mat_mult <- function(mat1, mat2) {
.Call('_transformer_mat_mult', PACKAGE = 'transformer', mat1, mat2)
}

mat_sum <- function(mat1, mat2) {
.Call('_transformer_mat_sum', PACKAGE = 'transformer', mat1, mat2)
}

multi_head <- function(Q, K, V, d_model, num_heads) {
.Call('_transformer_multi_head', PACKAGE = 'transformer', Q, K, V, d_model, num_heads)
}

pmax_mat <- function(mat) {
.Call('_transformer_pmax_mat', PACKAGE = 'transformer', mat)
}

row_max <- function(mat) {
.Call('_transformer_row_max', PACKAGE = 'transformer', mat)
}

row_means <- function(mat) {
.Call('_transformer_row_means', PACKAGE = 'transformer', mat)
}

row_vars <- function(mat) {
.Call('_transformer_row_vars', PACKAGE = 'transformer', mat)
}

sub_mat <- function(mat, col_Index) {
.Call('_transformer_sub_mat', PACKAGE = 'transformer', mat, col_Index)
}

3 changes: 0 additions & 3 deletions R/attention.R

This file was deleted.

19 changes: 0 additions & 19 deletions R/feed_forward.R

This file was deleted.

12 changes: 0 additions & 12 deletions R/layer_norm.R

This file was deleted.

36 changes: 0 additions & 36 deletions R/multi_head.R

This file was deleted.

10 changes: 0 additions & 10 deletions R/row_means.R

This file was deleted.

9 changes: 0 additions & 9 deletions R/row_vars.R

This file was deleted.

3 changes: 0 additions & 3 deletions R/softmax.R

This file was deleted.

25 changes: 0 additions & 25 deletions R/transformer.R

This file was deleted.

12 changes: 12 additions & 0 deletions src/Makevars
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) $(SHLIB_OPENMP_CFLAGS)
PKG_LIBS += $(shell "${R_HOME}/bin${R_ARCH_BIN}/Rscript.exe" \
-e "RcppParallel::RcppParallelLibs()")

PKG_CFLAGS = $(SHLIB_OPENMP_CFLAGS)

PKG_CXXFLAGS = $(SHLIB_OPENMP_CXXFLAGS)
PKG_CXXFLAGS += -std=c++11
PKG_CXXFLAGS += -DRCPP_PARALLEL_USE_TBB=1

CXX=g++
CXX_STD = CXX11
Loading