Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions block_stack_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#pragma once
#include <assert.h>


extern __shared__ char arena[];

// ---- Alignment helpers -----
__device__ __forceinline__
constexpr size_t align_up(size_t offset, size_t align) {
return (offset + align - 1u) & ~(align - 1u);
}

class BlockStackAllocator {
public:
static constexpr size_t DefaultAlign = 16;

template<typename T>
class BlockScopedAlloc {
public:
__device__ BlockScopedAlloc(BlockStackAllocator& bsa, T* ptr, size_t cp)
: bsa_(bsa), ptr_(ptr), cp_(cp) {}

BlockScopedAlloc(const BlockScopedAlloc&) = delete;
BlockScopedAlloc& operator=(const BlockScopedAlloc&) = delete;

__device__ BlockScopedAlloc(BlockScopedAlloc&& o)
: bsa_(o.bsa_), ptr_(o.ptr_), cp_(o.cp_) {
o.ptr_ = nullptr;
}

__device__ ~BlockScopedAlloc() {
if (ptr_) bsa_.restore(cp_);
}

__device__ operator T*() const { return ptr_; }
__device__ T* get() const { return ptr_; }

private:
BlockStackAllocator& bsa_;
T* ptr_;
size_t cp_;
};

__device__ void init(size_t capacity) {
if (threadIdx.x == 0) {
capacity_ = capacity;
offset_ = 0;
}
__syncthreads();
}

__device__ void* alloc_raw(size_t bytes, size_t align = DefaultAlign) {
__shared__ size_t alloc_start;
if (threadIdx.x == 0) {
size_t aligned = align_up(offset_, align);
if (aligned + bytes <= capacity_) {
alloc_start = aligned;
offset_ = aligned + bytes;
} else {
alloc_start = capacity_ + 1; // sentinel: out of memory
}
}
__syncthreads();
if (alloc_start > capacity_) return nullptr;
return static_cast<void*>(arena + alloc_start);
}

template<typename T>
__device__ BlockScopedAlloc<T> alloc(size_t count = 1, bool zero_init = false) {
size_t cp = offset_;
__syncthreads();

T* ptr = static_cast<T*>(alloc_raw(sizeof(T) * count, alignof(T)));

return BlockScopedAlloc<T>(*this, ptr, cp);
}

// ---- Collective checkpoint / restore / reset -----------------------
__device__ size_t checkpoint() {
__syncthreads();
return offset_;
}

__device__ void restore(size_t cp) {
if (threadIdx.x == 0) offset_ = cp;
__syncthreads();
}

__device__ void reset() {
if (threadIdx.x == 0) offset_ = 0;
__syncthreads();
}

// ---- Queries -------------------------------------------------------
__device__ size_t used() const { return offset_; }
__device__ size_t remaining() const { return capacity_ - offset_; }
__device__ size_t capacity() const { return capacity_; }

private:
size_t capacity_;
size_t offset_;
};


__device__ __forceinline__
BlockStackAllocator& block_allocator() {
static __shared__ BlockStackAllocator bsa;
return bsa;
}
13 changes: 12 additions & 1 deletion transform_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "util.h"
#include "mxm_level3.h"
#include "block_stack_allocator.h"

/**
* Transform wrapper for Level-3 (B in LDS + register accumulation).
Expand Down Expand Up @@ -36,13 +37,23 @@ LAUNCH_BOUNDS(MAX_THREADS_PER_BLOCK, 1)
__global__ void transform_kernel_level3_k(int nfuncs,
const T* A, const T* B, T* C, T* workspace) {
constexpr int K2NDIM = K * K * K;
__shared__ BlockStackAllocator bs;
size_type sh_mem = mra::mTxmq_level3_shmem_size<T>(K);
bs.init(sh_mem);
auto bsa = bs.alloc<T>(K*K);
T* B_shmem = bsa.get();

for(size_t i = thread_id();i<(K*K);i+=block_size()){
B_shmem[i] = B[i];
}

T* w = workspace + blockIdx.x * K2NDIM;
for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) {
const T* a = A + i * K2NDIM;
T* c = C + i * K2NDIM;
/* result pointer starts at c; workspace is w */
T* result = c;
transform_level3_k<T, K>(a, B, result, w);
transform_level3_k<T, K>(a, B_shmem, result, w);
}
}

Expand Down