Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ struct AttentionActivations {
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
size_t max_workers, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: rep_factor(max_workers *
: heads(layer_config.heads),
qkv_dim(layer_config.qkv_dim),
rep_factor(max_workers *
AttentionActivations::kThreadReplicationFactor /
layer_config.heads),
// `vocab_size == 0` means it is for Vit part, VitAttention
Expand Down Expand Up @@ -144,6 +146,13 @@ struct AttentionActivations {
// `inv_timescale*` are not batched.
}

size_t heads;
size_t qkv_dim;
AlignedBF16Vector bf16_queries;
std::vector<int16_t, hwy::AlignedAllocator<int16_t>> int16_queries;
AlignedFloatVector float_queries;
AlignedFloatVector q_scales;

// Maximum factor by which we might scale-up work to maximize parallelism.
size_t rep_factor = 1;
// Parameters for flash attention. The size of the vector is somewhere between
Expand Down Expand Up @@ -191,6 +200,10 @@ struct AttentionActivationsPtrs {
: config(config),
flash_params(flash_params),
split_flash_params(split_flash_params),
bf16_queries(nullptr),
int16_queries(nullptr),
float_queries(nullptr),
q_scales(nullptr),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}
Expand All @@ -212,6 +225,10 @@ struct AttentionActivationsPtrs {
att_sums = activations.att_sums;
inv_timescale = activations.inv_timescale;
inv_timescale_global = activations.inv_timescale_global;
bf16_queries = &activations.bf16_queries;
int16_queries = &activations.int16_queries;
float_queries = &activations.float_queries;
q_scales = &activations.q_scales;
}

void SetBatchSize(size_t batch_size) {
Expand Down Expand Up @@ -277,6 +294,10 @@ struct AttentionActivationsPtrs {
sub_task_exp_denominator_sums;
std::vector<AlignedFloatVector>*
sub_task_max_logits;
AlignedBF16Vector* bf16_queries;
std::vector<int16_t, hwy::AlignedAllocator<int16_t>>* int16_queries;
AlignedFloatVector* float_queries;
AlignedFloatVector* q_scales;
// Inverse timescales for RoPE computation.
MatPtrT<float> inv_timescale;
// Inverse timescales for global RoPE computation.
Expand Down
373 changes: 149 additions & 224 deletions gemma/flash_attention.cc

Large diffs are not rendered by default.

13 changes: 6 additions & 7 deletions gemma/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,24 @@ namespace gcpp {
ThreadingContext& ctx, AttentionImpl attention_impl); \
\
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( \
hwy::Span<const MatPtr> kvs, int q_count, \
const hwy::Span<const float* HWY_RESTRICT> q_T_in_groups_up_to_4, \
hwy::Span<const MatPtr> kvs, size_t q_count, \
const float* HWY_RESTRICT q_base, \
hwy::Span<const size_t> start_pos_per_query, \
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( \
hwy::Span<const MatPtr> kvs, int q_count, \
const hwy::Span<const BF16 * HWY_RESTRICT> q_T_in_groups_up_to_4, \
hwy::Span<const MatPtr> kvs, size_t q_count, \
const BF16* HWY_RESTRICT q_base, \
hwy::Span<const size_t> start_pos_per_query, \
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( \
hwy::Span<const MatPtr> kvs, int q_count, \
const hwy::Span<const int16_t* HWY_RESTRICT> q_T_in_groups_up_to_4, \
hwy::Span<const float> q_scales, \
hwy::Span<const MatPtr> kvs, size_t q_count, \
const int16_t* HWY_RESTRICT q_base, hwy::Span<const float> q_scales, \
hwy::Span<const size_t> start_pos_per_query, \
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
Expand Down
137 changes: 48 additions & 89 deletions gemma/flash_attention_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,17 @@ void PopulateTestKVCache(MatStorageT<T>& kv, gcpp::KVEncoding encoding,
}
}

AlignedFloatVector PopulateTestQueries(size_t num_queries, size_t qkv_dim) {
AlignedFloatVector q_all(num_queries * qkv_dim);
const float unpredictable_factor = 0.01f * hwy::Unpredictable1();
for (size_t i = 0; i < num_queries; ++i) {
for (size_t j = 0; j < qkv_dim; ++j) {
q_all[i * qkv_dim + j] = unpredictable_factor * (i + 1) / (j + 1);
}
}
return q_all;
}

struct AttentionTestEnv {
AttentionTestEnv(size_t num_queries, size_t kv_seq_len, size_t qkv_dim,
AttentionImpl attention_impl);
Expand Down Expand Up @@ -492,18 +503,7 @@ void TestTiledFlashAttention() {
2 * qkv_dim * gcpp::KVCache::kTileSize),
ctx.allocator, MatPadding::kPacked);
PopulateTestKVCache(kv, gcpp::KVEncoding::kF32, qkv_dim);
std::vector<float> q_float(4 * qkv_dim);
std::vector<float> q_float2(4 * qkv_dim);
// fill in qs with predictable, synthetic data
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < qkv_dim; j++) {
float val_1 = 0.01f * (i + 1) / (j + 1);
float val_2 = 0.01f * (i + 4 + 1) / (j + 1);
q_float[j * 4 + i] = val_1;
q_float2[j * 4 + i] = val_2;
}
}
const float* q_T[2] = {q_float.data(), q_float2.data()};
AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim);

MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
Expand Down Expand Up @@ -536,7 +536,7 @@ void TestTiledFlashAttention() {

hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
kvs, num_queries, hwy::Span<const float*>(q_T, 2),
kvs, num_queries, q_all.data(),
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());
Expand Down Expand Up @@ -578,22 +578,11 @@ void TestTiledFlashAttentionBF16() {
ctx.allocator, MatPadding::kPacked);
PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16TwoTranspositions, qkv_dim);

std::vector<float> q_all(num_queries * qkv_dim);
for (size_t i = 0; i < num_queries; ++i) {
for (size_t j = 0; j < qkv_dim; ++j) {
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
}
}
std::vector<float*> q_ptrs(num_queries);
for (int i = 0; i < num_queries; ++i) {
q_ptrs[i] = q_all.data() + i * qkv_dim;
}
auto [transposed_queries, transposed_queries_ptrs, _] =
TransposeQueriesToGroupsOfNBF16orInt16<BF16>(hwy::Span<float*>(q_ptrs),
qkv_dim, /*group_size=*/4);
hwy::Span<const BF16*> q_T(
const_cast<const BF16**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size());
AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim);
std::vector<BF16, hwy::AlignedAllocator<BF16>> bf16_queries(num_queries *
qkv_dim);
CompressQueriesBF16Contiguous(q_all.data(), qkv_dim, num_queries,
bf16_queries.data());

MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
Expand Down Expand Up @@ -624,7 +613,8 @@ void TestTiledFlashAttentionBF16() {
}
hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kvs, num_queries, q_T, hwy::Span<const size_t>(start_pos_per_query),
kvs, num_queries, bf16_queries.data(),
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());

Expand Down Expand Up @@ -673,18 +663,7 @@ void TestTiledFlashAttentionInt8() {
ctx.allocator, MatPadding::kPacked);
PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8, qkv_dim);

std::vector<float> q_float(4 * qkv_dim);
std::vector<float> q_float2(4 * qkv_dim);
// fill in qs with predictable, synthetic data
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < qkv_dim; j++) {
float val_1 = 0.01f * (i + 1) / (j + 1);
float val_2 = 0.01f * (i + 4 + 1) / (j + 1);
q_float[j * 4 + i] = val_1;
q_float2[j * 4 + i] = val_2;
}
}
const float* q_T[2] = {q_float.data(), q_float2.data()};
AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim);

MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
Expand Down Expand Up @@ -717,7 +696,7 @@ void TestTiledFlashAttentionInt8() {

hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
kvs, num_queries, hwy::Span<const float*>(q_T, 2),
kvs, num_queries, q_all.data(),
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());
Expand All @@ -741,45 +720,35 @@ void TestTiledFlashAttentionInt8() {


void TestTiledFlashAttentionInt8BF16() {
int qkv_dim = 64;
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
// tiles size to test the padding logic.
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
size_t qkv_dim = 64;
size_t kv_seq_len = 60; // number of tokens we will attend to. Not divisible
// by tiles size to test the padding logic.
size_t padded_kv_seq_len =
hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
float att_cap = 10.0f;
int num_queries = 8;
int num_queries_per_timestep = 4;
int num_tokens = num_queries / num_queries_per_timestep;
int kv_seq_end =
size_t num_queries = 8;
size_t num_queries_per_timestep = 4;
size_t num_tokens = num_queries / num_queries_per_timestep;
size_t kv_seq_end =
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);

int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;
size_t num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
size_t tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;

MatStorageT<int8_t> kv("kv", Extents2D(num_tiles, tile_size_bytes),
ctx.allocator, MatPadding::kPacked);

// fill in kvs with predictable, synthetic data matching BF16 paired layout
PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim);

std::vector<float> q_all(num_queries * qkv_dim);
for (int i = 0; i < num_queries; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
}
}
std::vector<float*> q_ptrs(num_queries);
for (int i = 0; i < num_queries; ++i) {
q_ptrs[i] = q_all.data() + i * qkv_dim;
}
auto [transposed_queries, transposed_queries_ptrs, _] =
TransposeQueriesToGroupsOfNBF16orInt16<BF16>(hwy::Span<float*>(q_ptrs),
qkv_dim, /*group_size=*/4);
hwy::Span<const BF16*> q_T(
const_cast<const BF16**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size());
AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim);
std::vector<BF16, hwy::AlignedAllocator<BF16>> bf16_queries(num_queries *
qkv_dim);
CompressQueriesBF16Contiguous(q_all.data(), qkv_dim, num_queries,
bf16_queries.data());

MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
Expand Down Expand Up @@ -812,7 +781,8 @@ void TestTiledFlashAttentionInt8BF16() {

hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kvs, num_queries, q_T, hwy::Span<const size_t>(start_pos_per_query),
kvs, num_queries, bf16_queries.data(),
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());

Expand Down Expand Up @@ -853,23 +823,12 @@ void TestTiledFlashAttentionInt8Int16() {
// fill in kvs with predictable, synthetic data matching BF16 paired layout
PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim);

std::vector<float> q_all(num_queries * qkv_dim);
for (int i = 0; i < num_queries; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
}
}
std::vector<float*> q_ptrs(num_queries);
for (int i = 0; i < num_queries; ++i) {
q_ptrs[i] = q_all.data() + i * qkv_dim;
}
auto [transposed_queries, transposed_queries_ptrs, q_scales] =
TransposeQueriesToGroupsOfNBF16orInt16<int16_t>(
hwy::Span<float*>(q_ptrs), qkv_dim, /*group_size=*/4);
hwy::Span<const int16_t*> q_T(
const_cast<const int16_t**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size());

AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim);
std::vector<int16_t, hwy::AlignedAllocator<int16_t>> int16_queries(
num_queries * qkv_dim);
AlignedFloatVector q_scales(num_queries);
CompressQueriesInt16Contiguous(q_all.data(), qkv_dim, num_queries,
int16_queries.data(), q_scales.data());
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
using DF = hn::ScalableTag<float>;
Expand Down Expand Up @@ -901,7 +860,7 @@ void TestTiledFlashAttentionInt8Int16() {

hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
kvs, num_queries, q_T, q_scales,
kvs, num_queries, int16_queries.data(), q_scales,
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());
Expand Down
Loading
Loading