diff --git a/buildz80com.py b/buildz80com.py index 1437076..0c74d56 100755 --- a/buildz80com.py +++ b/buildz80com.py @@ -34,11 +34,24 @@ def pack_2bit_weights(weights: np.ndarray) -> bytes: chunk = mapped[i:i+4] if len(chunk) < 4: chunk = np.pad(chunk, (0, 4 - len(chunk)), constant_values=2) + # 0,1,2,3 are as -2,-1,0,1 but in muladd for speed will be treated as -2,0,1,-1 + for c in range(4): + match int(chunk[c]): + case 0: + chunk[c] = 0 # -2 in position 0 + case 1: + chunk[c] = 3 # -1 in position 3 + case 2: + chunk[c] = 1 # 0 in position 1 + case 3: + chunk[c] = 2 # 1 in position 2 + case _: + raise Exception("weight value not valid") byte = \ - (int(chunk[2]) << 6) | \ - (int(chunk[1]) << 4) | \ - (int(chunk[0]) << 2) | \ - (int(chunk[3])) # chunk 3 as last since in the evaluation there will be first a rotation + (chunk[2] << 6) | \ + (chunk[1] << 4) | \ + (chunk[0] << 2) | \ + (chunk[3]) # chunk 3 as last since in the evaluation there will be first a rotation packed.append(byte) return bytes(packed) @@ -453,8 +466,9 @@ def build_autoreg(model_path: str = 'command_model_autoreg.pt'): b.inc_hl() # inc hl b.ld_d_hl() # ld d, (hl) b.inc_hl() # inc hl - b.ld_mem_label_hl('CURIN') # ld (CURIN), hl - b.call('MULADD') # call MULADD + b.ld_mem_label_hl('CURIN') # ld (CURIN), hl + b.dec_a() # dec a ; to verify if is 1 (that means weitgh "0 equivalent") + b.call_nz('MULADD') # call nz, MULADD ; "0 equivalent" is the most common state, so call will be performed more rarely b.inc_c() # inc c b.djnz('LWT') # djnz LWT # @@ -467,11 +481,11 @@ def build_autoreg(model_path: str = 'command_model_autoreg.pt'): b.ld_hl_mem_label('ACC') # ld hl, (ACC) b.add_hl_de() # add hl, de b.ld_mem_label_hl('ACC') # ld (ACC), hl - # ; After each layer, arithmetic right-shift by 2 to prevent overflow - b.sra_h() # sra h ; Shift right arithmetic (preserves sign) + # ; after each layer, arithmetic right-shift by 2 to prevent overflow + b.sra_h() # sra h ; shift right arithmetic (preserves sign) b.rr_l() # rr l b.sra_h() # sra h - b.rr_l() # rr l ; ACC = ACC / 4 + b.rr_l() # rr l ; ACC = ACC / 4 b.ld_iyd_l(0) # ld (iy+0), l b.ld_iyd_h(1) # ld (iy+1), h b.inc_iy() # inc iy @@ -480,21 +494,16 @@ def build_autoreg(model_path: str = 'command_model_autoreg.pt'): b.djnz('LNEUR') # dec b : jp nz, LNEUR b.ret() # ret # - # === MULADD === # - # a is 0,1,2,3 that are as -2,-1,0,+1 + # === MULADD === # ; since last dec here a is $FF,1,2 that are as -2,1,-1 b.label('MULADD') # MULADD: - # TODO: ACC should be put in a register b.ld_hl_mem_label('ACC') # ld hl, (ACC) b.dec_a() # dec a - b.jr_z('MA_M1') # jr z, MA_M1 ; jump if a is -1 equivalent + b.jr_z('MA_P1') # jr z, MA_P1 ; jump if a is +1 equivalent (33% jump probability) + b.sbc_hl_de() # sbc hl, de ; a is -1 or -2 equivalent b.dec_a() # dec a - b.ret_z() # ret z ; a is zero equivalent - b.dec_a() # dec a - b.jr_z('MA_P1') # jr z, MA_P1 ; jump if a is +1 equivalent - b.label('MA_M2') # MA_M2: ; a is -2 equivalent - b.sbc_hl_de() # sbc hl, de - b.label('MA_M1') # MA_M1: ; -1 - b.sbc_hl_de() # sbc hl, de + b.jr_z('MA_MRET') # jr z, MA_MRET ; skip next sbc if a is just -1 equivalent (50% jump probability) + b.sbc_hl_de() # sbc hl, de ; second time since a is -2 equivalent + b.label('MA_MRET') # MA_MRET: b.ld_mem_label_hl('ACC') # ld (ACC), hl b.ret() # ret b.label('MA_P1') # MA_P1: ; a is +1 diff --git a/libz80.py b/libz80.py index ace7b23..773ad2a 100644 --- a/libz80.py +++ b/libz80.py @@ -83,6 +83,10 @@ def call(self, label: str): self.emit(0xCD) self.fixup_word(label) + def call_nz(self, label: str): + self.emit(0xC4) + self.fixup_word(label) + def call_addr(self, addr: int): self.emit(0xCD, addr & 0xFF, (addr >> 8) & 0xFF)