-
Notifications
You must be signed in to change notification settings - Fork 745
Expand file tree
/
Copy pathfaster_hash_test.py
More file actions
1698 lines (1532 loc) · 61.6 KB
/
faster_hash_test.py
File metadata and controls
1698 lines (1532 loc) · 61.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import unittest
from enum import IntEnum
# pyre-ignore[21]
import fbgemm_gpu # noqa: F401
import torch
from hypothesis import given, settings, strategies as st
# check if we are in open source env to decide how to import necessary modules
try:
# pyre-ignore[21]
from fbgemm_gpu import open_source # noqa: F401
# pyre-ignore[21]
from test_utils import ( # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
gpu_unavailable,
)
except Exception:
from fbgemm_gpu.test.test_utils import gpu_unavailable
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:faster_hash_ops")
class HashZchKernelEvictionPolicy(IntEnum):
THRESHOLD_EVICTION = 0
LRU_EVICTION = 1
class FasterHashTest(unittest.TestCase):
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_no_evict(self) -> None:
"""
Test the basic functionality of zero collision hash without evicting.
It creates a identity table with 200 slots, and insert two batches of
100 numbers with zero collision hash.
Assertions:
1. The outputs of two batches match their inputs.
2. The identity table is fully utilized.
3. The readonly lookup matches the original output.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
# no evict
identities, _ = torch.ops.fbgemm.create_zch_buffer(
200, device=torch.device("cuda")
)
numbers = torch.arange(0, 100, dtype=torch.int64, device="cuda")
local_sizes = torch.ones_like(numbers) * 100
output1, evict_slots1 = torch.ops.fbgemm.zero_collision_hash(
input=numbers,
identities=identities,
max_probe=100,
circular_probe=True,
local_sizes=local_sizes,
offsets=torch.zeros_like(numbers),
)
output2, evict_slots2 = torch.ops.fbgemm.zero_collision_hash(
input=numbers + 100,
identities=identities,
max_probe=100,
circular_probe=True,
local_sizes=local_sizes,
offsets=torch.ones_like(numbers) * 100,
)
self.assertEqual(
torch.unique(output1).tolist(),
numbers.tolist(),
f"{torch.unique(output1).tolist()=} != {numbers.tolist()=}",
)
self.assertEqual(torch.unique(output2).tolist(), (numbers + 100).tolist())
self.assertTrue(torch.all(identities != -1))
# readonly lookup.
output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash(
input=numbers + 100,
identities=identities,
max_probe=100,
circular_probe=True,
exp_hours=-1,
readonly=True,
local_sizes=local_sizes,
offsets=torch.ones_like(numbers) * 100,
)
self.assertTrue(torch.equal(output2, output_readonly))
# CPU
output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash(
input=numbers.cpu() + 100,
identities=identities.cpu(),
max_probe=100,
circular_probe=True,
exp_hours=-1,
readonly=True,
local_sizes=local_sizes.cpu(),
offsets=torch.ones_like(numbers).cpu() * 100,
)
self.assertTrue(
torch.equal(output2.cpu(), output_readonly_cpu),
f"{output2.cpu()=} != {output_readonly_cpu=}",
)
# other numbers.
identities, _ = torch.ops.fbgemm.create_zch_buffer(
100, device=torch.device("cuda")
)
numbers_100_200 = torch.arange(100, 200, dtype=torch.int64, device="cuda")
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_100_200,
identities,
100,
circular_probe=True,
)
self.assertEqual(torch.unique(output).tolist(), numbers.tolist())
self.assertTrue(torch.all(identities != -1))
# readonly lookup.
output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_100_200,
identities,
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(torch.equal(output, output_readonly))
# CPU
output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash(
input=numbers_100_200.cpu(),
identities=identities.cpu(),
max_probe=100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(torch.equal(output_readonly.cpu(), output_readonly_cpu))
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_no_evict_rand(self) -> None:
"""
Test the basic functionality of zero collision hash without evicting.
It creates a identity table with 100 slots, and insert 100 random numbers
with zero collision hash.
Assertions:
1. The outputs of two batches match their inputs.
2. The identity table is fully utilized.
3. The readonly lookup matches the original output.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
# no evict - rand number.
identities, _ = torch.ops.fbgemm.create_zch_buffer(
100, device=torch.device("cuda")
)
random_numbers = torch.randint(0, 100, (100,), device="cuda")
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
random_numbers,
identities,
100,
circular_probe=True,
)
for i in range(100):
to_test = output[random_numbers == i]
if len(to_test) > 0:
self.assertTrue(torch.all(to_test == to_test[0]))
unique_indices = torch.unique(output)
all_indices = torch.arange(identities.size(0), device="cuda")
not_select_indices = torch.isin(all_indices, unique_indices, invert=True)
self.assertTrue(torch.all(identities[unique_indices] != -1))
self.assertTrue(torch.all(identities[not_select_indices] == -1))
unique_elements, counts = torch.unique(
identities[identities[:, 0] != -1][:, 0], return_counts=True
)
self.assertTrue(torch.all(counts == 1))
# readonly lookup.
output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash(
random_numbers,
identities,
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(torch.equal(output, output_readonly))
# CPU
output_readonly_cpu, evict_slots = torch.ops.fbgemm.zero_collision_hash(
random_numbers.cpu(),
identities.cpu(),
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(torch.equal(output.cpu(), output_readonly_cpu))
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_evict(self) -> None:
"""
Test the basic functionality of zero collision hash with evicting.
It creates a identity table with 100 slots, and insert 100 numbers
from 0 to 99 with zero collision hash. Then it inserts 100 numbers
from 100 to 199 with zero collision hash. The last 100 numbers should
evict the first 100 numbers.
Assertions (besides the no-evict assertions):
1. No evictions happen in the first batch.
2. The evicted indices are the first 100 numbers.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
# evict
identities, metadata = torch.ops.fbgemm.create_zch_buffer(
100, support_evict=True, device=torch.device("cuda")
)
numbers = torch.arange(0, 100, dtype=torch.int64, device="cuda")
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers,
identities,
100,
circular_probe=True,
exp_hours=7 * 24,
metadata=metadata,
)
self.assertEqual(torch.unique(output).tolist(), numbers.tolist())
self.assertTrue(evict_slots.numel() == 0)
# readonly lookup.
output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers,
identities,
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(torch.equal(output, output_readonly))
# evict with all expired hours.
metadata[:, 0] -= 7 * 24 + 1
numbers_100_200 = torch.arange(100, 200, dtype=torch.int64, device="cuda")
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_100_200,
identities,
100,
circular_probe=True,
exp_hours=7 * 24,
metadata=metadata,
)
self.assertEqual(torch.unique(output).tolist(), numbers.tolist())
self.assertTrue(torch.all(evict_slots != -1))
self.assertEqual(torch.unique(evict_slots).tolist(), numbers.tolist())
# readonly lookup.
output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_100_200,
identities,
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(torch.equal(output, output_readonly))
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_evict_with_rand_unique_numbers(self) -> None:
"""
Test the basic functionality of zero collision hash with evicting.
It creates a identity table with 100 slots, and insert 100 random numbers
with zero collision hash. Then it inserts 100 random numbers with zero
collision hash. The last 100 numbers should evict the first 100 numbers.
Assertions (besides the no-evict assertions):
1. No evictions happen in the first batch.
2. The evicted indices are the first 100 numbers.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
# evict - rand number.
identities, metadata = torch.ops.fbgemm.create_zch_buffer(
100, support_evict=True, device=torch.device("cuda")
)
random_numbers = torch.unique(torch.randint(0, 100, (100,), device="cuda"))
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
random_numbers,
identities,
100,
circular_probe=True,
exp_hours=7 * 24,
metadata=metadata,
)
for i in range(100):
to_test = output[random_numbers == i]
if len(to_test) > 0:
self.assertTrue(torch.all(to_test == to_test[0]))
unique_indices = torch.unique(output)
all_indices = torch.arange(identities.size(0), device="cuda")
not_select_indices = torch.isin(all_indices, unique_indices, invert=True)
self.assertTrue(torch.all(identities[unique_indices] != -1))
self.assertTrue(torch.all(identities[not_select_indices] == -1))
unique_elements, counts = torch.unique(
identities[identities[:, 0] != -1][:, 0], return_counts=True
)
self.assertTrue(torch.all(counts == 1))
# readonly lookup.
output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash(
random_numbers,
identities[:, 0].unsqueeze(1),
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(torch.equal(output, output_readonly))
@unittest.skipIf(*gpu_unavailable)
def test_eviction_during_lookup(self) -> None:
"""
Test the basic functionality of zero collision hash with evicting.
It creates a identity table with 100 slots, and insert 99 numbers
with zero collision hash. Then it inserts 1 random number 101 with zero
collision hash, then all the slots should be filled. Then it makes all
the slots expired except for the 101 one, and insert the number 101 again
with zero collision hash, then none slot should be evicted but all the identity
table should only have one value. Then it inserts 1 random number 102 with zero
collision hash, then one slot should be evicted.
Assertions:
1. After inserting 99 numbers, there is only one empty slot.
2. After inserting 101, all the slots are filled.
3. After making all the slots expired, and inserting 101 again, there is only one
slot with value 101.
4. After inserting 102, one slot should be evicted.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
identities, metadata = torch.ops.fbgemm.create_zch_buffer(
100, support_evict=True, device=torch.device("cuda")
)
numbers_0_99 = torch.arange(0, 99, dtype=torch.int64, device="cuda")
output_0_99, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_0_99,
identities,
100,
circular_probe=True,
exp_hours=7 * 24,
metadata=metadata,
)
empty_slots = identities[:, 0] == -1
self.assertTrue(torch.sum(empty_slots) == 1, torch.sum(empty_slots))
# insert number 101, should be able to fill all slots.
numbers = torch.tensor([101], dtype=torch.int64, device="cuda")
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers,
identities,
100,
circular_probe=True,
exp_hours=7 * 24,
metadata=metadata,
)
self.assertTrue(torch.all(identities[:, 0] != -1))
# make none 101 slots expired.
metadata[~empty_slots, 0] -= 7 * 24 + 1
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers,
identities,
100,
circular_probe=True,
exp_hours=7 * 24,
metadata=metadata,
)
unique_elements, counts = torch.unique(identities[:, 0], return_counts=True)
self.assertTrue(torch.all(counts == 1))
self.assertTrue(evict_slots.numel() == 0)
# readonly lookup.
output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_0_99,
identities,
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(torch.equal(output_0_99, output_readonly))
# evict some slot.
numbers = torch.tensor([102], dtype=torch.int64, device="cuda")
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers,
identities,
100,
circular_probe=True,
exp_hours=7 * 24,
metadata=metadata,
)
self.assertTrue(evict_slots.numel() == 1)
@unittest.skipIf(*gpu_unavailable)
def test_zch_int64_nohash_identity(self) -> None:
"""
Test with the capability of procesing int64_t data for
zero collision hash.
Assertions:
1. The output is correct for zero collision hash.
2. The eviction works correctly when processing int64_t data.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
# no evict
identities, metadata = torch.ops.fbgemm.create_zch_buffer(
100, device=torch.device("cuda"), support_evict=True, long_type=True
)
numbers = torch.arange(2**33, 2**33 + 100, dtype=torch.int64, device="cuda")
output, _ = torch.ops.fbgemm.zero_collision_hash(
input=numbers,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=False,
exp_hours=7 * 24,
metadata=metadata,
)
self.assertTrue(
torch.equal(
torch.sort(identities[identities != -1].view(-1))[0],
numbers,
),
f"{identities=} vs {numbers=}",
)
numbers_100_200 = torch.arange(
2**33 + 100, 2**33 + 200, dtype=torch.int64, device="cuda"
)
metadata[:, 0] -= 7 * 24 + 1
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
input=numbers_100_200,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=False,
exp_hours=7 * 24,
metadata=metadata,
)
expect_indices = list(range(100))
self.assertEqual(torch.unique(output).tolist(), expect_indices)
self.assertTrue(torch.all(evict_slots != -1))
self.assertEqual(torch.unique(evict_slots).tolist(), expect_indices)
self.assertTrue(
torch.equal(
torch.sort(identities[identities != -1].view(-1))[0],
numbers_100_200,
),
f"{identities=} vs {numbers_100_200=}",
)
@unittest.skipIf(*gpu_unavailable)
def test_zch_int32_nohash_identity(self) -> None:
"""
Test with the capability of procesing int32_t data for
zero collision hash.
Assertions:
1. The output is correct for zero collision hash.
2. The eviction works correctly when processing int32_t data.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
# no evict
identities, metadata = torch.ops.fbgemm.create_zch_buffer(
100, device=torch.device("cuda"), support_evict=True, long_type=False
)
numbers = torch.arange(2**33, 2**33 + 100, dtype=torch.int32, device="cuda")
output, _ = torch.ops.fbgemm.zero_collision_hash(
input=numbers,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=False,
exp_hours=7 * 24,
metadata=metadata,
)
self.assertTrue(
torch.equal(
torch.sort(identities[identities != -1].view(-1))[0],
numbers,
),
f"{identities=} vs {numbers=}",
)
numbers_100_200 = torch.arange(
2**33 + 100, 2**33 + 200, dtype=torch.int32, device="cuda"
)
metadata[:, 0] -= 7 * 24 + 1
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
input=numbers_100_200,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=False,
exp_hours=7 * 24,
metadata=metadata,
)
expect_indices = list(range(100))
self.assertEqual(torch.unique(output).tolist(), expect_indices)
self.assertTrue(torch.all(evict_slots != -1))
self.assertEqual(torch.unique(evict_slots).tolist(), expect_indices)
self.assertTrue(
torch.equal(
torch.sort(identities[identities != -1].view(-1))[0],
numbers_100_200,
),
f"{identities=} vs {numbers_100_200=}",
)
@unittest.skipIf(*gpu_unavailable)
def test_fallback(self) -> None:
"""
Test the fallback functionality of zero collision hash.
It creates a identity table with 100 slots, and insert 100 numbers
with zero collision hash. Then it inserts 20 random numbers 90-109 with zero
collision hash when all the slots should be filled. When enabling fallback,
all the ids (including unexisting ones) are mapped to a position. When disabling
fallback, existing ids are mapped to a position and unexisting ones are mapped to -1.
Assertions:
1. When fallback is enabled, all the ids (including unexisting ones) are mapped to a position.
2. When fallback is disabled, existing ids are mapped to a position and unexisting ones are mapped to -1.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
# init and add some ids
identities, _ = torch.ops.fbgemm.create_zch_buffer(
100, device=torch.device("cuda"), long_type=True
)
ids = torch.arange(0, 100, device="cuda")
output, _ = torch.ops.fbgemm.zero_collision_hash(
input=ids,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=False,
)
# non-readonly and fallback enabled
ids = torch.arange(90, 120, device="cuda")
remapped_ids, _ = torch.ops.fbgemm.zero_collision_hash(
input=ids,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=False,
disable_fallback=False,
)
# all ids (including unexisting ones) are mapped to a position
self.assertTrue(torch.all(remapped_ids != -1))
# readonly and fallback enabled
ids = torch.arange(90, 120, device="cuda")
remapped_ids, _ = torch.ops.fbgemm.zero_collision_hash(
input=ids,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=True,
disable_fallback=False,
)
# all ids (including unexisting ones) are mapped to a position
self.assertTrue(torch.all(remapped_ids != -1))
# non-readonly and fallback disabled
ids = torch.arange(90, 120, device="cuda")
remapped_ids, _ = torch.ops.fbgemm.zero_collision_hash(
input=ids,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=False,
disable_fallback=True,
)
# existing ids are mapped to a position and unexisting ones are mapped to -1
self.assertTrue(
torch.equal(
torch.index_select(
identities, 0, remapped_ids[remapped_ids != -1]
).squeeze(),
torch.arange(90, 100, device="cuda"),
)
)
self.assertTrue(torch.all(remapped_ids[-20:] == -1))
# readonly and fallback disabled
ids = torch.arange(90, 120, device="cuda")
remapped_ids, _ = torch.ops.fbgemm.zero_collision_hash(
input=ids,
identities=identities,
max_probe=100,
circular_probe=True,
readonly=True,
disable_fallback=True,
)
# existing ids are mapped to a position and unexisting ones are mapped to -1
self.assertTrue(
torch.equal(
torch.index_select(
identities, 0, remapped_ids[remapped_ids != -1]
).squeeze(),
torch.arange(90, 100, device="cuda"),
)
)
self.assertTrue(torch.all(remapped_ids[-20:] == -1))
@unittest.skipIf(*gpu_unavailable)
def test_simple_zch_individual_score_evict(self) -> None:
"""
Test the zero collision hash with individual score evict.
It creates a identity table with 100 slots, and insert 100 numbers
with zero collision hash. Then it sets a threshold to make half of the
slots evictable. Then it inserts 100 random numbers with zero collision
hash, and check the number of evicted slots. Then it looks up with the
values between 0 and 99, and check the output is correct.
Assertions:
1. After inserting 100 numbers, there is none empty slot, and the remmaped ids are in range [0, 99].
2. The metadata for the inserted ids are mapped correctly.
3. None eviction happens after inserting 100 numbers.
4. Readonly lookup works correctly.
5. After setting the threshold to make half of the slots evictable, and inserting 100 random numbers,
half (50) of the slots are evicted.
6. The metadata for the evicted ids are set correctly.
7. The metadata for the inserted ids are mapped correctly.
8. The read-only output values for the requried 0-99 are correct.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior.
"""
# evict
identities, metadata = torch.ops.fbgemm.create_zch_buffer(
100, support_evict=True, long_type=True, device=torch.device("cuda")
)
numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda")
input_metadata_500_600 = torch.arange(
500, 600, dtype=torch.int32, device="cuda"
)
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_0_100,
identities,
100,
circular_probe=True,
metadata=metadata,
input_metadata=input_metadata_500_600,
eviction_threshold=100,
)
self.assertEqual(torch.unique(output).tolist(), numbers_0_100.tolist())
self.assertEqual(
torch.unique(metadata).tolist(), input_metadata_500_600.tolist()
)
self.assertTrue(evict_slots.numel() == 0)
# readonly lookup.
output_readonly, _ = torch.ops.fbgemm.zero_collision_hash(
numbers_0_100,
identities,
100,
circular_probe=True,
readonly=True,
)
self.assertTrue(torch.equal(output, output_readonly))
numbers_100_200 = torch.arange(100, 200, dtype=torch.int64, device="cuda")
input_metadata_600_700 = torch.arange(
600, 700, dtype=torch.int32, device="cuda"
)
# evict by setting eviction_threshold to 550 (half of the slots of which the
# eviction scores are less 550 will be evicted)
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_100_200,
identities,
100,
circular_probe=True,
metadata=metadata,
input_metadata=input_metadata_600_700,
eviction_threshold=550,
)
self.assertEqual(evict_slots.numel(), 50)
self.assertTrue(torch.all(metadata >= 550))
# readonly lookup.
output_readonly, _ = torch.ops.fbgemm.zero_collision_hash(
numbers_100_200,
identities,
100,
circular_probe=True,
readonly=True,
)
self.assertTrue(torch.equal(output, output_readonly))
# attempt to update with lower input_metadata values
metadata0 = metadata.clone()
input_metadata_0_100 = torch.arange(0, 100, dtype=torch.int32, device="cuda")
output_lower_metadata, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_100_200,
identities,
100,
circular_probe=True,
metadata=metadata,
input_metadata=input_metadata_0_100,
eviction_threshold=550,
)
self.assertTrue(torch.equal(output_lower_metadata, output))
# metadata should not be overwritten
self.assertTrue(torch.equal(metadata, metadata0))
@unittest.skipIf(*gpu_unavailable)
def test_zch_lru_evict(self) -> None:
"""
Test the zero collision hash with LRU eviction.
Tested eviction policy: HashZchKernelEvictionPolicy.LRU_EVICTION.value
"""
# No evict
identities, metadata = torch.ops.fbgemm.create_zch_buffer(
100, support_evict=True, device=torch.device("cuda")
)
numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda")
cur_hour = 500
ttl = 72
input_metadata = torch.full_like(
numbers_0_100,
ttl + cur_hour,
dtype=torch.int32,
device="cuda",
)
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_0_100,
identities,
100,
circular_probe=True,
metadata=metadata,
eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value,
input_metadata=input_metadata,
eviction_threshold=cur_hour,
)
self.assertEqual(
torch.unique(output).tolist(), numbers_0_100.tolist(), f"{output=}"
)
self.assertTrue(torch.all(metadata != -1), metadata)
self.assertTrue(evict_slots.numel() == 0)
self.assertEqual(
torch.unique(identities).tolist(), numbers_0_100.tolist(), f"{identities=}"
)
# readonly lookup.
output_readonly, _ = torch.ops.fbgemm.zero_collision_hash(
numbers_0_100,
identities,
100,
circular_probe=True,
readonly=True,
eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value,
)
self.assertTrue(output.tolist(), output_readonly.tolist())
output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash(
numbers_0_100.cpu(),
identities.cpu(),
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value,
)
self.assertTrue(
torch.equal(output_readonly_cpu, output_readonly.cpu()),
f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}",
)
numbers_100_120 = torch.arange(100, 120, dtype=torch.int64, device="cuda")
new_cur_hour = 600
new_input_metadata = torch.full_like(
numbers_100_120,
ttl + new_cur_hour,
dtype=torch.int32,
device="cuda",
)
# modify metadata to set different update hours to trigger LRU eviction
metadata = torch.randint(
500, (100, 1), dtype=torch.int32, device=metadata.device
)
# arrange metadata in update order
eviction_order = (
torch.sort(metadata, 0)
.indices.index_select(1, torch.tensor([0], device=metadata.device))
.squeeze()
)
# all rows were occupied, do evict for all input numbers
# evict by LRU
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
input=numbers_100_120,
identities=identities,
max_probe=100,
circular_probe=True,
metadata=metadata,
eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value,
input_metadata=new_input_metadata,
eviction_threshold=new_cur_hour,
)
self.assertEqual(evict_slots.numel(), 20)
self.assertTrue(
set(evict_slots.tolist()).issubset(set(eviction_order[:40].tolist())),
f"{evict_slots=}, {eviction_order=}",
)
self.assertTrue(
torch.equal(
torch.sort(identities[identities >= 100])[0],
torch.sort(numbers_100_120)[0],
),
f"{identities=} vs {numbers_100_120=}",
)
self.assertTrue(
torch.equal(evict_slots, torch.sort(output)[0]),
f"{evict_slots=} vs {output=}",
)
self.assertTrue(
torch.equal(
torch.nonzero(metadata >= 500), torch.nonzero(identities >= 100)
),
f"{torch.nonzero(metadata >= 500)=} vs {torch.nonzero(identities >= 100)=}",
)
# readonly lookup again
output_readonly, _ = torch.ops.fbgemm.zero_collision_hash(
numbers_100_120,
identities,
100,
circular_probe=True,
readonly=True,
)
self.assertTrue(output.tolist(), output_readonly.tolist())
output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash(
numbers_100_120.cpu(),
identities.cpu(),
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
)
self.assertTrue(
torch.equal(output_readonly_cpu, output_readonly.cpu()),
f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}",
)
@unittest.skipIf(*gpu_unavailable)
def test_zch_lru_evict_with_unexpired_slots(self) -> None:
"""
Test the zero collision hash with LRU eviction with unexpired slots.
Raises:
AssertionError: If the test detects a mismatch from the expected behavior,
an assertion error will be raised.
"""
# No evict
identities, metadata = torch.ops.fbgemm.create_zch_buffer(
100, support_evict=True, device=torch.device("cuda")
)
numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda")
cur_hour = 1000
ttl = 72
input_metadata = torch.full_like(
numbers_0_100,
ttl + cur_hour,
dtype=torch.int32,
device="cuda",
)
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
numbers_0_100,
identities,
100,
circular_probe=True,
metadata=metadata,
eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value,
eviction_threshold=cur_hour,
input_metadata=input_metadata,
)
self.assertEqual(
torch.unique(output).tolist(), numbers_0_100.tolist(), f"{output=}"
)
self.assertTrue(torch.all(metadata != -1), metadata)
self.assertTrue(evict_slots.numel() == 0)
self.assertEqual(
torch.unique(identities).tolist(), numbers_0_100.tolist(), f"{identities=}"
)
# readonly lookup.
output_readonly, _ = torch.ops.fbgemm.zero_collision_hash(
numbers_0_100,
identities,
100,
circular_probe=True,
readonly=True,
eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value,
)
self.assertTrue(output.tolist(), output_readonly.tolist())
output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash(
numbers_0_100.cpu(),
identities.cpu(),
100,
circular_probe=True,
exp_hours=-1,
readonly=True,
eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value,
)
self.assertTrue(
torch.equal(output_readonly_cpu, output_readonly.cpu()),
f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}",
)
numbers_100_150 = torch.arange(100, 150, dtype=torch.int64, device="cuda")
# 20 slots expired, 80 unexpired
metadata_to_update = torch.randint(
500, 1050, (20, 1), dtype=torch.int32, device=metadata.device
)
metadata[0:20] = metadata_to_update
metadata_index_0_20 = torch.arange(
0, 20, dtype=torch.int64, device=metadata.device
)
new_cur_hour = 1050
new_input_metadata = torch.full_like(
numbers_100_150,
ttl + new_cur_hour,
dtype=torch.int32,
device="cuda",
)
# all rows were occupied, do evict by LRU + TTL rule
output, evict_slots = torch.ops.fbgemm.zero_collision_hash(
input=numbers_100_150,
identities=identities,
max_probe=100,
circular_probe=True,
metadata=metadata,
eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value,
eviction_threshold=new_cur_hour,
input_metadata=new_input_metadata,
)