Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ set(tbe_eeg_cpu_sources
src/tbe/eeg/indices_generator.cpp)

set(fbgemm_gpu_sources_cpu_static
src/faster_hash_ops/faster_hash.cpp
src/intraining_embedding_pruning_ops/intraining_embedding_pruning_cpu.cpp
src/memory_utils/memory_utils.cpp
src/memory_utils/memory_utils_ops.cpp
Expand All @@ -45,11 +46,6 @@ set(fbgemm_gpu_sources_cpu_static
src/sparse_ops/sparse_ops_meta.cpp
${tbe_eeg_cpu_sources})

if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM)
list(APPEND fbgemm_gpu_sources_cpu_static
src/faster_hash_ops/faster_hash.cpp)
endif()

if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU)
list(APPEND fbgemm_gpu_sources_cpu_static
src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
Expand All @@ -73,6 +69,7 @@ endif()

if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU)
set(fbgemm_gpu_sources_gpu_static
src/faster_hash_ops/faster_hash.cu
src/histogram_binning_calibration_ops.cu
src/input_combine_ops/input_combine.cu
src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
Expand Down Expand Up @@ -131,11 +128,6 @@ if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU)
src/sparse_ops/sparse_segment_sum_csr.cu
src/sparse_ops/sparse_zipf.cu
src/sparse_ops/sparse_block_bucketize_features_2d_weights.cu)

if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM)
list(APPEND fbgemm_gpu_sources_gpu_static
src/faster_hash_ops/faster_hash.cu)
endif()
endif()


Expand Down
12 changes: 6 additions & 6 deletions fbgemm_gpu/src/faster_hash_ops/faster_hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ template <
typename TInput,
typename TIdentity>
void process_item_zch(
const at::PackedTensorAccessor64<TInput, 1>& input,
at::PackedTensorAccessor64<int64_t, 1> output,
const at::PackedTensorAccessor64<TIdentity, 2>& identities,
const at::TensorAccessor<TInput, 1>& input,
at::TensorAccessor<int64_t, 1> output,
const at::TensorAccessor<TIdentity, 2>& identities,
int64_t modulo,
int64_t max_probe,
const int64_t* const local_sizes,
Expand Down Expand Up @@ -140,9 +140,9 @@ void _zero_collision_hash_cpu_out(
HAS_OFFSET, \
TInput, \
TIdentity>( \
input.packed_accessor64<TInput, 1>(), \
output.packed_accessor64<int64_t, 1>(), \
identities.packed_accessor64<TIdentity, 2>(), \
input.accessor<TInput, 1>(), \
output.accessor<int64_t, 1>(), \
identities.accessor<TIdentity, 2>(), \
modulo, \
max_probe, \
local_sizes_ptr, \
Expand Down
17 changes: 17 additions & 0 deletions fbgemm_gpu/src/faster_hash_ops/faster_hash.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ static constexpr int32_t kDefaultTensor = -1;
static constexpr int64_t kMaxIdentityNum = INT32_MAX;
static constexpr int64_t kMaxHours = INT32_MAX;
static constexpr int64_t kSecondsInHour = 60 * 60;
static constexpr int32_t kMaxSpinCount = 10 * 1000;

template <typename T>
__device__ __inline__ T CAS(T* data, T cmp, T val) {
Expand Down Expand Up @@ -167,12 +168,20 @@ __device__ __inline__ int64_t check_min(
// wait.
auto insert_idx = output_index + offset;
int32_t last_seen = kDefaultTensor;
int32_t spin_count = 0;
while (true) {
last_seen =
atomicCAS(metadata + insert_idx, kDefaultTensor, kDefaultTensor);
if (last_seen != kDefaultTensor) {
break;
}
#ifdef USE_ROCM
if (++spin_count > kMaxSpinCount) {
// Metadata write may not be visible yet due to atomic contention.
// Slot not considered, keep original min_index, move to next slot
return min_index;
}
#endif
}

// only check those expired slots
Expand Down Expand Up @@ -225,12 +234,20 @@ __device__ __inline__ bool check_evict<1>(
// has not been written yet, while the other id checking the slot's eviction
// status. Therefore, wait until the metadata is not -1.
int32_t identity_metadata = kDefaultTensor;
int32_t spin_counter = 0;
while (true) {
identity_metadata =
atomicCAS(metadata + output_index, kDefaultTensor, kDefaultTensor);
if (identity_metadata != kDefaultTensor) {
break;
}
#ifdef USE_ROCM
if (++spin_counter > kMaxSpinCount) {
// Metadata write may not be visible yet due to atomic contention.
// Slot not considered, return false, and move to next slot
return false;
}
#endif
}
bool is_more_recent = (identity_metadata < metadata_val);
bool threshold_met = (eviction_threshold > identity_metadata);
Expand Down
21 changes: 2 additions & 19 deletions fbgemm_gpu/test/faster_hash_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
# pyre-ignore[21]
from test_utils import ( # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
gpu_unavailable,
skipIfRocm,
)
except Exception:
from fbgemm_gpu.test.test_utils import gpu_unavailable, skipIfRocm
from fbgemm_gpu.test.test_utils import gpu_unavailable

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:faster_hash_ops")

Expand All @@ -36,7 +35,7 @@ class HashZchKernelEvictionPolicy(IntEnum):


class FasterHashTest(unittest.TestCase):
@skipIfRocm("The CUDA kernel is not supported on ROCm")

@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_no_evict(self) -> None:
"""
Expand Down Expand Up @@ -149,7 +148,6 @@ def test_simple_zch_no_evict(self) -> None:
)
self.assertTrue(torch.equal(output_readonly.cpu(), output_readonly_cpu))

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_no_evict_rand(self) -> None:
"""
Expand Down Expand Up @@ -214,7 +212,6 @@ def test_simple_zch_no_evict_rand(self) -> None:
)
self.assertTrue(torch.equal(output.cpu(), output_readonly_cpu))

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_evict(self) -> None:
"""
Expand Down Expand Up @@ -284,7 +281,6 @@ def test_simple_zch_evict(self) -> None:
)
self.assertTrue(torch.equal(output, output_readonly))

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_evict_with_rand_unique_numbers(self) -> None:
"""
Expand Down Expand Up @@ -340,7 +336,6 @@ def test_simple_zch_evict_with_rand_unique_numbers(self) -> None:
)
self.assertTrue(torch.equal(output, output_readonly))

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_eviction_during_lookup(self) -> None:
"""
Expand Down Expand Up @@ -427,7 +422,6 @@ def test_eviction_during_lookup(self) -> None:
)
self.assertTrue(evict_slots.numel() == 1)

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_zch_int64_nohash_identity(self) -> None:
"""
Expand Down Expand Up @@ -491,7 +485,6 @@ def test_zch_int64_nohash_identity(self) -> None:
f"{identities=} vs {numbers_100_200=}",
)

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_zch_int32_nohash_identity(self) -> None:
"""
Expand Down Expand Up @@ -555,7 +548,6 @@ def test_zch_int32_nohash_identity(self) -> None:
f"{identities=} vs {numbers_100_200=}",
)

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_fallback(self) -> None:
"""
Expand Down Expand Up @@ -654,7 +646,6 @@ def test_fallback(self) -> None:
)
self.assertTrue(torch.all(remapped_ids[-20:] == -1))

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_individual_score_evict(self) -> None:
"""
Expand Down Expand Up @@ -759,7 +750,6 @@ def test_simple_zch_individual_score_evict(self) -> None:
# metadata should not be overwritten
self.assertTrue(torch.equal(metadata, metadata0))

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_zch_lru_evict(self) -> None:
"""
Expand Down Expand Up @@ -908,7 +898,6 @@ def test_zch_lru_evict(self) -> None:
f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}",
)

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_zch_lru_evict_with_unexpired_slots(self) -> None:
"""
Expand Down Expand Up @@ -1048,7 +1037,6 @@ def test_zch_lru_evict_with_unexpired_slots(self) -> None:
f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}",
)

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_rand_numbers_zch_lru_evict(self) -> None:
"""
Expand Down Expand Up @@ -1154,7 +1142,6 @@ def test_rand_numbers_zch_lru_evict(self) -> None:
f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}",
)

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_zch_lru_evict_with_offsets(self) -> None:
"""
Expand Down Expand Up @@ -1325,7 +1312,6 @@ def test_zch_lru_evict_with_offsets(self) -> None:
f"{set(second_half[second_half >= 300].tolist())=}, {set(random_numbers_300_350.tolist())=}",
)

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_opt_in_with_prob(self) -> None:
"""
Expand Down Expand Up @@ -1525,7 +1511,6 @@ def test_opt_in_with_prob(self) -> None:
)
self.assertTrue(torch.equal(output_readonly_cpu, output_readonly.cpu()))

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_zch_lru_evict_train_eval(self) -> None:
"""
Expand Down Expand Up @@ -1614,7 +1599,6 @@ def test_zch_lru_evict_train_eval(self) -> None:
f"{output_readonly_cpu=} v.s {output_readonly=}",
)

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
def test_murmur_hash(self) -> None:
"""
Expand All @@ -1638,7 +1622,6 @@ def test_murmur_hash(self) -> None:
output_item_second_round = torch.ops.fbgemm.murmur_hash3(input_item, 0, 0)
self.assertTrue(torch.equal(output_item_first_round, output_item_second_round))

@skipIfRocm("The CUDA kernel is not supported on ROCm")
@unittest.skipIf(*gpu_unavailable)
@settings(deadline=None)
# pyre-ignore [56]
Expand Down
Loading