To maximise performance in the CUDA V1 version, a number of tricks have been utilised which limit the input shapes. TLDR: should work for dimensions Nx2048x(x*32) - [ ] A custom in memory transpose is used to improve the performance, this utilises a tiling of 32*32, which restricts T and D to be multiples of 32. - [ ] Using a single block for the scan with 2 elements per thread and 1024 threads limits to processing of T<=2048 - [ ] For the backward pass to reverse the indexing I have not implemented the full indexing approach, TLDR limits to T = 2048
To maximise performance in the CUDA V1 version, a number of tricks have been utilised which limit the input shapes.
TLDR: should work for dimensions Nx2048x(x*32)