From 3d123d04a26de94d9ff322024ec33f43f28473fe Mon Sep 17 00:00:00 2001 From: Joseph Antony Date: Thu, 23 Apr 2026 14:46:35 -0400 Subject: [PATCH] Memory Allocator --- block_stack_allocator.h | 109 ++++++++++++++++++++++++++++++++++++++++ transform_level3.h | 13 ++++- 2 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 block_stack_allocator.h diff --git a/block_stack_allocator.h b/block_stack_allocator.h new file mode 100644 index 0000000..b4ed8bd --- /dev/null +++ b/block_stack_allocator.h @@ -0,0 +1,109 @@ +#pragma once +#include + + +extern __shared__ char arena[]; + +// ---- Alignment helpers ----- +__device__ __forceinline__ +constexpr size_t align_up(size_t offset, size_t align) { + return (offset + align - 1u) & ~(align - 1u); +} + +class BlockStackAllocator { +public: + static constexpr size_t DefaultAlign = 16; + + template + class BlockScopedAlloc { + public: + __device__ BlockScopedAlloc(BlockStackAllocator& bsa, T* ptr, size_t cp) + : bsa_(bsa), ptr_(ptr), cp_(cp) {} + + BlockScopedAlloc(const BlockScopedAlloc&) = delete; + BlockScopedAlloc& operator=(const BlockScopedAlloc&) = delete; + + __device__ BlockScopedAlloc(BlockScopedAlloc&& o) + : bsa_(o.bsa_), ptr_(o.ptr_), cp_(o.cp_) { + o.ptr_ = nullptr; + } + + __device__ ~BlockScopedAlloc() { + if (ptr_) bsa_.restore(cp_); + } + + __device__ operator T*() const { return ptr_; } + __device__ T* get() const { return ptr_; } + + private: + BlockStackAllocator& bsa_; + T* ptr_; + size_t cp_; + }; + + __device__ void init(size_t capacity) { + if (threadIdx.x == 0) { + capacity_ = capacity; + offset_ = 0; + } + __syncthreads(); + } + + __device__ void* alloc_raw(size_t bytes, size_t align = DefaultAlign) { + __shared__ size_t alloc_start; + if (threadIdx.x == 0) { + size_t aligned = align_up(offset_, align); + if (aligned + bytes <= capacity_) { + alloc_start = aligned; + offset_ = aligned + bytes; + } else { + alloc_start = capacity_ + 1; // sentinel: out of memory + } + } + __syncthreads(); + if (alloc_start > capacity_) return nullptr; + return static_cast(arena + alloc_start); + } + + template + __device__ BlockScopedAlloc alloc(size_t count = 1, bool zero_init = false) { + size_t cp = offset_; + __syncthreads(); + + T* ptr = static_cast(alloc_raw(sizeof(T) * count, alignof(T))); + + return BlockScopedAlloc(*this, ptr, cp); + } + + // ---- Collective checkpoint / restore / reset ----------------------- + __device__ size_t checkpoint() { + __syncthreads(); + return offset_; + } + + __device__ void restore(size_t cp) { + if (threadIdx.x == 0) offset_ = cp; + __syncthreads(); + } + + __device__ void reset() { + if (threadIdx.x == 0) offset_ = 0; + __syncthreads(); + } + + // ---- Queries ------------------------------------------------------- + __device__ size_t used() const { return offset_; } + __device__ size_t remaining() const { return capacity_ - offset_; } + __device__ size_t capacity() const { return capacity_; } + +private: + size_t capacity_; + size_t offset_; +}; + + +__device__ __forceinline__ +BlockStackAllocator& block_allocator() { + static __shared__ BlockStackAllocator bsa; + return bsa; +} diff --git a/transform_level3.h b/transform_level3.h index c377277..e318694 100644 --- a/transform_level3.h +++ b/transform_level3.h @@ -2,6 +2,7 @@ #include "util.h" #include "mxm_level3.h" +#include "block_stack_allocator.h" /** * Transform wrapper for Level-3 (B in LDS + register accumulation). @@ -36,13 +37,23 @@ LAUNCH_BOUNDS(MAX_THREADS_PER_BLOCK, 1) __global__ void transform_kernel_level3_k(int nfuncs, const T* A, const T* B, T* C, T* workspace) { constexpr int K2NDIM = K * K * K; + __shared__ BlockStackAllocator bs; + size_type sh_mem = mra::mTxmq_level3_shmem_size(K); + bs.init(sh_mem); + auto bsa = bs.alloc(K*K); + T* B_shmem = bsa.get(); + + for(size_t i = thread_id();i<(K*K);i+=block_size()){ + B_shmem[i] = B[i]; + } + T* w = workspace + blockIdx.x * K2NDIM; for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) { const T* a = A + i * K2NDIM; T* c = C + i * K2NDIM; /* result pointer starts at c; workspace is w */ T* result = c; - transform_level3_k(a, B, result, w); + transform_level3_k(a, B_shmem, result, w); } }