Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions buildz80com.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,24 @@ def pack_2bit_weights(weights: np.ndarray) -> bytes:
chunk = mapped[i:i+4]
if len(chunk) < 4:
chunk = np.pad(chunk, (0, 4 - len(chunk)), constant_values=2)
# 0,1,2,3 are as -2,-1,0,1 but in muladd for speed will be treated as -2,0,1,-1
for c in range(4):
match int(chunk[c]):
case 0:
chunk[c] = 0 # -2 in position 0
case 1:
chunk[c] = 3 # -1 in position 3
case 2:
chunk[c] = 1 # 0 in position 1
case 3:
chunk[c] = 2 # 1 in position 2
case _:
raise Exception("weight value not valid")
byte = \
(int(chunk[2]) << 6) | \
(int(chunk[1]) << 4) | \
(int(chunk[0]) << 2) | \
(int(chunk[3])) # chunk 3 as last since in the evaluation there will be first a rotation
(chunk[2] << 6) | \
(chunk[1] << 4) | \
(chunk[0] << 2) | \
(chunk[3]) # chunk 3 as last since in the evaluation there will be first a rotation
packed.append(byte)

return bytes(packed)
Expand Down Expand Up @@ -453,8 +466,9 @@ def build_autoreg(model_path: str = 'command_model_autoreg.pt'):
b.inc_hl() # inc hl
b.ld_d_hl() # ld d, (hl)
b.inc_hl() # inc hl
b.ld_mem_label_hl('CURIN') # ld (CURIN), hl
b.call('MULADD') # call MULADD
b.ld_mem_label_hl('CURIN') # ld (CURIN), hl
b.dec_a() # dec a ; to verify if is 1 (that means weitgh "0 equivalent")
b.call_nz('MULADD') # call nz, MULADD ; "0 equivalent" is the most common state, so call will be performed more rarely
b.inc_c() # inc c
b.djnz('LWT') # djnz LWT
#
Expand All @@ -467,11 +481,11 @@ def build_autoreg(model_path: str = 'command_model_autoreg.pt'):
b.ld_hl_mem_label('ACC') # ld hl, (ACC)
b.add_hl_de() # add hl, de
b.ld_mem_label_hl('ACC') # ld (ACC), hl
# ; After each layer, arithmetic right-shift by 2 to prevent overflow
b.sra_h() # sra h ; Shift right arithmetic (preserves sign)
# ; after each layer, arithmetic right-shift by 2 to prevent overflow
b.sra_h() # sra h ; shift right arithmetic (preserves sign)
b.rr_l() # rr l
b.sra_h() # sra h
b.rr_l() # rr l ; ACC = ACC / 4
b.rr_l() # rr l ; ACC = ACC / 4
b.ld_iyd_l(0) # ld (iy+0), l
b.ld_iyd_h(1) # ld (iy+1), h
b.inc_iy() # inc iy
Expand All @@ -480,21 +494,16 @@ def build_autoreg(model_path: str = 'command_model_autoreg.pt'):
b.djnz('LNEUR') # dec b : jp nz, LNEUR
b.ret() # ret
#
# === MULADD === #
# a is 0,1,2,3 that are as -2,-1,0,+1
# === MULADD === # ; since last dec here a is $FF,1,2 that are as -2,1,-1
b.label('MULADD') # MULADD:
# TODO: ACC should be put in a register
b.ld_hl_mem_label('ACC') # ld hl, (ACC)
b.dec_a() # dec a
b.jr_z('MA_M1') # jr z, MA_M1 ; jump if a is -1 equivalent
b.jr_z('MA_P1') # jr z, MA_P1 ; jump if a is +1 equivalent (33% jump probability)
b.sbc_hl_de() # sbc hl, de ; a is -1 or -2 equivalent
b.dec_a() # dec a
b.ret_z() # ret z ; a is zero equivalent
b.dec_a() # dec a
b.jr_z('MA_P1') # jr z, MA_P1 ; jump if a is +1 equivalent
b.label('MA_M2') # MA_M2: ; a is -2 equivalent
b.sbc_hl_de() # sbc hl, de
b.label('MA_M1') # MA_M1: ; -1
b.sbc_hl_de() # sbc hl, de
b.jr_z('MA_MRET') # jr z, MA_MRET ; skip next sbc if a is just -1 equivalent (50% jump probability)
b.sbc_hl_de() # sbc hl, de ; second time since a is -2 equivalent
b.label('MA_MRET') # MA_MRET:
b.ld_mem_label_hl('ACC') # ld (ACC), hl
b.ret() # ret
b.label('MA_P1') # MA_P1: ; a is +1
Expand Down
4 changes: 4 additions & 0 deletions libz80.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def call(self, label: str):
self.emit(0xCD)
self.fixup_word(label)

def call_nz(self, label: str):
self.emit(0xC4)
self.fixup_word(label)

def call_addr(self, addr: int):
self.emit(0xCD, addr & 0xFF, (addr >> 8) & 0xFF)

Expand Down