diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 73a17572ef..c0ea49aa80 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -24,6 +24,7 @@ set(tbe_eeg_cpu_sources src/tbe/eeg/indices_generator.cpp) set(fbgemm_gpu_sources_cpu_static + src/faster_hash_ops/faster_hash.cpp src/intraining_embedding_pruning_ops/intraining_embedding_pruning_cpu.cpp src/memory_utils/memory_utils.cpp src/memory_utils/memory_utils_ops.cpp @@ -45,11 +46,6 @@ set(fbgemm_gpu_sources_cpu_static src/sparse_ops/sparse_ops_meta.cpp ${tbe_eeg_cpu_sources}) -if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) - list(APPEND fbgemm_gpu_sources_cpu_static - src/faster_hash_ops/faster_hash.cpp) -endif() - if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU) list(APPEND fbgemm_gpu_sources_cpu_static src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp @@ -73,6 +69,7 @@ endif() if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU) set(fbgemm_gpu_sources_gpu_static + src/faster_hash_ops/faster_hash.cu src/histogram_binning_calibration_ops.cu src/input_combine_ops/input_combine.cu src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu @@ -131,11 +128,6 @@ if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU) src/sparse_ops/sparse_segment_sum_csr.cu src/sparse_ops/sparse_zipf.cu src/sparse_ops/sparse_block_bucketize_features_2d_weights.cu) - - if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) - list(APPEND fbgemm_gpu_sources_gpu_static - src/faster_hash_ops/faster_hash.cu) - endif() endif() diff --git a/fbgemm_gpu/src/faster_hash_ops/faster_hash.cpp b/fbgemm_gpu/src/faster_hash_ops/faster_hash.cpp index 2a2238766e..7917828900 100644 --- a/fbgemm_gpu/src/faster_hash_ops/faster_hash.cpp +++ b/fbgemm_gpu/src/faster_hash_ops/faster_hash.cpp @@ -34,9 +34,9 @@ template < typename TInput, typename TIdentity> void process_item_zch( - const at::PackedTensorAccessor64& input, - at::PackedTensorAccessor64 output, - const at::PackedTensorAccessor64& identities, + const at::TensorAccessor& input, + at::TensorAccessor output, + const at::TensorAccessor& identities, int64_t modulo, int64_t max_probe, const int64_t* const local_sizes, @@ -140,9 +140,9 @@ void _zero_collision_hash_cpu_out( HAS_OFFSET, \ TInput, \ TIdentity>( \ - input.packed_accessor64(), \ - output.packed_accessor64(), \ - identities.packed_accessor64(), \ + input.accessor(), \ + output.accessor(), \ + identities.accessor(), \ modulo, \ max_probe, \ local_sizes_ptr, \ diff --git a/fbgemm_gpu/src/faster_hash_ops/faster_hash.cu b/fbgemm_gpu/src/faster_hash_ops/faster_hash.cu index d1ac231a36..fef6584024 100644 --- a/fbgemm_gpu/src/faster_hash_ops/faster_hash.cu +++ b/fbgemm_gpu/src/faster_hash_ops/faster_hash.cu @@ -32,6 +32,7 @@ static constexpr int32_t kDefaultTensor = -1; static constexpr int64_t kMaxIdentityNum = INT32_MAX; static constexpr int64_t kMaxHours = INT32_MAX; static constexpr int64_t kSecondsInHour = 60 * 60; +static constexpr int32_t kMaxSpinCount = 10 * 1000; template __device__ __inline__ T CAS(T* data, T cmp, T val) { @@ -167,12 +168,20 @@ __device__ __inline__ int64_t check_min( // wait. auto insert_idx = output_index + offset; int32_t last_seen = kDefaultTensor; + int32_t spin_count = 0; while (true) { last_seen = atomicCAS(metadata + insert_idx, kDefaultTensor, kDefaultTensor); if (last_seen != kDefaultTensor) { break; } +#ifdef USE_ROCM + if (++spin_count > kMaxSpinCount) { + // Metadata write may not be visible yet due to atomic contention. + // Slot not considered, keep original min_index, move to next slot + return min_index; + } +#endif } // only check those expired slots @@ -225,12 +234,20 @@ __device__ __inline__ bool check_evict<1>( // has not been written yet, while the other id checking the slot's eviction // status. Therefore, wait until the metadata is not -1. int32_t identity_metadata = kDefaultTensor; + int32_t spin_counter = 0; while (true) { identity_metadata = atomicCAS(metadata + output_index, kDefaultTensor, kDefaultTensor); if (identity_metadata != kDefaultTensor) { break; } +#ifdef USE_ROCM + if (++spin_counter > kMaxSpinCount) { + // Metadata write may not be visible yet due to atomic contention. + // Slot not considered, return false, and move to next slot + return false; + } +#endif } bool is_more_recent = (identity_metadata < metadata_val); bool threshold_met = (eviction_threshold > identity_metadata); diff --git a/fbgemm_gpu/test/faster_hash_test.py b/fbgemm_gpu/test/faster_hash_test.py index e185d0da98..6e8b795906 100644 --- a/fbgemm_gpu/test/faster_hash_test.py +++ b/fbgemm_gpu/test/faster_hash_test.py @@ -22,10 +22,9 @@ # pyre-ignore[21] from test_utils import ( # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils gpu_unavailable, - skipIfRocm, ) except Exception: - from fbgemm_gpu.test.test_utils import gpu_unavailable, skipIfRocm + from fbgemm_gpu.test.test_utils import gpu_unavailable torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:faster_hash_ops") @@ -36,7 +35,7 @@ class HashZchKernelEvictionPolicy(IntEnum): class FasterHashTest(unittest.TestCase): - @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) def test_simple_zch_no_evict(self) -> None: """ @@ -149,7 +148,6 @@ def test_simple_zch_no_evict(self) -> None: ) self.assertTrue(torch.equal(output_readonly.cpu(), output_readonly_cpu)) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_simple_zch_no_evict_rand(self) -> None: """ @@ -214,7 +212,6 @@ def test_simple_zch_no_evict_rand(self) -> None: ) self.assertTrue(torch.equal(output.cpu(), output_readonly_cpu)) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_simple_zch_evict(self) -> None: """ @@ -284,7 +281,6 @@ def test_simple_zch_evict(self) -> None: ) self.assertTrue(torch.equal(output, output_readonly)) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_simple_zch_evict_with_rand_unique_numbers(self) -> None: """ @@ -340,7 +336,6 @@ def test_simple_zch_evict_with_rand_unique_numbers(self) -> None: ) self.assertTrue(torch.equal(output, output_readonly)) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_eviction_during_lookup(self) -> None: """ @@ -427,7 +422,6 @@ def test_eviction_during_lookup(self) -> None: ) self.assertTrue(evict_slots.numel() == 1) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_zch_int64_nohash_identity(self) -> None: """ @@ -491,7 +485,6 @@ def test_zch_int64_nohash_identity(self) -> None: f"{identities=} vs {numbers_100_200=}", ) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_zch_int32_nohash_identity(self) -> None: """ @@ -555,7 +548,6 @@ def test_zch_int32_nohash_identity(self) -> None: f"{identities=} vs {numbers_100_200=}", ) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_fallback(self) -> None: """ @@ -654,7 +646,6 @@ def test_fallback(self) -> None: ) self.assertTrue(torch.all(remapped_ids[-20:] == -1)) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_simple_zch_individual_score_evict(self) -> None: """ @@ -759,7 +750,6 @@ def test_simple_zch_individual_score_evict(self) -> None: # metadata should not be overwritten self.assertTrue(torch.equal(metadata, metadata0)) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_zch_lru_evict(self) -> None: """ @@ -908,7 +898,6 @@ def test_zch_lru_evict(self) -> None: f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}", ) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_zch_lru_evict_with_unexpired_slots(self) -> None: """ @@ -1048,7 +1037,6 @@ def test_zch_lru_evict_with_unexpired_slots(self) -> None: f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}", ) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_rand_numbers_zch_lru_evict(self) -> None: """ @@ -1154,7 +1142,6 @@ def test_rand_numbers_zch_lru_evict(self) -> None: f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}", ) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_zch_lru_evict_with_offsets(self) -> None: """ @@ -1325,7 +1312,6 @@ def test_zch_lru_evict_with_offsets(self) -> None: f"{set(second_half[second_half >= 300].tolist())=}, {set(random_numbers_300_350.tolist())=}", ) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_opt_in_with_prob(self) -> None: """ @@ -1525,7 +1511,6 @@ def test_opt_in_with_prob(self) -> None: ) self.assertTrue(torch.equal(output_readonly_cpu, output_readonly.cpu())) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_zch_lru_evict_train_eval(self) -> None: """ @@ -1614,7 +1599,6 @@ def test_zch_lru_evict_train_eval(self) -> None: f"{output_readonly_cpu=} v.s {output_readonly=}", ) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) def test_murmur_hash(self) -> None: """ @@ -1638,7 +1622,6 @@ def test_murmur_hash(self) -> None: output_item_second_round = torch.ops.fbgemm.murmur_hash3(input_item, 0, 0) self.assertTrue(torch.equal(output_item_first_round, output_item_second_round)) - @skipIfRocm("The CUDA kernel is not supported on ROCm") @unittest.skipIf(*gpu_unavailable) @settings(deadline=None) # pyre-ignore [56]