diff --git a/vecops_fp32/Makefile b/vecops_fp32/Makefile new file mode 100644 index 0000000..2fb1165 --- /dev/null +++ b/vecops_fp32/Makefile @@ -0,0 +1,18 @@ +TARGET := vecops_fp32 +SRC := main.c +HDR := ssve.h + +ifeq ($(origin CC),default) +$(error Please set CC explicitly, for example: make CC=clang) +endif + +CFLAGS := -O3 -Wall -march=armv9.2-a+sme2+sve+sve2 +.PHONY: all clean + +all: $(TARGET) + +$(TARGET): $(SRC) $(HDR) + $(CC) $(CFLAGS) $(SRC) -o $(TARGET) + +clean: + rm -f $(TARGET) diff --git a/vecops_fp32/README.md b/vecops_fp32/README.md new file mode 100644 index 0000000..baac6d9 --- /dev/null +++ b/vecops_fp32/README.md @@ -0,0 +1,99 @@ +# Arm Streaming SVE (SSVE) Vector Operations + +## Vector operations + +The `ssve.h` header contains a set of basic vector operations implemented with Arm SSVE intrinsics. The kernels support two data formats, `float32` and `complex float32`. In the function names, `f32` means `float32` while `cf32` means `complex float32`. The table below summarizes the available operations. + +| SSVE routine | Notes | +| --- | --- | +| `mul_cf32` | Complex element-wise multiply. | +| `power_cf32` | Outputs L2 norm for each complex element. | +| `conj_scale_cf32` | Applies complex conjugation and then scales by a real scalar. | +| `dot_cf32` | Complex inner product without conjugation. | +| `conj_mul_cf32` | Multiplies `conj(a)` by `b` element-wise. | +| `conj_dot_cf32` | Conjugate complex dot product. | +| `mul_f32` | Real element-wise multiply. | +| `scale_f32` | Scalar multiply operation. | +| `dot_f32` | Real dot product reduced into `c[0]`. | +| `add_f32` | Element-wise addition kernel. | + +Important: the code in this repository is written under a fixed **512-bit** SSVE assumption. The blocking factors, tuple sizes, and tail-handling logic in `ssve.h`, as well as the way `main.c` is used for measurement, are documented with that assumption in mind. This is not presented as a vector-length-agnostic implementation. + +The notes below describe the intended correspondence assuming a **512-bit** SSVE vector length, which means: + +- One SSVE `float32` vector contains 16 lanes. +- `svld1_f32_x4` / `svst1_f32_x4` cover 64 `float` values per tuple. + +The f32 kernels operate on `n` scalar `float32` elements, while the cf32 kernels operate on `n` scalar `complex float32` elements. The complex data is stored as interleaved `(real, imag)` pairs. + +```text +[re0, im0, re1, im1, re2, im2, ...] +``` + +## Benchmark + +`main.c` is a minimal benchmark entry point for the SSVE kernels. + +- It accepts three command-line arguments: `choice`, `n`, and `iter`. +- `choice` selects one of the ten routines implemented in `ssve.h`. +- `n` is passed directly to the selected routine, so its meaning depends on whether the routine is operating on real or complex data. +- `iter` controls how many times the selected routine is executed for timing. +- The program currently prints only the total elapsed time in nanoseconds. + +The benchmark maps `choice` to kernels as follows: + +| `choice` | Kernel | +| --- | --- | +| `1` | `mul_cf32` | +| `2` | `power_cf32` | +| `3` | `conj_scale_cf32` | +| `4` | `dot_cf32` | +| `5` | `conj_mul_cf32` | +| `6` | `conj_dot_cf32` | +| `7` | `mul_f32` | +| `8` | `scale_f32` | +| `9` | `dot_f32` | +| `10` | `add_f32` | + +The program runs the specified kernel `iter` times and then prints the total duration in `ns`. + +## Build instructions + +Build from the `vecops_fp32` directory and pass the compiler explicitly via `CC`. + +The generated binary is intended for AArch64 targets that support the SME2 feature enabled by the Makefile flags. + +```sh +make CC=clang +``` + +The Makefile uses the following flags: + +```text +-O3 -Wall -march=armv9.2-a+sme2+sve+sve2 +``` + +This produces the benchmark binary `vecops_fp32`. + +To run the benchmark: + +```sh +./vecops_fp32 +``` + +Example: + +```sh +./vecops_fp32 1 512 1000 +``` + +To remove the generated binary: + +```sh +make clean CC=clang +``` + +## Numerical Behavior + +- The complex SSVE routines use `FCMLA`-based sequences, so their last-bit results can differ slightly from implementations with non-SIMD instructions. +- Reduction routines can also differ slightly because vector code changes the accumulation order compared with implementations with non-SIMD instructions. diff --git a/vecops_fp32/main.c b/vecops_fp32/main.c new file mode 100644 index 0000000..a792774 --- /dev/null +++ b/vecops_fp32/main.c @@ -0,0 +1,145 @@ +#include "ssve.h" +#include +#include +#include +#include +#include + +static uint64_t get_time_unit() +{ + uint64_t freq; + __asm__ volatile("mrs %0, cntfrq_el0" : "=r"(freq)); + return freq; +} + +static uint64_t get_time_count() +{ + uint64_t stamp; + __asm__ volatile("mrs %0, cntvct_el0" : "=r"(stamp)); + return stamp; +} + +static uint64_t get_time_interval_ns(const uint64_t t1, const uint64_t t2, const uint64_t unit) +{ + assert(t2 >= t1); + return (t2 - t1) * 1000000000 / unit; +} + +static void fill_vec(float *vec, size_t count, float weight) +{ + for (size_t i = 0; i < count; i++) + vec[i] = weight * (i + 1); +} + +static void print_choice_usage(const char *prog) +{ + fprintf(stderr, "Usage: %s \n", prog); + fprintf(stderr, "choice=1 -> mul_cf32\n"); + fprintf(stderr, "choice=2 -> power_cf32\n"); + fprintf(stderr, "choice=3 -> conj_scale_cf32\n"); + fprintf(stderr, "choice=4 -> dot_cf32\n"); + fprintf(stderr, "choice=5 -> conj_mul_cf32\n"); + fprintf(stderr, "choice=6 -> conj_dot_cf32\n"); + fprintf(stderr, "choice=7 -> mul_f32\n"); + fprintf(stderr, "choice=8 -> scale_f32\n"); + fprintf(stderr, "choice=9 -> dot_f32\n"); + fprintf(stderr, "choice=10 -> add_f32\n"); + fprintf(stderr, "Note: n is passed directly to the selected ssve kernel.\n"); +} + +static const char *run_ssve_choice(size_t choice, const float *a, const float *b, float *c, + uint64_t n, float scale) +{ + switch (choice) { + case 1: + mul_cf32(a, b, c, n); + return "mul_cf32"; + case 2: + power_cf32(a, c, n); + return "power_cf32"; + case 3: + conj_scale_cf32(a, scale, c, n); + return "conj_scale_cf32"; + case 4: + dot_cf32(a, b, c, n); + return "dot_cf32"; + case 5: + conj_mul_cf32(a, b, c, n); + return "conj_mul_cf32"; + case 6: + conj_dot_cf32(a, b, c, n); + return "conj_dot_cf32"; + case 7: + mul_f32(a, b, c, n); + return "mul_f32"; + case 8: + scale_f32(a, scale, c, n); + return "scale_f32"; + case 9: + dot_f32(a, b, c, n); + return "dot_f32"; + case 10: + add_f32(a, b, c, n); + return "add_f32"; + default: + return NULL; + } +} + +int main(int argc, char *argv[]) +{ + if (argc != 4) { + fprintf(stderr, "Error argc!\n"); + print_choice_usage(argv[0]); + return 1; + } + size_t choice = atoi(argv[1]), n = atoi(argv[2]), iter = atoi(argv[3]); + if (n == 0 || iter == 0) { + fprintf(stderr, "Error n or iter!\n"); + return 1; + } + float *a = NULL, *b = NULL, *c = NULL; + a = (float *)malloc(sizeof(float) * n * 2); + b = (float *)malloc(sizeof(float) * n * 2); + c = (float *)malloc(sizeof(float) * n * 2); + if (a == NULL || b == NULL || c == NULL) { + fprintf(stderr, "malloc failed!\n"); + free(a); + free(b); + free(c); + return 1; + } + fill_vec(a, n * 2, 0.1f); + fill_vec(b, n * 2, 0.2f); + fill_vec(c, n * 2, 0.0f); + + const float scale = 0.75f; + uint64_t unit = get_time_unit(); + uint64_t t1 = 0; + uint64_t t2 = 0; + uint64_t elapsed_ns = 0; + + if (run_ssve_choice(choice, a, b, c, (uint64_t)n, scale) == NULL) { + fprintf(stderr, "Unsupported choice: %zu\n", choice); + print_choice_usage(argv[0]); + free(a); + free(b); + free(c); + return 1; + } + + fill_vec(c, n * 2, 0.0f); + t1 = get_time_count(); + for (size_t i = 0; i < iter; i++) { + (void)run_ssve_choice(choice, a, b, c, (uint64_t)n, scale); + } + t2 = get_time_count(); + elapsed_ns = get_time_interval_ns(t1, t2, unit); + + printf("%llu\n", (unsigned long long)elapsed_ns); + + free(a); + free(b); + free(c); + return 0; +} \ No newline at end of file diff --git a/vecops_fp32/ssve.h b/vecops_fp32/ssve.h new file mode 100644 index 0000000..6c4eff6 --- /dev/null +++ b/vecops_fp32/ssve.h @@ -0,0 +1,623 @@ +#ifndef __SSVE_H__ +#define __SSVE_H__ + +#include +#include + +void mul_cf32(const float *restrict a, const float *restrict b, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 32; + svcount_t pg = svwhilelt_c32(0ULL, 2 * n, 4); + svbool_t pt = svptrue_b32(); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va = svld1_f32_x4(pg, a); + svfloat32x4_t vb = svld1_f32_x4(pg, b); + svfloat32_t vc1, vc2, vc3, vc4; + vc1 = svdup_n_f32(0); + vc2 = svdup_n_f32(0); + vc3 = svdup_n_f32(0); + vc4 = svdup_n_f32(0); + svfloat32_t va1 = svget4_f32(va, 0), va2 = svget4_f32(va, 1), va3 = svget4_f32(va, 2), + va4 = svget4_f32(va, 3); + svfloat32_t vb1 = svget4_f32(vb, 0), vb2 = svget4_f32(vb, 1), vb3 = svget4_f32(vb, 2), + vb4 = svget4_f32(vb, 3); + vc1 = svcmla_f32_x(pt, vc1, va1, vb1, 0); + vc2 = svcmla_f32_x(pt, vc2, va2, vb2, 0); + vc3 = svcmla_f32_x(pt, vc3, va3, vb3, 0); + vc4 = svcmla_f32_x(pt, vc4, va4, vb4, 0); + vc1 = svcmla_f32_x(pt, vc1, va1, vb1, 90); + vc2 = svcmla_f32_x(pt, vc2, va2, vb2, 90); + vc3 = svcmla_f32_x(pt, vc3, va3, vb3, 90); + vc4 = svcmla_f32_x(pt, vc4, va4, vb4, 90); + svfloat32x4_t vc = svcreate4_f32(vc1, vc2, vc3, vc4); + svst1_f32_x4(pg, c, vc); + a += 64; + b += 64; + c += 64; + } + const uint64_t remain = (n & 31) / 8 + ((n & 7) ? 1 : 0); + for (uint64_t i = 0; i < remain; i++) { + svbool_t pb = svwhilelt_b32(16 * i, (n & 31) * 2); + svfloat32_t va = svld1_f32(pb, a); + svfloat32_t vb = svld1_f32(pb, b); + svfloat32_t vc = svdup_n_f32(0); + vc = svcmla_f32_x(pb, vc, va, vb, 0); + vc = svcmla_f32_x(pb, vc, va, vb, 90); + svst1_f32(pb, c, vc); + a += 16; + b += 16; + c += 16; + } +} + +void power_cf32(const float *restrict a, float *restrict c, const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 64; + svbool_t pt = svptrue_b32(); + svcount_t ptc = svptrue_c32(); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x2_t v0 = svld2_f32(pt, a); + svfloat32x2_t v1 = svld2_f32(pt, a + 32); + svfloat32x2_t v2 = svld2_f32(pt, a + 64); + svfloat32x2_t v3 = svld2_f32(pt, a + 96); + svfloat32_t va0 = svget2_f32(v0, 0), va1 = svget2_f32(v0, 1); + svfloat32_t va2 = svget2_f32(v1, 0), va3 = svget2_f32(v1, 1); + svfloat32_t va4 = svget2_f32(v2, 0), va5 = svget2_f32(v2, 1); + svfloat32_t va6 = svget2_f32(v3, 0), va7 = svget2_f32(v3, 1); + svfloat32_t r0 = svmul_f32_x(pt, va0, va0); + svfloat32_t r1 = svmul_f32_x(pt, va2, va2); + svfloat32_t r2 = svmul_f32_x(pt, va4, va4); + svfloat32_t r3 = svmul_f32_x(pt, va6, va6); + r0 = svmla_f32_x(pt, r0, va1, va1); + r1 = svmla_f32_x(pt, r1, va3, va3); + r2 = svmla_f32_x(pt, r2, va5, va5); + r3 = svmla_f32_x(pt, r3, va7, va7); + a += 128; + svst1_f32_x4(ptc, c, svcreate4_f32(r0, r1, r2, r3)); + c += 64; + } + if (n & 63) { + svbool_t p0 = svwhilelt_b32(0ULL, (n & 63)); + svbool_t p1 = svwhilelt_b32(16ULL, (n & 63)); + svbool_t p2 = svwhilelt_b32(32ULL, (n & 63)); + svbool_t p3 = svwhilelt_b32(48ULL, (n & 63)); + svcount_t pg = svwhilelt_c32(0ULL, (n & 63), 4); + svfloat32x2_t v0 = svld2_f32(p0, a); + svfloat32x2_t v1 = svld2_f32(p1, a + 32); + svfloat32x2_t v2 = svld2_f32(p2, a + 64); + svfloat32x2_t v3 = svld2_f32(p3, a + 96); + svfloat32_t va0 = svget2_f32(v0, 0), va1 = svget2_f32(v0, 1); + svfloat32_t va2 = svget2_f32(v1, 0), va3 = svget2_f32(v1, 1); + svfloat32_t va4 = svget2_f32(v2, 0), va5 = svget2_f32(v2, 1); + svfloat32_t va6 = svget2_f32(v3, 0), va7 = svget2_f32(v3, 1); + svfloat32_t r0 = svmul_f32_x(p0, va0, va0); + svfloat32_t r1 = svmul_f32_x(p1, va2, va2); + svfloat32_t r2 = svmul_f32_x(p2, va4, va4); + svfloat32_t r3 = svmul_f32_x(p3, va6, va6); + r0 = svmla_f32_m(p0, r0, va1, va1); + r1 = svmla_f32_m(p1, r1, va3, va3); + r2 = svmla_f32_m(p2, r2, va5, va5); + r3 = svmla_f32_m(p3, r3, va7, va7); + svst1_f32_x4(pg, c, svcreate4_f32(r0, r1, r2, r3)); + } +} + +void scale_f32(const float *restrict a, const float scale, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 128; + svcount_t ptc = svptrue_c32(); + svbool_t pt = svptrue_b32(); + svfloat32_t vscale = svdup_n_f32(scale); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va0 = svld1_f32_x4(ptc, a); + svfloat32x4_t va1 = svld1_f32_x4(ptc, a + 64); + svfloat32_t va00 = svget4_f32(va0, 0), va01 = svget4_f32(va0, 1), va02 = svget4_f32(va0, 2), + va03 = svget4_f32(va0, 3); + svfloat32_t va10 = svget4_f32(va1, 0), va11 = svget4_f32(va1, 1), va12 = svget4_f32(va1, 2), + va13 = svget4_f32(va1, 3); + svfloat32_t vc0 = svmul_f32_x(pt, va00, vscale); + svfloat32_t vc1 = svmul_f32_x(pt, va01, vscale); + svfloat32_t vc2 = svmul_f32_x(pt, va02, vscale); + svfloat32_t vc3 = svmul_f32_x(pt, va03, vscale); + svfloat32_t vc4 = svmul_f32_x(pt, va10, vscale); + svfloat32_t vc5 = svmul_f32_x(pt, va11, vscale); + svfloat32_t vc6 = svmul_f32_x(pt, va12, vscale); + svfloat32_t vc7 = svmul_f32_x(pt, va13, vscale); + svfloat32x4_t vc_0 = svcreate4_f32(vc0, vc1, vc2, vc3); + svst1_f32_x4(ptc, c, vc_0); + a += 128; + svst1_f32(pt, c + 64, vc4); + svst1_f32(pt, c + 80, vc5); + svst1_f32(pt, c + 96, vc6); + svst1_f32(pt, c + 112, vc7); + c += 128; + } + for (uint64_t i = 0; i < (n & 127); i += 64) { + svcount_t pg = svwhilelt_c32(i, n & 127, 4); + svfloat32x4_t va = svld1_f32_x4(pg, a); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vc0 = svmul_f32_x(pt, va0, vscale); + svfloat32_t vc1 = svmul_f32_x(pt, va1, vscale); + svfloat32_t vc2 = svmul_f32_x(pt, va2, vscale); + svfloat32_t vc3 = svmul_f32_x(pt, va3, vscale); + svfloat32x4_t vc = svcreate4_f32(vc0, vc1, vc2, vc3); + a += 64; + svst1_f32_x4(pg, c, vc); + c += 64; + } +} + +void dot_f32(const float *restrict a, const float *restrict b, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 128; + svfloat32_t vc0 = svdup_n_f32(0), vc1 = svdup_n_f32(0), vc2 = svdup_n_f32(0), + vc3 = svdup_n_f32(0); + + svcount_t ptc = svptrue_c32(); + svbool_t pt = svptrue_b32(); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va0 = svld1_f32_x4(ptc, a); + svfloat32x4_t vb0 = svld1_f32_x4(ptc, b); + svfloat32x4_t va1 = svld1_f32_x4(ptc, a + 64); + svfloat32x4_t vb1 = svld1_f32_x4(ptc, b + 64); + svfloat32_t va00 = svget4_f32(va0, 0), va01 = svget4_f32(va0, 1), va02 = svget4_f32(va0, 2), + va03 = svget4_f32(va0, 3); + svfloat32_t vb00 = svget4_f32(vb0, 0), vb01 = svget4_f32(vb0, 1), vb02 = svget4_f32(vb0, 2), + vb03 = svget4_f32(vb0, 3); + svfloat32_t va10 = svget4_f32(va1, 0), va11 = svget4_f32(va1, 1), va12 = svget4_f32(va1, 2), + va13 = svget4_f32(va1, 3); + svfloat32_t vb10 = svget4_f32(vb1, 0), vb11 = svget4_f32(vb1, 1), vb12 = svget4_f32(vb1, 2), + vb13 = svget4_f32(vb1, 3); + vc0 = svmla_f32_x(pt, vc0, va00, vb00); + vc1 = svmla_f32_x(pt, vc1, va01, vb01); + vc2 = svmla_f32_x(pt, vc2, va02, vb02); + vc3 = svmla_f32_x(pt, vc3, va03, vb03); + vc0 = svmla_f32_x(pt, vc0, va10, vb10); + vc1 = svmla_f32_x(pt, vc1, va11, vb11); + vc2 = svmla_f32_x(pt, vc2, va12, vb12); + vc3 = svmla_f32_x(pt, vc3, va13, vb13); + a += 128; + b += 128; + } + if (n & 64) { + svfloat32x4_t va = svld1_f32_x4(ptc, a); + svfloat32x4_t vb = svld1_f32_x4(ptc, b); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vb0 = svget4_f32(vb, 0), vb1 = svget4_f32(vb, 1), vb2 = svget4_f32(vb, 2), + vb3 = svget4_f32(vb, 3); + vc0 = svmla_f32_x(pt, vc0, va0, vb0); + vc1 = svmla_f32_x(pt, vc1, va1, vb1); + vc2 = svmla_f32_x(pt, vc2, va2, vb2); + vc3 = svmla_f32_x(pt, vc3, va3, vb3); + a += 64; + b += 64; + } + if (n & 63) { + ptc = svwhilelt_c32(0ULL, (n & 63), 4); + svfloat32x4_t va = svld1_f32_x4(ptc, a); + svfloat32x4_t vb = svld1_f32_x4(ptc, b); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vb0 = svget4_f32(vb, 0), vb1 = svget4_f32(vb, 1), vb2 = svget4_f32(vb, 2), + vb3 = svget4_f32(vb, 3); + svboolx2_t p01 = svpext_lane_c32_x2(ptc, 0), p23 = svpext_lane_c32_x2(ptc, 1); + svbool_t p0 = svget2(p01, 0), p1 = svget2(p01, 1), p2 = svget2(p23, 0), p3 = svget2(p23, 1); + vc0 = svmla_f32_m(p0, vc0, va0, vb0); + vc1 = svmla_f32_m(p1, vc1, va1, vb1); + vc2 = svmla_f32_m(p2, vc2, va2, vb2); + vc3 = svmla_f32_m(p3, vc3, va3, vb3); + } + svfloat32_t res_p0 = svadd_f32_x(pt, vc0, vc1); + svfloat32_t res_p1 = svadd_f32_x(pt, vc2, vc3); + svfloat32_t res = svadd_f32_x(pt, res_p0, res_p1); + *c = svaddv(pt, res); +} + +void add_f32(const float *restrict a, const float *restrict b, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 128; + svcount_t ptc = svptrue_c32(); + svbool_t pt = svptrue_b32(); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va0 = svld1_f32_x4(ptc, a); + svfloat32x4_t va1 = svld1_f32_x4(ptc, a + 64); + svfloat32x4_t vb0 = svld1_f32_x4(ptc, b); + svfloat32x4_t vb1 = svld1_f32_x4(ptc, b + 64); + svfloat32_t va00 = svget4_f32(va0, 0), va01 = svget4_f32(va0, 1), va02 = svget4_f32(va0, 2), + va03 = svget4_f32(va0, 3); + svfloat32_t va10 = svget4_f32(va1, 0), va11 = svget4_f32(va1, 1), va12 = svget4_f32(va1, 2), + va13 = svget4_f32(va1, 3); + svfloat32_t vb00 = svget4_f32(vb0, 0), vb01 = svget4_f32(vb0, 1), vb02 = svget4_f32(vb0, 2), + vb03 = svget4_f32(vb0, 3); + svfloat32_t vb10 = svget4_f32(vb1, 0), vb11 = svget4_f32(vb1, 1), vb12 = svget4_f32(vb1, 2), + vb13 = svget4_f32(vb1, 3); + svfloat32_t vc00 = svadd_f32_x(pt, va00, vb00); + svfloat32_t vc01 = svadd_f32_x(pt, va01, vb01); + svfloat32_t vc02 = svadd_f32_x(pt, va02, vb02); + svfloat32_t vc03 = svadd_f32_x(pt, va03, vb03); + svfloat32x4_t vc0 = svcreate4_f32(vc00, vc01, vc02, vc03); + svfloat32_t vc10 = svadd_f32_x(pt, va10, vb10); + svfloat32_t vc11 = svadd_f32_x(pt, va11, vb11); + svfloat32_t vc12 = svadd_f32_x(pt, va12, vb12); + svfloat32_t vc13 = svadd_f32_x(pt, va13, vb13); + svfloat32x4_t vc1 = svcreate4_f32(vc10, vc11, vc12, vc13); + a += 128; + b += 128; + svst1_f32_x4(ptc, c, vc0); + svst1_f32_x4(ptc, c + 64, vc1); + c += 128; + } + for (uint64_t i = 0; i < (n & 127); i += 64) { + svcount_t pg = svwhilelt_c32(i, n & 127, 4); + svfloat32x4_t va = svld1_f32_x4(pg, a); + svfloat32x4_t vb = svld1_f32_x4(pg, b); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vb0 = svget4_f32(vb, 0), vb1 = svget4_f32(vb, 1), vb2 = svget4_f32(vb, 2), + vb3 = svget4_f32(vb, 3); + svfloat32_t vc0 = svadd_f32_x(pt, va0, vb0); + svfloat32_t vc1 = svadd_f32_x(pt, va1, vb1); + svfloat32_t vc2 = svadd_f32_x(pt, va2, vb2); + svfloat32_t vc3 = svadd_f32_x(pt, va3, vb3); + svfloat32x4_t vc = svcreate4_f32(vc0, vc1, vc2, vc3); + a += 64; + b += 64; + svst1_f32_x4(pg, c, vc); + c += 64; + } +} + +void conj_scale_cf32(const float *restrict a, const float scale, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 32; + svcount_t ptc = svptrue_c32(); + svbool_t pt = svptrue_b32(); + svfloat32_t vs1 = svdup_n_f32(scale); + svfloat32_t vs2 = svneg_f32_x(pt, vs1); + svfloat32_t vscale = svzip1_f32(vs1, vs2); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va = svld1_f32_x4(ptc, a); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vc0 = svmul_f32_x(pt, va0, vscale); + svfloat32_t vc1 = svmul_f32_x(pt, va1, vscale); + svfloat32_t vc2 = svmul_f32_x(pt, va2, vscale); + svfloat32_t vc3 = svmul_f32_x(pt, va3, vscale); + a += 64; + svfloat32x4_t vc = svcreate4_f32(vc0, vc1, vc2, vc3); + svst1_f32_x4(ptc, c, vc); + c += 64; + } + if (n & 31) { + svcount_t pg = svwhilelt_c32(0ULL, (n & 31) * 2, 4); + svfloat32x4_t va = svld1_f32_x4(pg, a); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vc0 = svmul_f32_x(pt, va0, vscale); + svfloat32_t vc1 = svmul_f32_x(pt, va1, vscale); + svfloat32_t vc2 = svmul_f32_x(pt, va2, vscale); + svfloat32_t vc3 = svmul_f32_x(pt, va3, vscale); + svfloat32x4_t vc = svcreate4_f32(vc0, vc1, vc2, vc3); + svst1_f32_x4(pg, c, vc); + } +} + +void dot_cf32(const float *restrict a, const float *restrict b, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 64; + svcount_t ptc = svptrue_c32(); + svbool_t pt = svptrue_b32(); + svfloat32_t vc0 = svdup_n_f32(0), vc1 = svdup_n_f32(0), vc2 = svdup_n_f32(0), + vc3 = svdup_n_f32(0); + svfloat32_t vc4 = svdup_n_f32(0), vc5 = svdup_n_f32(0), vc6 = svdup_n_f32(0), + vc7 = svdup_n_f32(0); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va0 = svld1_f32_x4(ptc, a); + svfloat32x4_t vb0 = svld1_f32_x4(ptc, b); + svfloat32x4_t va1 = svld1_f32_x4(ptc, a + 64); + svfloat32x4_t vb1 = svld1_f32_x4(ptc, b + 64); + svfloat32_t va00 = svget4_f32(va0, 0), va01 = svget4_f32(va0, 1), va02 = svget4_f32(va0, 2), + va03 = svget4_f32(va0, 3); + svfloat32_t vb00 = svget4_f32(vb0, 0), vb01 = svget4_f32(vb0, 1), vb02 = svget4_f32(vb0, 2), + vb03 = svget4_f32(vb0, 3); + svfloat32_t va10 = svget4_f32(va1, 0), va11 = svget4_f32(va1, 1), va12 = svget4_f32(va1, 2), + va13 = svget4_f32(va1, 3); + svfloat32_t vb10 = svget4_f32(vb1, 0), vb11 = svget4_f32(vb1, 1), vb12 = svget4_f32(vb1, 2), + vb13 = svget4_f32(vb1, 3); + vc0 = svcmla_f32_x(pt, vc0, va00, vb00, 0); + vc1 = svcmla_f32_x(pt, vc1, va01, vb01, 0); + vc2 = svcmla_f32_x(pt, vc2, va02, vb02, 0); + vc3 = svcmla_f32_x(pt, vc3, va03, vb03, 0); + vc4 = svcmla_f32_x(pt, vc4, va10, vb10, 0); + vc5 = svcmla_f32_x(pt, vc5, va11, vb11, 0); + vc6 = svcmla_f32_x(pt, vc6, va12, vb12, 0); + vc7 = svcmla_f32_x(pt, vc7, va13, vb13, 0); + vc0 = svcmla_f32_x(pt, vc0, va00, vb00, 90); + vc1 = svcmla_f32_x(pt, vc1, va01, vb01, 90); + vc2 = svcmla_f32_x(pt, vc2, va02, vb02, 90); + vc3 = svcmla_f32_x(pt, vc3, va03, vb03, 90); + vc4 = svcmla_f32_x(pt, vc4, va10, vb10, 90); + vc5 = svcmla_f32_x(pt, vc5, va11, vb11, 90); + vc6 = svcmla_f32_x(pt, vc6, va12, vb12, 90); + vc7 = svcmla_f32_x(pt, vc7, va13, vb13, 90); + a += 128; + b += 128; + } + if (n & 32) { + svfloat32x4_t va = svld1_f32_x4(ptc, a); + svfloat32x4_t vb = svld1_f32_x4(ptc, b); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vb0 = svget4_f32(vb, 0), vb1 = svget4_f32(vb, 1), vb2 = svget4_f32(vb, 2), + vb3 = svget4_f32(vb, 3); + vc0 = svcmla_f32_x(pt, vc0, va0, vb0, 0); + vc1 = svcmla_f32_x(pt, vc1, va1, vb1, 0); + vc2 = svcmla_f32_x(pt, vc2, va2, vb2, 0); + vc3 = svcmla_f32_x(pt, vc3, va3, vb3, 0); + vc0 = svcmla_f32_x(pt, vc0, va0, vb0, 90); + vc1 = svcmla_f32_x(pt, vc1, va1, vb1, 90); + vc2 = svcmla_f32_x(pt, vc2, va2, vb2, 90); + vc3 = svcmla_f32_x(pt, vc3, va3, vb3, 90); + a += 64; + b += 64; + } + if (n & 31) { + svcount_t pg = svwhilelt_c32(0ULL, (n & 31) * 2, 4); + svfloat32x4_t va = svld1_f32_x4(pg, a); + svfloat32x4_t vb = svld1_f32_x4(pg, b); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vb0 = svget4_f32(vb, 0), vb1 = svget4_f32(vb, 1), vb2 = svget4_f32(vb, 2), + vb3 = svget4_f32(vb, 3); + svboolx2_t p01 = svpext_lane_c32_x2(pg, 0); + svboolx2_t p23 = svpext_lane_c32_x2(pg, 1); + svbool_t p0 = svget2(p01, 0), p1 = svget2(p01, 1), p2 = svget2(p23, 0), p3 = svget2(p23, 1); + vc0 = svcmla_f32_m(p0, vc0, va0, vb0, 0); + vc1 = svcmla_f32_m(p1, vc1, va1, vb1, 0); + vc2 = svcmla_f32_m(p2, vc2, va2, vb2, 0); + vc3 = svcmla_f32_m(p3, vc3, va3, vb3, 0); + vc0 = svcmla_f32_m(p0, vc0, va0, vb0, 90); + vc1 = svcmla_f32_m(p1, vc1, va1, vb1, 90); + vc2 = svcmla_f32_m(p2, vc2, va2, vb2, 90); + vc3 = svcmla_f32_m(p3, vc3, va3, vb3, 90); + } + svfloat32_t res_p0 = svadd_f32_x(pt, vc0, vc1); + svfloat32_t res_p1 = svadd_f32_x(pt, vc2, vc3); + svfloat32_t res_p2 = svadd_f32_x(pt, vc4, vc5); + svfloat32_t res_p3 = svadd_f32_x(pt, vc6, vc7); + svfloat32_t res0 = svadd_f32_x(pt, res_p0, res_p1); + svfloat32_t res1 = svadd_f32_x(pt, res_p2, res_p3); + svfloat32_t r0 = svuzp1_f32(res0, res1); + svfloat32_t r1 = svuzp2_f32(res0, res1); + *c = svaddv_f32(pt, r0); + *(c + 1) = svaddv_f32(pt, r1); +} + +void mul_f32(const float *restrict a, const float *restrict b, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 128; + svcount_t ptc = svptrue_c32(); + svbool_t pt = svptrue_b32(); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va0 = svld1_f32_x4(ptc, a); + svfloat32x4_t vb0 = svld1_f32_x4(ptc, b); + svfloat32x4_t va1 = svld1_f32_x4(ptc, a + 64); + svfloat32x4_t vb1 = svld1_f32_x4(ptc, b + 64); + svfloat32_t va00 = svget4_f32(va0, 0), va01 = svget4_f32(va0, 1), va02 = svget4_f32(va0, 2), + va03 = svget4_f32(va0, 3); + svfloat32_t va10 = svget4_f32(va1, 0), va11 = svget4_f32(va1, 1), va12 = svget4_f32(va1, 2), + va13 = svget4_f32(va1, 3); + svfloat32_t vb00 = svget4_f32(vb0, 0), vb01 = svget4_f32(vb0, 1), vb02 = svget4_f32(vb0, 2), + vb03 = svget4_f32(vb0, 3); + svfloat32_t vb10 = svget4_f32(vb1, 0), vb11 = svget4_f32(vb1, 1), vb12 = svget4_f32(vb1, 2), + vb13 = svget4_f32(vb1, 3); + svfloat32_t vc00 = svmul_f32_x(pt, va00, vb00); + svfloat32_t vc01 = svmul_f32_x(pt, va01, vb01); + svfloat32_t vc02 = svmul_f32_x(pt, va02, vb02); + svfloat32_t vc03 = svmul_f32_x(pt, va03, vb03); + svfloat32_t vc10 = svmul_f32_x(pt, va10, vb10); + svfloat32_t vc11 = svmul_f32_x(pt, va11, vb11); + svfloat32_t vc12 = svmul_f32_x(pt, va12, vb12); + svfloat32_t vc13 = svmul_f32_x(pt, va13, vb13); + a += 128; + b += 128; + svfloat32x4_t vc0 = svcreate4_f32(vc00, vc01, vc02, vc03); + svst1_f32_x4(ptc, c, vc0); + svfloat32x4_t vc1 = svcreate4_f32(vc10, vc11, vc12, vc13); + svst1_f32_x4(ptc, c + 64, vc1); + c += 128; + } + for (uint64_t i = 0; i < (n & 127); i += 64) { + svcount_t pg = svwhilelt_c32(i, n & 127, 4); + svfloat32x4_t va = svld1_f32_x4(pg, a); + svfloat32x4_t vb = svld1_f32_x4(pg, b); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vb0 = svget4_f32(vb, 0), vb1 = svget4_f32(vb, 1), vb2 = svget4_f32(vb, 2), + vb3 = svget4_f32(vb, 3); + svfloat32_t vc0 = svmul_f32_x(pt, va0, vb0); + svfloat32_t vc1 = svmul_f32_x(pt, va1, vb1); + svfloat32_t vc2 = svmul_f32_x(pt, va2, vb2); + svfloat32_t vc3 = svmul_f32_x(pt, va3, vb3); + svfloat32x4_t vc = svcreate4_f32(vc0, vc1, vc2, vc3); + a += 64; + b += 64; + svst1_f32_x4(pg, c, vc); + c += 64; + } +} + +void conj_mul_cf32(const float *restrict a, const float *restrict b, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 64; + svcount_t pg = svptrue_c32(); + svbool_t pt = svptrue_b32(); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va0 = svld1_f32_x4(pg, a); + svfloat32x4_t vb0 = svld1_f32_x4(pg, b); + svfloat32_t vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7; + svfloat32x4_t va1 = svld1_f32_x4(pg, a + 64); + svfloat32x4_t vb1 = svld1_f32_x4(pg, b + 64); + vc0 = svdup_n_f32(0); + vc1 = svdup_n_f32(0); + vc2 = svdup_n_f32(0); + vc3 = svdup_n_f32(0); + vc4 = svdup_n_f32(0); + vc5 = svdup_n_f32(0); + vc6 = svdup_n_f32(0); + vc7 = svdup_n_f32(0); + svfloat32_t va00 = svget4_f32(va0, 0), va01 = svget4_f32(va0, 1), va02 = svget4_f32(va0, 2), + va03 = svget4_f32(va0, 3); + svfloat32_t va10 = svget4_f32(va1, 0), va11 = svget4_f32(va1, 1), va12 = svget4_f32(va1, 2), + va13 = svget4_f32(va1, 3); + svfloat32_t vb00 = svget4_f32(vb0, 0), vb01 = svget4_f32(vb0, 1), vb02 = svget4_f32(vb0, 2), + vb03 = svget4_f32(vb0, 3); + svfloat32_t vb10 = svget4_f32(vb1, 0), vb11 = svget4_f32(vb1, 1), vb12 = svget4_f32(vb1, 2), + vb13 = svget4_f32(vb1, 3); + vc0 = svcmla_f32_x(pt, vc0, vb00, va00, 0); + vc1 = svcmla_f32_x(pt, vc1, vb01, va01, 0); + vc2 = svcmla_f32_x(pt, vc2, vb02, va02, 0); + vc3 = svcmla_f32_x(pt, vc3, vb03, va03, 0); + vc4 = svcmla_f32_x(pt, vc4, vb10, va10, 0); + vc5 = svcmla_f32_x(pt, vc5, vb11, va11, 0); + vc6 = svcmla_f32_x(pt, vc6, vb12, va12, 0); + vc7 = svcmla_f32_x(pt, vc7, vb13, va13, 0); + vc0 = svcmla_f32_x(pt, vc0, vb00, va00, 270); + vc1 = svcmla_f32_x(pt, vc1, vb01, va01, 270); + vc2 = svcmla_f32_x(pt, vc2, vb02, va02, 270); + vc3 = svcmla_f32_x(pt, vc3, vb03, va03, 270); + svfloat32x4_t vc_p0 = svcreate4_f32(vc0, vc1, vc2, vc3); + svst1_f32_x4(pg, c, vc_p0); + vc4 = svcmla_f32_x(pt, vc4, vb10, va10, 270); + vc5 = svcmla_f32_x(pt, vc5, vb11, va11, 270); + vc6 = svcmla_f32_x(pt, vc6, vb12, va12, 270); + vc7 = svcmla_f32_x(pt, vc7, vb13, va13, 270); + svfloat32x4_t vc_p1 = svcreate4_f32(vc4, vc5, vc6, vc7); + a += 128; + b += 128; + svst1_f32_x4(pg, c + 64, vc_p1); + c += 128; + } + for (uint64_t i = 0; i < (n & 63); i += 32) { + svcount_t pg = svwhilelt_c32(i * 2, (n & 63) * 2, 4); + svfloat32x4_t va = svld1_f32_x4(pg, a); + svfloat32x4_t vb = svld1_f32_x4(pg, b); + svfloat32_t vc0, vc1, vc2, vc3; + vc0 = svdup_n_f32(0); + vc1 = svdup_n_f32(0); + vc2 = svdup_n_f32(0); + vc3 = svdup_n_f32(0); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vb0 = svget4_f32(vb, 0), vb1 = svget4_f32(vb, 1), vb2 = svget4_f32(vb, 2), + vb3 = svget4_f32(vb, 3); + vc0 = svcmla_f32_x(pt, vc0, vb0, va0, 0); + vc1 = svcmla_f32_x(pt, vc1, vb1, va1, 0); + vc2 = svcmla_f32_x(pt, vc2, vb2, va2, 0); + vc3 = svcmla_f32_x(pt, vc3, vb3, va3, 0); + vc0 = svcmla_f32_x(pt, vc0, vb0, va0, 270); + vc1 = svcmla_f32_x(pt, vc1, vb1, va1, 270); + vc2 = svcmla_f32_x(pt, vc2, vb2, va2, 270); + vc3 = svcmla_f32_x(pt, vc3, vb3, va3, 270); + a += 64; + b += 64; + svfloat32x4_t vc = svcreate4_f32(vc0, vc1, vc2, vc3); + svst1_f32_x4(pg, c, vc); + c += 64; + } +} + +void conj_dot_cf32(const float *restrict a, const float *restrict b, float *restrict c, + const uint64_t n) __arm_streaming +{ + const uint64_t iter = n / 64; + svcount_t ptc = svptrue_c32(); + svbool_t pt = svptrue_b32(); + svfloat32_t vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7; + vc0 = svdup_n_f32(0); + vc1 = svdup_n_f32(0); + vc2 = svdup_n_f32(0); + vc3 = svdup_n_f32(0); + vc4 = svdup_n_f32(0); + vc5 = svdup_n_f32(0); + vc6 = svdup_n_f32(0); + vc7 = svdup_n_f32(0); + for (uint64_t i = 0; i < iter; i++) { + svfloat32x4_t va0 = svld1_f32_x4(ptc, a); + svfloat32x4_t vb0 = svld1_f32_x4(ptc, b); + svfloat32x4_t va1 = svld1_f32_x4(ptc, a + 64); + svfloat32x4_t vb1 = svld1_f32_x4(ptc, b + 64); + svfloat32_t va00 = svget4_f32(va0, 0), va01 = svget4_f32(va0, 1), va02 = svget4_f32(va0, 2), + va03 = svget4_f32(va0, 3); + svfloat32_t va10 = svget4_f32(va1, 0), va11 = svget4_f32(va1, 1), va12 = svget4_f32(va1, 2), + va13 = svget4_f32(va1, 3); + svfloat32_t vb00 = svget4_f32(vb0, 0), vb01 = svget4_f32(vb0, 1), vb02 = svget4_f32(vb0, 2), + vb03 = svget4_f32(vb0, 3); + svfloat32_t vb10 = svget4_f32(vb1, 0), vb11 = svget4_f32(vb1, 1), vb12 = svget4_f32(vb1, 2), + vb13 = svget4_f32(vb1, 3); + vc0 = svcmla_f32_x(pt, vc0, vb00, va00, 0); + vc1 = svcmla_f32_x(pt, vc1, vb01, va01, 0); + vc2 = svcmla_f32_x(pt, vc2, vb02, va02, 0); + vc3 = svcmla_f32_x(pt, vc3, vb03, va03, 0); + vc4 = svcmla_f32_x(pt, vc4, vb10, va10, 0); + vc5 = svcmla_f32_x(pt, vc5, vb11, va11, 0); + vc6 = svcmla_f32_x(pt, vc6, vb12, va12, 0); + vc7 = svcmla_f32_x(pt, vc7, vb13, va13, 0); + vc0 = svcmla_f32_x(pt, vc0, vb00, va00, 270); + vc1 = svcmla_f32_x(pt, vc1, vb01, va01, 270); + vc2 = svcmla_f32_x(pt, vc2, vb02, va02, 270); + vc3 = svcmla_f32_x(pt, vc3, vb03, va03, 270); + vc4 = svcmla_f32_x(pt, vc4, vb10, va10, 270); + vc5 = svcmla_f32_x(pt, vc5, vb11, va11, 270); + vc6 = svcmla_f32_x(pt, vc6, vb12, va12, 270); + vc7 = svcmla_f32_x(pt, vc7, vb13, va13, 270); + a += 128; + b += 128; + } + for (uint64_t i = 0; i < (n & 63); i += 32) { + svcount_t pg = svwhilelt_c32(i * 2, (n & 63) * 2, 4); + svfloat32x4_t va = svld1_f32_x4(pg, a); + svfloat32x4_t vb = svld1_f32_x4(pg, b); + svboolx2_t p01 = svpext_lane_c32_x2(pg, 0); + svboolx2_t p23 = svpext_lane_c32_x2(pg, 1); + svbool_t p0 = svget2(p01, 0), p1 = svget2(p01, 1), p2 = svget2(p23, 0), p3 = svget2(p23, 1); + svfloat32_t va0 = svget4_f32(va, 0), va1 = svget4_f32(va, 1), va2 = svget4_f32(va, 2), + va3 = svget4_f32(va, 3); + svfloat32_t vb0 = svget4_f32(vb, 0), vb1 = svget4_f32(vb, 1), vb2 = svget4_f32(vb, 2), + vb3 = svget4_f32(vb, 3); + vc0 = svcmla_f32_m(p0, vc0, vb0, va0, 0); + vc1 = svcmla_f32_m(p1, vc1, vb1, va1, 0); + vc2 = svcmla_f32_m(p2, vc2, vb2, va2, 0); + vc3 = svcmla_f32_m(p3, vc3, vb3, va3, 0); + vc0 = svcmla_f32_m(p0, vc0, vb0, va0, 270); + vc1 = svcmla_f32_m(p1, vc1, vb1, va1, 270); + vc2 = svcmla_f32_m(p2, vc2, vb2, va2, 270); + vc3 = svcmla_f32_m(p3, vc3, vb3, va3, 270); + a += 64; + b += 64; + } + svfloat32_t res_p0 = svadd_f32_x(pt, vc0, vc1); + svfloat32_t res_p1 = svadd_f32_x(pt, vc2, vc3); + svfloat32_t res_p2 = svadd_f32_x(pt, vc4, vc5); + svfloat32_t res_p3 = svadd_f32_x(pt, vc6, vc7); + svfloat32_t res0 = svadd_f32_x(pt, res_p0, res_p1); + svfloat32_t res1 = svadd_f32_x(pt, res_p2, res_p3); + svfloat32_t res_real = svuzp1_f32(res0, res1); + svfloat32_t res_imag = svuzp2_f32(res0, res1); + float real = svaddv_f32(pt, res_real); + float imag = svaddv_f32(pt, res_imag); + *c = real; + *(c + 1) = imag; +} + +#endif \ No newline at end of file