From 869aaa7e99134ac1db07b60a35a3a3d4bf024567 Mon Sep 17 00:00:00 2001 From: truthixify Date: Thu, 14 Aug 2025 03:50:48 +0100 Subject: [PATCH] integrated a multitoken vault system with comprehensive tests --- .DS_Store | Bin 8196 -> 8196 bytes src/contracts/core.cairo | 8 +- src/contracts/factory.cairo | 3 +- src/contracts/vault.cairo | 404 +++++++++++++++++++++------------- src/interfaces/icore.cairo | 4 +- src/interfaces/ifactory.cairo | 10 +- src/interfaces/ivault.cairo | 135 +++++++----- tests/test_core.cairo | 20 +- tests/test_vault.cairo | 316 +++++++++++++++++--------- 9 files changed, 575 insertions(+), 325 deletions(-) diff --git a/.DS_Store b/.DS_Store index aeca7ee3bca58441486d18e39b36fef87b5eeb04..1fea7c3199bc37c0b481cf9441f52e293d40ba2c 100644 GIT binary patch delta 37 tcmZp1XmOa}&nUk!U^hRb{AL~jf9B1Tg%cPzACYljo>;)MnO)*9I{@Y83Tbk|6stdIb1k}abubY^JaF5zw7{)WEW}x diff --git a/src/contracts/core.cairo b/src/contracts/core.cairo index 602abc9..936cd44 100644 --- a/src/contracts/core.cairo +++ b/src/contracts/core.cairo @@ -234,7 +234,7 @@ mod Core { /// - If the payout is attempted before the scheduled start time or after the end time. /// - If the payout is attempted before the required interval has passed since the last /// execution. - fn schedule_payout(ref self: ContractState) { + fn schedule_payout(ref self: ContractState, token: ContractAddress) { let caller = get_caller_address(); let members = self.member.get_members(); let no_of_members = members.len(); @@ -243,8 +243,8 @@ mod Core { let vault_address = org_info.vault_address; let vault_dispatcher = IVaultDispatcher { contract_address: vault_address }; - let total_bonus = vault_dispatcher.get_bonus_allocation(); - let total_funds = vault_dispatcher.get_balance(); + let total_bonus = vault_dispatcher.get_bonus_allocation(token); + let total_funds = vault_dispatcher.get_token_balance(token); let current_schedule = self.disbursement.get_current_schedule(); assert(current_schedule.status == ScheduleStatus::ACTIVE, 'Schedule not active'); @@ -282,7 +282,7 @@ mod Core { .disbursement .compute_renumeration(current_member_response, total_bonus, total_weight); let timestamp = get_block_timestamp(); - vault_dispatcher.pay_member(current_member_response.address, amount); + vault_dispatcher.pay_member(token, current_member_response.address, amount); // self.member.record_member_payment(current_member_response.id, amount, timestamp) } diff --git a/src/contracts/factory.cairo b/src/contracts/factory.cairo index c5bb322..4d3b340 100644 --- a/src/contracts/factory.cairo +++ b/src/contracts/factory.cairo @@ -376,10 +376,11 @@ pub mod Factory { let vault_count = self.vaults_count.read(); let vault_id: u256 = vault_count.into(); let mut constructor_calldata = array![]; - token.serialize(ref constructor_calldata); + let tokens = array![token]; // available_funds.serialize(ref constructor_calldata); // starting_bonus_allocation.serialize(ref constructor_calldata); owner.serialize(ref constructor_calldata); + tokens.serialize(ref constructor_calldata); // Deploy the Vault let processed_class_hash: ClassHash = self.vault_class_hash.read(); diff --git a/src/contracts/vault.cairo b/src/contracts/vault.cairo index 3eb03f4..cb036d3 100644 --- a/src/contracts/vault.cairo +++ b/src/contracts/vault.cairo @@ -1,10 +1,10 @@ -/// ## A Starknet contract for managing an organization's financial vault. +/// ## A Starknet contract for managing an organization's financial vault for multiple tokens. /// /// This contract is responsible for: -/// - Securely holding a single type of ERC20 token. +/// - Securely holding multiple types of ERC20 tokens. /// - Processing deposits and withdrawals from authorized addresses. /// - Executing payments to organization members. -/// - Allocating funds for bonuses. +/// - Allocating funds for bonuses on a per-token basis. /// - Recording all transactions for auditing purposes. /// - Providing security features like an emergency freeze. /// @@ -19,7 +19,8 @@ pub mod Vault { use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; use openzeppelin::upgrades::UpgradeableComponent; use starknet::storage::{ - Map, StoragePathEntry, StoragePointerReadAccess, StoragePointerWriteAccess, + Map, MutableVecTrait, StoragePathEntry, StoragePointerReadAccess, StoragePointerWriteAccess, + Vec, VecTrait, }; use starknet::{ ContractAddress, get_block_timestamp, get_caller_address, get_contract_address, get_tx_info, @@ -35,21 +36,19 @@ pub mod Vault { /// Maps a contract address to a boolean indicating if it's permitted to interact with the /// vault. permitted_addresses: Map, - /// The total balance of the managed token held by the vault. - available_funds: u256, - /// The portion of the total balance allocated for bonus payments. - total_bonus: u256, - /// Maps a transaction ID (`u64`) to a `Transaction` struct, storing a history of all vault + /// Maps a token address to its portion of the total balance allocated for bonus payments. + bonus_allocations: Map, + /// Maps a transaction ID to a `Transaction` struct, storing a history of all vault /// operations. - transaction_history: Map< - u64, Transaction, - >, // No 1. Transaction x, no 2, transaction y etc for history, and it begins with 1 + transaction_history: Map, /// A counter for the total number of transactions processed. transactions_count: u64, /// The current operational status of the vault (e.g., VAULTACTIVE, VAULTFROZEN). vault_status: VaultStatus, - /// The contract address of the single ERC20 token this vault manages. - token: ContractAddress, + /// Maps a token address to a boolean, indicating if it's an accepted asset for the vault. + accepted_tokens: Map, + /// A list of all accepted token addresses for easy retrieval. + accepted_tokens_list: Vec, /// Substorage for the Ownable component. #[substorage(v0)] ownable: OwnableComponent::Storage, @@ -74,15 +73,16 @@ pub mod Vault { TransactionRecorded: TransactionRecorded, /// Emitted when funds are allocated to the bonus pool. BonusAllocation: BonusAllocation, + /// Emitted when a new token is accepted by the vault. + TokenAccepted: TokenAccepted, + /// Emitted when a token is removed from the list of accepted tokens. + TokenRemoved: TokenRemoved, /// Flat event for Ownable component events. #[flat] OwnableEvent: OwnableComponent::Event, /// Flat event for Upgradeable component events. #[flat] UpgradeableEvent: UpgradeableComponent::Event, - // TODO: - // Add an event here that gets emitted if the money goes below a certain threshold - // Threshold Will be decided. } /// Event data for a successful deposit. @@ -133,151 +133,162 @@ pub mod Vault { pub timestamp: u64, } + /// Event data for when a new token is accepted. + #[derive(Drop, starknet::Event)] + pub struct TokenAccepted { + pub added_by: ContractAddress, + pub token: ContractAddress, + } + + /// Event data for when a token is removed. + #[derive(Drop, starknet::Event)] + pub struct TokenRemoved { + pub removed_by: ContractAddress, + pub token: ContractAddress, + } + #[abi(embed_v0)] impl OwnableMixinImpl = OwnableComponent::OwnableMixinImpl; impl OwnableInternalImpl = OwnableComponent::InternalImpl; - impl UpgradeableInternalImpl = UpgradeableComponent::InternalImpl; - // TODO: - // Add to this constructor, a way to add addresses and store them as permitted addresses here /// Initializes the Vault contract. /// /// ### Parameters - /// - `token`: The contract address of the ERC20 token to be managed. /// - `owner`: The address that will have initial ownership and permissions. #[constructor] fn constructor( - ref self: ContractState, - token: ContractAddress, // available_funds: u256, - // bonus_allocation: u256, - owner: ContractAddress, + ref self: ContractState, owner: ContractAddress, tokens: Array, ) { - self.token.write(token); - self.total_bonus.write(0); self.permitted_addresses.entry(owner).write(true); - self._sync_available_funds(); - } - - // TODO: - // From the ivault, add functions in the interfaces for subtracting from and adding to bonus - // IMPLEMENT HERE + let mut i = 0; + while i != tokens.len() { + self.accepted_tokens.entry(*tokens.at(i)).write(true); + self.accepted_tokens_list.push(*tokens.at(i)); - // TODO: - // Implement a storage variable, that will be in the constructor, for the token address to be - // supplied at deployment For now, we want a single-token implementation + i += 1; + } + } /// # VaultImpl /// /// Public-facing implementation of the `IVault` interface. #[abi(embed_v0)] pub impl VaultImpl of IVault { - /// Accepts a deposit of the managed token. + /// Deposits a specified amount of an accepted token into the vault. /// /// ### Parameters - /// - `amount`: The amount to deposit. - /// - `address`: The address from which the funds are being sent. + /// - `token`: The `ContractAddress` of the token being deposited. + /// - `amount`: The amount of funds to deposit as a `u256`. + /// - `from_address`: The `ContractAddress` from which the funds are being sent. /// /// ### Panics - /// - If `amount` or `address` is zero. - /// - If the direct caller or the source `address` is not permitted. + /// - If `amount` or `from_address` is zero. + /// - If the `token` is not on the accepted list. + /// - If the direct caller or the source `from_address` is not permitted. /// - If the vault is frozen. - fn deposit_funds(ref self: ContractState, amount: u256, address: ContractAddress) { + fn deposit_funds( + ref self: ContractState, + token: ContractAddress, + amount: u256, + from_address: ContractAddress, + ) { assert(amount.is_non_zero(), 'Invalid Amount'); - assert(address.is_non_zero(), 'Invalid Address'); + assert(from_address.is_non_zero(), 'Invalid Address'); + assert(self.is_token_acceptable(token), 'Token not accepted'); + let caller = get_caller_address(); - let permitted = self.permitted_addresses.entry(caller).read(); - assert(permitted, 'Direct Caller not permitted'); - assert(self.permitted_addresses.entry(address).read(), 'Deep Caller Not Permitted'); - let current_vault_status = self.vault_status.read(); + assert(self.permitted_addresses.entry(caller).read(), 'Direct Caller not permitted'); assert( - current_vault_status != VaultStatus::VAULTFROZEN, 'Vault Frozen for Transactions', + self.permitted_addresses.entry(from_address).read(), 'Deep Caller Not Permitted', + ); + assert( + self.vault_status.read() != VaultStatus::VAULTFROZEN, + 'Vault Frozen for Transactions', ); - - self._sync_available_funds(); let timestamp = get_block_timestamp(); let this_contract = get_contract_address(); - let token = self.token.read(); let token_dispatcher = IERC20Dispatcher { contract_address: token }; - token_dispatcher.transfer_from(address, this_contract, amount); + token_dispatcher.transfer_from(from_address, this_contract, amount); - self._record_transaction(token, amount, TransactionType::DEPOSIT, address); + self._record_transaction(token, amount, TransactionType::DEPOSIT, from_address); - self._sync_available_funds(); - - self.emit(DepositSuccessful { caller: address, token, timestamp, amount }) + self.emit(DepositSuccessful { caller: from_address, token, timestamp, amount }); } - /// Withdraws the managed token to a specified address. + /// Withdraws a specified amount of an accepted token from the vault. /// /// ### Parameters - /// - `amount`: The amount to withdraw. - /// - `address`: The address to receive the funds. + /// - `token`: The `ContractAddress` of the token being withdrawn. + /// - `amount`: The amount of funds to withdraw as a `u256`. + /// - `to_address`: The `ContractAddress` to receive the funds. /// /// ### Panics - /// - If `amount` or `address` is zero. - /// - If the direct caller or the destination `address` is not permitted. + /// - If `amount` or `to_address` is zero. + /// - If the `token` is not on the accepted list. + /// - If the caller is not a permitted address. /// - If the vault is frozen. - /// - If the requested amount exceeds the vault's balance. - fn withdraw_funds(ref self: ContractState, amount: u256, address: ContractAddress) { + /// - If the requested `amount` exceeds the vault's balance for that token. + fn withdraw_funds( + ref self: ContractState, + token: ContractAddress, + amount: u256, + to_address: ContractAddress, + ) { assert(amount.is_non_zero(), 'Invalid Amount'); - assert(address.is_non_zero(), 'Invalid Address'); - let caller = get_caller_address(); - let permitted = self.permitted_addresses.entry(caller).read(); - assert(permitted, 'Direct Caller not permitted'); - assert(self.permitted_addresses.entry(address).read(), 'Deep Caller Not Permitted'); + assert(to_address.is_non_zero(), 'Invalid Address'); + assert(self.is_token_acceptable(token), 'Token not accepted'); - let current_vault_status = self.vault_status.read(); + let caller = get_caller_address(); + assert(self.permitted_addresses.entry(caller).read(), 'Direct Caller not permitted'); + assert(self.permitted_addresses.entry(to_address).read(), 'Deep Caller Not Permitted'); assert( - current_vault_status != VaultStatus::VAULTFROZEN, 'Vault Frozen for Transactions', + self.vault_status.read() != VaultStatus::VAULTFROZEN, + 'Vault Frozen for Transactions', ); - self._sync_available_funds(); - - let timestamp = get_block_timestamp(); + let token_balance = self.get_token_balance(token); + assert(amount <= token_balance, 'Insufficient Balance'); - let token = self.token.read(); let token_dispatcher = IERC20Dispatcher { contract_address: token }; - let vault_balance = token_dispatcher.balance_of(get_contract_address()); - assert(amount <= vault_balance, 'Insufficient Balance'); - - token_dispatcher.transfer(address, amount); - self._record_transaction(token, amount, TransactionType::WITHDRAWAL, address); + token_dispatcher.transfer(to_address, amount); - self._sync_available_funds(); + self._record_transaction(token, amount, TransactionType::WITHDRAWAL, to_address); - self.emit(WithdrawalSuccessful { caller: address, token, amount, timestamp }) + let timestamp = get_block_timestamp(); + self.emit(WithdrawalSuccessful { caller: to_address, token, amount, timestamp }); } - /// Allocates a portion of the vault's funds to the bonus pool. + /// Allocates a portion of a token's funds to the bonus pool. /// /// ### Parameters + /// - `token`: The `ContractAddress` of the token for the bonus allocation. /// - `amount`: The amount to allocate for bonuses. /// - `address`: The address initiating the allocation. /// /// ### Panics - /// - If `amount` or `address` is zero. - /// - If the direct caller or the source `address` is not permitted. + /// - If the caller is not a permitted address. + /// - If the `token` is not on the accepted list. + /// - If the `amount` exceeds the available, non-bonus portion of the token's balance. fn add_to_bonus_allocation( - ref self: ContractState, amount: u256, address: ContractAddress, + ref self: ContractState, token: ContractAddress, amount: u256, address: ContractAddress, ) { - assert(amount.is_non_zero(), 'Invalid Amount'); - assert(address.is_non_zero(), 'Invalid Address'); let caller = get_caller_address(); - let permitted = self.permitted_addresses.entry(caller).read(); - assert(permitted, 'Direct Caller not permitted'); + assert(self.permitted_addresses.entry(caller).read(), 'Direct Caller not permitted'); assert(self.permitted_addresses.entry(address).read(), 'Deep Caller Not Permitted'); + assert(self.is_token_acceptable(token), 'Token not accepted'); - self._sync_available_funds(); + let current_token_balance = self.get_token_balance(token); + let current_token_bonus = self.bonus_allocations.entry(token).read(); + assert( + amount <= current_token_balance - current_token_bonus, 'Bonus exceeds available', + ); - self.total_bonus.write(self.total_bonus.read() + amount); - self - ._record_transaction( - self.token.read(), amount, TransactionType::BONUS_ALLOCATION, address, - ); + self.bonus_allocations.entry(token).write(current_token_bonus + amount); + self._record_transaction(token, amount, TransactionType::BONUS_ALLOCATION, address); } /// Freezes all vault operations as a security measure. @@ -288,7 +299,7 @@ pub mod Vault { fn emergency_freeze(ref self: ContractState) { let caller = get_caller_address(); let permitted = self.permitted_addresses.entry(caller).read(); - assert(permitted, 'Caller not permitted'); + assert(permitted, 'Direct Caller not permitted'); assert(self.vault_status.read() != VaultStatus::VAULTFROZEN, 'Vault Already Frozen'); self.vault_status.write(VaultStatus::VAULTFROZEN); @@ -302,80 +313,156 @@ pub mod Vault { fn unfreeze_vault(ref self: ContractState) { let caller = get_caller_address(); let permitted = self.permitted_addresses.entry(caller).read(); - assert(permitted, 'Caller not permitted'); + assert(permitted, 'Direct Caller not permitted'); assert(self.vault_status.read() != VaultStatus::VAULTRESUMED, 'Vault Not Frozen'); self.vault_status.write(VaultStatus::VAULTRESUMED); } - // fn bulk_transfer(ref self: ContractState, recipients: Span) {} + /// Executes a payment from the vault to a specific address using a specified token. + /// + /// ### Parameters + /// - `token`: The `ContractAddress` of the token for the payment. + /// - `recipient`: The `ContractAddress` of the member to receive the payment. + /// - `amount`: The payment amount as a `u256`. + /// + /// ### Panics + /// - If `recipient` or `amount` is zero. + /// - If the `token` is not on the accepted list. + /// - If the caller is not a permitted address. + /// - If the payment `amount` exceeds the vault's balance for that token. + /// - If the token transfer fails. + fn pay_member( + ref self: ContractState, + token: ContractAddress, + recipient: ContractAddress, + amount: u256, + ) { + assert(recipient.is_non_zero(), 'Invalid Address'); + assert(amount.is_non_zero(), 'Invalid Amount'); + assert(self.is_token_acceptable(token), 'Token not accepted'); - /// Returns the vault's total balance of the managed token. + let caller = get_caller_address(); + assert(self.permitted_addresses.entry(caller).read(), 'Direct Caller not permitted'); + + let token_balance = self.get_token_balance(token); + assert(amount <= token_balance, 'Amount exceeds balance'); + + let token_dispatcher = IERC20Dispatcher { contract_address: token }; + let transfered = token_dispatcher.transfer(recipient, amount); + assert(transfered, 'Transfer failed'); + + self._record_transaction(token, amount, TransactionType::PAYMENT, caller); + } + + /// Adds a new ERC20 token to the list of accepted tokens for the vault. (Owner only) /// - /// ### Returns - /// - `u256`: The total balance. - fn get_balance(self: @ContractState) -> u256 { - // let caller = get_caller_address(); - // assert(self.permitted_addresses.entry(caller).read(), 'Caller Not Permitted'); - let token_address = self.token.read(); - let token_dispatcher = IERC20Dispatcher { contract_address: token_address }; - let vault_address = get_contract_address(); - let balance = token_dispatcher.balance_of(vault_address); - balance + /// ### Parameters + /// - `token`: The `ContractAddress` of the token to be accepted. + /// + /// ### Panics + /// - If the caller is not the contract owner. + /// - If the `token` address is zero. + /// - If the `token` has already been accepted. + fn add_accepted_token(ref self: ContractState, token: ContractAddress) { + let caller = get_caller_address(); + + assert(token.is_non_zero(), 'Invalid token address'); + assert(!self.is_token_acceptable(token), 'Token already accepted'); + assert(self.permitted_addresses.entry(caller).read(), 'Direct Caller not permitted'); + + self.accepted_tokens.entry(token).write(true); + self.accepted_tokens_list.push(token); + self.emit(TokenAccepted { added_by: get_caller_address(), token }); } - /// Returns the funds currently available for use. + /// Removes an ERC20 token from the list of accepted tokens for the vault. (Owner only) /// - /// ### Returns - /// - `u256`: The available fund balance. - fn get_available_funds(self: @ContractState) -> u256 { - self.available_funds.read() + /// ### Parameters + /// - `token`: The `ContractAddress` of the token to be removed. + fn remove_accepted_token(ref self: ContractState, token: ContractAddress) { + let caller = get_caller_address(); + assert(self.is_token_acceptable(token), 'Token not accepted'); + assert(self.permitted_addresses.entry(caller).read(), 'Direct Caller not permitted'); + + self.accepted_tokens.entry(token).write(false); + self.emit(TokenRemoved { removed_by: get_caller_address(), token }); } - /// Returns the total amount allocated for bonuses. + /// Retrieves the total balance of a specific token held by the vault. + /// + /// ### Parameters + /// - `token`: The `ContractAddress` of the token to query. /// /// ### Returns - /// - `u256`: The bonus allocation amount. - fn get_bonus_allocation(self: @ContractState) -> u256 { - // let caller = get_caller_address(); - // assert(self.permitted_addresses.entry(caller).read(), 'Caller Not Permitted'); - self.total_bonus.read() + /// - `u256`: The total balance of the specified token. + fn get_token_balance(self: @ContractState, token: ContractAddress) -> u256 { + let token_dispatcher = IERC20Dispatcher { contract_address: token }; + token_dispatcher.balance_of(get_contract_address()) } - /// Pays a member from the vault's funds. + /// Retrieves a list of all tokens the vault is authorized to manage. /// - /// ### Parameters - /// - `recipient`: The address of the member to pay. - /// - `amount`: The amount of the payment. + /// ### Returns + /// - `Array`: A list of accepted token contract addresses. + fn get_accepted_tokens(self: @ContractState) -> Array { + let mut accepted_tokens: Array = array![]; + let mut i = 0; + + while i != self.accepted_tokens_list.len() { + let token_addr = self.accepted_tokens_list.at(i).read(); + + if self.is_token_acceptable(token_addr) { + accepted_tokens.append(token_addr); + } + + i += 1; + } + + accepted_tokens + } + + /// Retrieves the vault's balance for every accepted token. /// - /// ### Panics - /// - If `recipient` or `amount` is zero. - /// - If the caller is not a permitted address. - /// - If the payment amount exceeds the vault's balance. - /// - If the token transfer fails. - fn pay_member(ref self: ContractState, recipient: ContractAddress, amount: u256) { - assert(recipient.is_non_zero(), 'Invalid Address'); - assert(amount.is_non_zero(), 'Invalid Amount'); - let caller = get_caller_address(); - assert(self.permitted_addresses.entry(caller).read(), 'Caller Not Permitted'); + /// ### Returns + /// - `Array<(ContractAddress, u256)>`: A list of tuples, each containing a token address + /// and its corresponding balance. + fn get_all_token_balances(self: @ContractState) -> Array<(ContractAddress, u256)> { + let mut all_balances = array![]; + let mut i = 0; + let this_contract = get_contract_address(); - self._sync_available_funds(); + while i != self.accepted_tokens_list.len() { + let token_addr = self.accepted_tokens_list.at(i).read(); - let token_address = self.token.read(); - let token = IERC20Dispatcher { contract_address: token_address }; - let token_balance = token.balance_of(get_contract_address()); - assert(amount <= token_balance, 'Amount Overflow'); - let transfer = token.transfer(recipient, amount); - assert(transfer, 'Transfer failed'); - self._record_transaction(token_address, amount, TransactionType::PAYMENT, caller); + if self.is_token_acceptable(token_addr) { + let token_dispatcher = IERC20Dispatcher { contract_address: token_addr }; + let balance = token_dispatcher.balance_of(this_contract); - self._sync_available_funds(); + all_balances.append((token_addr, balance)); + } + + i += 1; + } + + all_balances } - /// Returns the current status of the vault. + /// Retrieves the current total amount allocated for bonuses for a specific token. + /// + /// ### Parameters + /// - `token`: The `ContractAddress` of the token to query. /// /// ### Returns - /// - `VaultStatus`: The vault's current status enum. + /// - `u256`: The bonus allocation for the specified token. + fn get_bonus_allocation(self: @ContractState, token: ContractAddress) -> u256 { + self.bonus_allocations.entry(token).read() + } + + /// Returns the current operational status of the vault. + /// + /// ### Returns + /// - `VaultStatus`: The vault's current status enum (e.g., `VAULTACTIVE`, `VAULTFROZEN`). fn get_vault_status(self: @ContractState) -> VaultStatus { self.vault_status.read() } @@ -406,6 +493,17 @@ pub mod Vault { assert(org_address.is_non_zero(), 'Invalid Address'); self.permitted_addresses.entry(org_address).write(true); } + + /// Checks whether a specific token is accepted by the vault for transactions. + /// + /// ### Parameters + /// - `token`: The `ContractAddress` of the token to check. + /// + /// ### Returns + /// - `bool`: `true` if the token is accepted, `false` otherwise. + fn is_token_acceptable(self: @ContractState, token: ContractAddress) -> bool { + self.accepted_tokens.entry(token).read() + } } /// # InternalFunctions @@ -422,7 +520,7 @@ pub mod Vault { /// - If the caller is not a permitted address. fn _add_transaction(ref self: ContractState, transaction: Transaction) { let caller = get_caller_address(); - assert(self.permitted_addresses.entry(caller).read(), 'Caller not permitted'); + assert(self.permitted_addresses.entry(caller).read(), 'Direct Caller not permitted'); let current_transaction_count = self.transactions_count.read(); self.transaction_history.entry(current_transaction_count + 1).write(transaction); self.transactions_count.write(current_transaction_count + 1); @@ -446,7 +544,9 @@ pub mod Vault { caller: ContractAddress, ) { let actual_caller = get_caller_address(); - assert(self.permitted_addresses.entry(actual_caller).read(), 'Caller Not Permitted'); + assert( + self.permitted_addresses.entry(actual_caller).read(), 'Direct Caller not permitted', + ); let timestamp = get_block_timestamp(); let tx_info = get_tx_info(); let transaction = Transaction { @@ -468,15 +568,5 @@ pub mod Vault { }, ); } - - /// Updates the `available_funds` storage variable to match the contract's actual token - /// balance. - fn _sync_available_funds(ref self: ContractState) { - let token_address = self.token.read(); - let token_dispatcher = IERC20Dispatcher { contract_address: token_address }; - let vault_address = get_contract_address(); - let actual_balance = token_dispatcher.balance_of(vault_address); - self.available_funds.write(actual_balance); - } } } diff --git a/src/interfaces/icore.cairo b/src/interfaces/icore.cairo index a3ab5f5..ad870cf 100644 --- a/src/interfaces/icore.cairo +++ b/src/interfaces/icore.cairo @@ -1,3 +1,5 @@ +use starknet::ContractAddress; + /// # ICore /// /// This trait defines the public interface for the central core contract of an organization. @@ -18,7 +20,7 @@ pub trait ICore { /// ## Parameters /// /// - `ref self: T`: The current state of the contract. - fn schedule_payout(ref self: T); + fn schedule_payout(ref self: T, token: ContractAddress); /// # initialize_disbursement_schedule /// diff --git a/src/interfaces/ifactory.cairo b/src/interfaces/ifactory.cairo index 688a0ce..b9c5ad6 100644 --- a/src/interfaces/ifactory.cairo +++ b/src/interfaces/ifactory.cairo @@ -208,9 +208,9 @@ pub trait IFactory { // fn get_vault_org_pairs(self: @T) -> Array<(ContractAddress, ContractAddress)>; // in the future, you can upgrade a deployed org core from here - // fn initialize_upgrade(ref self: T, vaults: Array, cores: - // Array); - // this function would pick the updated class hash from the storage, if the class hash has been - // updated at present, it can only pick the latest... - // in the future, it can pick a specific class hash version +// fn initialize_upgrade(ref self: T, vaults: Array, cores: +// Array); +// this function would pick the updated class hash from the storage, if the class hash has been +// updated at present, it can only pick the latest... +// in the future, it can pick a specific class hash version } diff --git a/src/interfaces/ivault.cairo b/src/interfaces/ivault.cairo index 1300d4c..b454490 100644 --- a/src/interfaces/ivault.cairo +++ b/src/interfaces/ivault.cairo @@ -3,153 +3,188 @@ use starknet::ContractAddress; /// # IVault /// -/// This trait defines the public interface for a vault component. It outlines the core -/// functionalities for managing an organization's funds, including deposits, withdrawals, -/// member payments, and security measures like emergency freezes. The interface is designed -/// to be implemented by a Starknet component responsible for the secure handling and -/// tracking of financial assets. +/// This trait defines the public interface for a multi-token vault component. It outlines the core +/// functionalities for managing an organization's funds across various ERC20 tokens, including +/// deposits, withdrawals, member payments, and security measures. #[starknet::interface] pub trait IVault { /// # deposit_funds /// - /// Deposits a specified amount of a given token into the vault. + /// Deposits a specified amount of an accepted token into the vault. /// /// ## Parameters - /// /// - `ref self: TContractState`: The current state of the contract. + /// - `token`: The `ContractAddress` of the token being deposited. /// - `amount`: The amount of funds to deposit as a `u256`. - /// - `address`: The `ContractAddress` of the token being deposited. - fn deposit_funds(ref self: TContractState, amount: u256, address: ContractAddress); + /// - `from_address`: The `ContractAddress` from which the funds are being sent. + fn deposit_funds( + ref self: TContractState, + token: ContractAddress, + amount: u256, + from_address: ContractAddress, + ); /// # withdraw_funds /// - /// Withdraws a specified amount of a given token from the vault. This is a privileged action. + /// Withdraws a specified amount of an accepted token from the vault. /// /// ## Parameters /// /// - `ref self: TContractState`: The current state of the contract. + /// - `token`: The `ContractAddress` of the token being withdrawn. /// - `amount`: The amount of funds to withdraw as a `u256`. - /// - `address`: The `ContractAddress` of the token being withdrawn. - fn withdraw_funds(ref self: TContractState, amount: u256, address: ContractAddress); + /// - `to_address`: The `ContractAddress` to receive the funds. + fn withdraw_funds( + ref self: TContractState, token: ContractAddress, amount: u256, to_address: ContractAddress, + ); /// # emergency_freeze /// - /// Halts all outbound transactions from the vault. This function serves as a security - /// measure to prevent unauthorized fund movements in case of a compromise. + /// Halts all outbound transactions from the vault as a global security measure. /// /// ## Parameters - /// /// - `ref self: TContractState`: The current state of the contract. fn emergency_freeze(ref self: TContractState); /// # unfreeze_vault /// - /// Lifts the emergency freeze, restoring normal vault operations. This is a privileged action. + /// Lifts the emergency freeze, restoring normal vault operations. /// /// ## Parameters - /// /// - `ref self: TContractState`: The current state of the contract. fn unfreeze_vault(ref self: TContractState); - // fn bulk_transfer(ref self: TContractState, recipients: Span); - /// # pay_member /// - /// Executes a payment from the vault to a specific member's address. + /// Executes a payment from the vault to a specific member's address using a specified token. /// /// ## Parameters - /// /// - `ref self: TContractState`: The current state of the contract. + /// - `token`: The `ContractAddress` of the token for the payment. /// - `recipient`: The `ContractAddress` of the member to receive the payment. /// - `amount`: The payment amount as a `u256`. - fn pay_member(ref self: TContractState, recipient: ContractAddress, amount: u256); + fn pay_member( + ref self: TContractState, token: ContractAddress, recipient: ContractAddress, amount: u256, + ); /// # add_to_bonus_allocation /// - /// Allocates a certain amount of funds for bonus payments. These funds are tracked - /// separately from the main available balance. + /// Allocates funds of a specific token for bonus payments. /// /// ## Parameters - /// /// - `ref self: TContractState`: The current state of the contract. + /// - `token`: The `ContractAddress` of the token for the bonus allocation. /// - `amount`: The amount to allocate for bonuses. - /// - `address`: The `ContractAddress` of the token for the bonus allocation. - fn add_to_bonus_allocation(ref self: TContractState, amount: u256, address: ContractAddress); + /// - `address`: The address initiating the allocation. + fn add_to_bonus_allocation( + ref self: TContractState, token: ContractAddress, amount: u256, address: ContractAddress, + ); + + /// # add_accepted_token + /// + /// Adds a new ERC20 token to the list of accepted tokens for the vault. (Owner only) + /// + /// ## Parameters + /// - `ref self: TContractState`: The current state of the contract. + /// - `token`: The `ContractAddress` of the token to be accepted. + fn add_accepted_token(ref self: TContractState, token: ContractAddress); - /// # get_balance + /// # remove_accepted_token /// - /// Retrieves the total balance of the vault for all managed assets. + /// Removes an ERC20 token from the list of accepted tokens for the vault. (Owner only) /// /// ## Parameters + /// - `ref self: TContractState`: The current state of the contract. + /// - `token`: The `ContractAddress` of the token to be removed. + fn remove_accepted_token(ref self: TContractState, token: ContractAddress); + + /// # get_token_balance + /// + /// Retrieves the total balance of a specific token held by the vault. /// + /// ## Parameters /// - `self: @TContractState`: A snapshot of the contract's state. + /// - `token`: The `ContractAddress` of the token to query. /// /// ## Returns - /// - /// The total balance as a `u256`. - fn get_balance(self: @TContractState) -> u256; + /// The total balance of the specified token as a `u256`. + fn get_token_balance(self: @TContractState, token: ContractAddress) -> u256; - /// # get_available_funds + /// # get_all_token_balances /// - /// Retrieves the amount of funds available for general use, excluding any earmarked - /// allocations like bonuses. + /// Retrieves the vault's balance for every accepted token. /// /// ## Parameters - /// /// - `self: @TContractState`: A snapshot of the contract's state. /// /// ## Returns + /// An `Array` of (`ContractAddress`, `u256`) tuples representing each token and its balance. + fn get_all_token_balances(self: @TContractState) -> Array<(ContractAddress, u256)>; + + + /// # get_accepted_tokens + /// + /// Retrieves a list of all tokens the vault is authorized to manage. + /// + /// ## Parameters + /// - `self: @TContractState`: A snapshot of the contract's state. /// - /// The available funds as a `u256`. - fn get_available_funds(self: @TContractState) -> u256; + /// ## Returns + /// An `Array` of accepted tokens. + fn get_accepted_tokens(self: @TContractState) -> Array; /// # get_vault_status /// /// Returns the current operational status of the vault (e.g., Active, Frozen). /// /// ## Parameters - /// /// - `self: @TContractState`: A snapshot of the contract's state. /// /// ## Returns - /// /// The current `VaultStatus` enum. fn get_vault_status(self: @TContractState) -> VaultStatus; /// # get_bonus_allocation /// - /// Retrieves the current total amount allocated for bonuses. + /// Retrieves the current total amount allocated for bonuses for a specific token. /// /// ## Parameters - /// /// - `self: @TContractState`: A snapshot of the contract's state. + /// - `token`: The `ContractAddress` of the token to query. /// /// ## Returns + /// The bonus allocation for the specified token as a `u256`. + fn get_bonus_allocation(self: @TContractState, token: ContractAddress) -> u256; + + /// # is_token_acceptable + /// + /// Checks whether a specific token is accepted by the vault for transactions. + /// + /// ## Parameters + /// - `self: @TContractState`: A snapshot of the contract's state. + /// - `token`: The `ContractAddress` of the token to check. /// - /// The total bonus allocation as a `u256`. - fn get_bonus_allocation(self: @TContractState) -> u256; + /// ## Returns + /// A boolean value: `true` if the token is accepted, `false` otherwise. + fn is_token_acceptable(self: @TContractState, token: ContractAddress) -> bool; /// # get_transaction_history /// /// Retrieves a log of all transactions processed by the vault. /// /// ## Parameters - /// /// - `self: @TContractState`: A snapshot of the contract's state. /// /// ## Returns - /// /// An `Array` containing the vault's transaction history. fn get_transaction_history(self: @TContractState) -> Array; /// # allow_org_core_address /// - /// Grants permission to a core organization contract to interact with the vault. - /// This is necessary for enabling automated payments and other coordinated actions. + /// Grants permission to a contract to interact with the vault. /// /// ## Parameters - /// /// - `ref self: TContractState`: The current state of the contract. - /// - `org_address`: The `ContractAddress` of the core organization contract to authorize. + /// - `org_address`: The `ContractAddress` of the contract to authorize. fn allow_org_core_address(ref self: TContractState, org_address: ContractAddress); } diff --git a/tests/test_core.cairo b/tests/test_core.cairo index cbb57c2..61e27ba 100644 --- a/tests/test_core.cairo +++ b/tests/test_core.cairo @@ -205,8 +205,8 @@ fn setup_full_organization() -> ( // Fund the vault start_cheat_caller_address(vault_address, owner); - vault_dispatcher.deposit_funds(5000000000000000000000, owner); - vault_dispatcher.add_to_bonus_allocation(1000000000000000000000, owner); + vault_dispatcher.deposit_funds(token_address, 5000000000000000000000, owner); + vault_dispatcher.add_to_bonus_allocation(token_address, 1000000000000000000000, owner); stop_cheat_caller_address(vault_address); ( @@ -242,6 +242,7 @@ fn setup_organization_no_bonus() -> ( let (factory_address, _) = factory_contract.deploy(@factory_calldata).unwrap(); let factory_dispatcher = IFactoryDispatcher { contract_address: factory_address }; + start_cheat_caller_address(factory_address, owner()); let (core_address, vault_address) = factory_dispatcher .setup_org( token: token_address, @@ -254,6 +255,7 @@ fn setup_organization_no_bonus() -> ( first_admin_alias: 'admin', organization_type: 0, ); + stop_cheat_caller_address(factory_address); let core_dispatcher = ICoreDispatcher { contract_address: core_address }; let vault_dispatcher = IVaultDispatcher { contract_address: vault_address }; @@ -269,7 +271,7 @@ fn setup_organization_no_bonus() -> ( // Fund the vault without bonus allocation start_cheat_caller_address(vault_address, owner()); - vault_dispatcher.deposit_funds(5000000000000000000000, owner()); // 5,000 tokens + vault_dispatcher.deposit_funds(token_address, 5000000000000000000000, owner()); // 5,000 tokens // Skip add_to_bonus_allocation to keep bonus at 0 stop_cheat_caller_address(vault_address); @@ -397,7 +399,7 @@ fn test_schedule_payout_before_start_time() { start_cheat_block_timestamp(core_address, start_time - 100); start_cheat_caller_address(core_address, owner); - core_dispatcher.schedule_payout(); + core_dispatcher.schedule_payout(_token_address); stop_cheat_caller_address(core_address); stop_cheat_block_timestamp(core_address); @@ -430,7 +432,7 @@ fn test_schedule_payout_after_end_time() { start_cheat_block_timestamp(core_address, end_time + 100); start_cheat_caller_address(core_address, owner); - core_dispatcher.schedule_payout(); + core_dispatcher.schedule_payout(_token_address); stop_cheat_caller_address(core_address); stop_cheat_block_timestamp(core_address); @@ -466,7 +468,7 @@ fn test_schedule_payout_with_paused_schedule() { start_cheat_block_timestamp(core_address, start_time + 100); start_cheat_caller_address(core_address, owner); - core_dispatcher.schedule_payout(); + core_dispatcher.schedule_payout(_token_address); stop_cheat_caller_address(core_address); stop_cheat_block_timestamp(core_address); @@ -495,7 +497,7 @@ fn test_schedule_payout_at_end_timestamp() { stop_cheat_caller_address(core_address); start_cheat_block_timestamp(core_address, end_time); start_cheat_caller_address(core_address, owner); - core_dispatcher.schedule_payout(); + core_dispatcher.schedule_payout(_token_address); stop_cheat_caller_address(core_address); stop_cheat_block_timestamp(core_address); } @@ -516,7 +518,7 @@ fn test_schedule_payout_inactive_schedule() { add_test_members(core_dispatcher, core_address); start_cheat_caller_address(core_address, owner); - core_dispatcher.schedule_payout(); + core_dispatcher.schedule_payout(_token_address); stop_cheat_caller_address(core_address); } @@ -556,7 +558,7 @@ fn test_schedule_payout_successful() { let employee1_balance = _token_dispatcher.balance_of(employee1()); start_cheat_caller_address(core_address, owner); - core_dispatcher.schedule_payout(); + core_dispatcher.schedule_payout(_token_address); stop_cheat_caller_address(core_address); let new_employee1_balance = _token_dispatcher.balance_of(employee1()); diff --git a/tests/test_vault.cairo b/tests/test_vault.cairo index b306286..4098599 100644 --- a/tests/test_vault.cairo +++ b/tests/test_vault.cairo @@ -111,14 +111,25 @@ fn deploy_mock_erc20() -> (IMockERC20Dispatcher, ContractAddress) { (dispatcher, contract_address) } -fn deploy_vault() -> (IVaultDispatcher, ContractAddress, ContractAddress) { - let (token_dispatcher, token_address) = deploy_mock_erc20(); +fn deploy_another_mock_erc20() -> (IMockERC20Dispatcher, ContractAddress) { + let contract = declare("AnotherMockERC20").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@array![]).unwrap(); + let dispatcher = IMockERC20Dispatcher { contract_address }; + + (dispatcher, contract_address) +} + +fn deploy_vault() -> (IVaultDispatcher, ContractAddress, ContractAddress, ContractAddress) { + let (token1_dispatcher, token1_address) = deploy_mock_erc20(); + let (token2_dispatcher, token2_address) = deploy_mock_erc20(); let vault_contract = declare("Vault").unwrap().contract_class(); let owner_address = owner(); + let token_addresses = array![token1_address, token2_address]; // Use array! macro with explicit typing for deployment - let constructor_calldata = array![token_address.into(), owner_address.into()]; + let mut constructor_calldata = array![owner_address.into()]; + token_addresses.serialize(ref constructor_calldata); let (vault_address, _) = vault_contract.deploy(@constructor_calldata).unwrap(); let vault_dispatcher = IVaultDispatcher { contract_address: vault_address }; @@ -131,159 +142,191 @@ fn deploy_vault() -> (IVaultDispatcher, ContractAddress, ContractAddress) { vault_dispatcher.allow_org_core_address(recipient()); stop_cheat_caller_address(vault_address); - // Setup token approvals for all addresses - start_cheat_caller_address(token_address, owner_address); - token_dispatcher.approve(vault_address, 1000000000000000000000); - stop_cheat_caller_address(token_address); + // Setup token1 approvals for all addresses + start_cheat_caller_address(token1_address, owner_address); + token1_dispatcher.approve(vault_address, 1000000000000000000000); + stop_cheat_caller_address(token1_address); + + start_cheat_caller_address(token1_address, permitted_caller()); + token1_dispatcher.approve(vault_address, 1000000000000000000000); + stop_cheat_caller_address(token1_address); + + start_cheat_caller_address(token1_address, recipient()); + token1_dispatcher.approve(vault_address, 1000000000000000000000); + stop_cheat_caller_address(token1_address); - start_cheat_caller_address(token_address, permitted_caller()); - token_dispatcher.approve(vault_address, 1000000000000000000000); - stop_cheat_caller_address(token_address); + // Transfer some tokens to the vault for payment operations + start_cheat_caller_address(token1_address, owner_address); + token1_dispatcher.transfer(vault_address, 100000000000000000000); // 100 tokens + stop_cheat_caller_address(token1_address); + + // Setup token2 approvals for all addresses + start_cheat_caller_address(token2_address, owner_address); + token2_dispatcher.approve(vault_address, 2000000000000000000000); + stop_cheat_caller_address(token2_address); - start_cheat_caller_address(token_address, recipient()); - token_dispatcher.approve(vault_address, 1000000000000000000000); - stop_cheat_caller_address(token_address); + start_cheat_caller_address(token2_address, permitted_caller()); + token2_dispatcher.approve(vault_address, 2000000000000000000000); + stop_cheat_caller_address(token2_address); + + start_cheat_caller_address(token2_address, recipient()); + token2_dispatcher.approve(vault_address, 2000000000000000000000); + stop_cheat_caller_address(token2_address); // Transfer some tokens to the vault for payment operations - start_cheat_caller_address(token_address, owner_address); - token_dispatcher.transfer(vault_address, 100000000000000000000); // 100 tokens - stop_cheat_caller_address(token_address); + start_cheat_caller_address(token2_address, owner_address); + token2_dispatcher.transfer(vault_address, 200000000000000000000); // 200 tokens + stop_cheat_caller_address(token2_address); - (vault_dispatcher, vault_address, token_address) + (vault_dispatcher, vault_address, token1_address, token2_address) } // Constructor Tests #[test] fn test_constructor_initializes_correctly() { - let (vault, _vault_address, _token_address) = deploy_vault(); - assert(vault.get_balance() == 100000000000000000000, 'Incorrect initial balance'); - assert(vault.get_bonus_allocation() == 0, 'Incorrect bonus allocation'); + let (vault, _vault_address, _token1_address, _token2_address) = deploy_vault(); + assert( + vault.get_token_balance(_token1_address) == 100000000000000000000, + 'Incorrect initial balance', + ); + assert(vault.get_bonus_allocation(_token1_address) == 0, 'Incorrect bonus allocation'); + assert( + vault.get_token_balance(_token2_address) == 200000000000000000000, + 'Incorrect initial balance', + ); + assert(vault.get_bonus_allocation(_token2_address) == 0, 'Incorrect bonus allocation'); assert(vault.get_vault_status() == VaultStatus::VAULTRESUMED, 'Vault should be resumed'); } // Deposit Tests #[test] fn test_deposit_funds_success() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); let deposit_amount = 100000000000000000; // 0.1 ETH start_cheat_caller_address(vault_address, permitted_caller()); - let first_vault_balance = vault.get_balance(); - vault.deposit_funds(deposit_amount, owner()); + let first_vault_balance = vault.get_token_balance(_token1_address); + vault.deposit_funds(_token1_address, deposit_amount, owner()); stop_cheat_caller_address(vault_address); // Check balance updated let expected_balance = first_vault_balance + deposit_amount; - assert(vault.get_balance() == expected_balance, 'Balance not updated correctly'); + assert( + vault.get_token_balance(_token1_address) == expected_balance, + 'Balance not updated correctly', + ); } #[test] #[should_panic(expected: 'Direct Caller not permitted')] fn test_deposit_funds_unauthorized_caller() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _) = deploy_vault(); start_cheat_caller_address(vault_address, non_permitted_caller()); - vault.deposit_funds(100000000000000000, owner()); + vault.deposit_funds(_token1_address, 100000000000000000, owner()); stop_cheat_caller_address(vault_address); } #[test] #[should_panic(expected: 'Deep Caller Not Permitted')] fn test_deposit_funds_unauthorized_deep_caller() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); - vault.deposit_funds(100000000000000000, non_permitted_caller()); + vault.deposit_funds(_token1_address, 100000000000000000, non_permitted_caller()); stop_cheat_caller_address(vault_address); } #[test] #[should_panic(expected: 'Vault Frozen for Transactions')] fn test_deposit_funds_when_vault_frozen() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _) = deploy_vault(); // Freeze vault first start_cheat_caller_address(vault_address, permitted_caller()); vault.emergency_freeze(); // Try to deposit - should fail - vault.deposit_funds(100000000000000000, owner()); + vault.deposit_funds(_token1_address, 100000000000000000, owner()); stop_cheat_caller_address(vault_address); } // Withdrawal Tests #[test] fn test_withdraw_funds_success() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); let withdraw_amount = 1000000000000000; // 0.1 ETH start_cheat_caller_address(vault_address, owner()); - vault.deposit_funds(withdraw_amount + 1000, owner()); + vault.deposit_funds(_token1_address, withdraw_amount + 1000, owner()); stop_cheat_caller_address(vault_address); start_cheat_caller_address(vault_address, permitted_caller()); - let first_vault_balance = vault.get_balance(); - vault.withdraw_funds(withdraw_amount, recipient()); + let first_vault_balance = vault.get_token_balance(_token1_address); + vault.withdraw_funds(_token1_address, withdraw_amount, recipient()); stop_cheat_caller_address(vault_address); // Check balance updated let expected_balance = first_vault_balance - withdraw_amount; - assert(vault.get_balance() == expected_balance, 'Balance not updated correctly'); + assert( + vault.get_token_balance(_token1_address) == expected_balance, + 'Balance not updated correctly', + ); } #[test] #[should_panic(expected: 'Direct Caller not permitted')] fn test_withdraw_funds_unauthorized_caller() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _) = deploy_vault(); start_cheat_caller_address(vault_address, non_permitted_caller()); - vault.withdraw_funds(100000000000000000, recipient()); + vault.withdraw_funds(_token1_address, 100000000000000000, recipient()); stop_cheat_caller_address(vault_address); } #[test] #[should_panic(expected: 'Deep Caller Not Permitted')] fn test_withdraw_funds_unauthorized_deep_caller() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); - vault.withdraw_funds(100000000000000000, non_permitted_caller()); + vault.withdraw_funds(_token1_address, 100000000000000000, non_permitted_caller()); stop_cheat_caller_address(vault_address); } #[test] #[should_panic(expected: 'Insufficient Balance')] fn test_withdraw_funds_insufficient_balance() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _) = deploy_vault(); let excessive_amount = 300000000000000000000; // 2 ETH (more than available) start_cheat_caller_address(vault_address, owner()); - vault.deposit_funds(excessive_amount / 1000, owner()); + vault.deposit_funds(_token1_address, excessive_amount / 1000, owner()); stop_cheat_caller_address(vault_address); start_cheat_caller_address(vault_address, permitted_caller()); - vault.withdraw_funds(excessive_amount, recipient()); + vault.withdraw_funds(_token1_address, excessive_amount, recipient()); stop_cheat_caller_address(vault_address); } #[test] #[should_panic(expected: 'Vault Frozen for Transactions')] fn test_withdraw_funds_when_vault_frozen() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _) = deploy_vault(); // Freeze vault first start_cheat_caller_address(vault_address, permitted_caller()); vault.emergency_freeze(); // Try to withdraw - should fail - vault.withdraw_funds(100000000000000000, recipient()); + vault.withdraw_funds(_token1_address, 100000000000000000, recipient()); stop_cheat_caller_address(vault_address); } // Freeze/Unfreeze Tests #[test] fn test_emergency_freeze_success() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); vault.emergency_freeze(); @@ -294,9 +337,9 @@ fn test_emergency_freeze_success() { } #[test] -#[should_panic(expected: 'Caller not permitted')] +#[should_panic(expected: 'Direct Caller not permitted')] fn test_emergency_freeze_unauthorized() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, non_permitted_caller()); vault.emergency_freeze(); @@ -306,7 +349,7 @@ fn test_emergency_freeze_unauthorized() { #[test] #[should_panic(expected: 'Vault Already Frozen')] fn test_emergency_freeze_already_frozen() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); vault.emergency_freeze(); @@ -316,7 +359,7 @@ fn test_emergency_freeze_already_frozen() { #[test] fn test_unfreeze_vault_success() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); @@ -330,9 +373,9 @@ fn test_unfreeze_vault_success() { } #[test] -#[should_panic(expected: 'Caller not permitted')] +#[should_panic(expected: 'Direct Caller not permitted')] fn test_unfreeze_vault_unauthorized() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); // Freeze first start_cheat_caller_address(vault_address, permitted_caller()); @@ -348,7 +391,7 @@ fn test_unfreeze_vault_unauthorized() { #[test] #[should_panic(expected: 'Vault Not Frozen')] fn test_unfreeze_vault_not_frozen() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); vault.unfreeze_vault(); // Should fail - vault is not frozen @@ -358,87 +401,94 @@ fn test_unfreeze_vault_not_frozen() { // Pay Member Tests #[test] fn test_pay_member_success() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); let payment_amount = 100000000000000; // 0.1 ETH start_cheat_caller_address(vault_address, owner()); - vault.deposit_funds(payment_amount + 1000, owner()); + vault.deposit_funds(_token1_address, payment_amount + 1000, owner()); stop_cheat_caller_address(vault_address); start_cheat_caller_address(vault_address, permitted_caller()); - let first_vault_balance = vault.get_balance(); - vault.pay_member(recipient(), payment_amount); + let first_vault_balance = vault.get_token_balance(_token1_address); + vault.pay_member(_token1_address, recipient(), payment_amount); stop_cheat_caller_address(vault_address); // Check balance updated let expected_balance = first_vault_balance - payment_amount; - assert(vault.get_balance() == expected_balance, 'Balance not updated correctly'); + assert( + vault.get_token_balance(_token1_address) == expected_balance, + 'Balance not updated correctly', + ); } #[test] -#[should_panic(expected: 'Caller Not Permitted')] +#[should_panic(expected: 'Direct Caller not permitted')] fn test_pay_member_unauthorized() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, non_permitted_caller()); - vault.pay_member(recipient(), 100000000000000000); + vault.pay_member(_token1_address, recipient(), 100000000000000000); stop_cheat_caller_address(vault_address); } // Bonus Allocation Tests #[test] fn test_add_to_bonus_allocation_success() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); let bonus_amount = 100000000000000000; // 0.1 ETH start_cheat_caller_address(vault_address, permitted_caller()); - let starting_bonus_allocation = vault.get_bonus_allocation(); - vault.add_to_bonus_allocation(bonus_amount, owner()); + let starting_bonus_allocation = vault.get_bonus_allocation(_token1_address); + vault.add_to_bonus_allocation(_token1_address, bonus_amount, owner()); stop_cheat_caller_address(vault_address); // Check bonus updated let expected_bonus = starting_bonus_allocation + bonus_amount; - assert(vault.get_bonus_allocation() == expected_bonus, 'Bonus not updated correctly'); + assert( + vault.get_bonus_allocation(_token1_address) == expected_bonus, + 'Bonus not updated correctly', + ); } #[test] #[should_panic(expected: 'Direct Caller not permitted')] fn test_add_to_bonus_allocation_unauthorized() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, non_permitted_caller()); - vault.add_to_bonus_allocation(100000000000000000, owner()); + vault.add_to_bonus_allocation(_token1_address, 100000000000000000, owner()); stop_cheat_caller_address(vault_address); } #[test] #[should_panic(expected: 'Deep Caller Not Permitted')] fn test_add_to_bonus_allocation_unauthorized_deep_caller() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); - vault.add_to_bonus_allocation(100000000000000000, non_permitted_caller()); + vault.add_to_bonus_allocation(_token1_address, 100000000000000000, non_permitted_caller()); stop_cheat_caller_address(vault_address); } // Transaction History Tests #[test] fn test_transaction_history_records_correctly() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); let deposit_amount = 100000000000000000; let withdraw_amount = 50000000000000000; start_cheat_caller_address(vault_address, permitted_caller()); // Perform some transactions - vault.deposit_funds(deposit_amount, owner()); - vault.withdraw_funds(withdraw_amount, recipient()); - vault.add_to_bonus_allocation(25000000000000000, owner()); + vault.deposit_funds(_token1_address, deposit_amount, owner()); + vault.withdraw_funds(_token1_address, withdraw_amount, recipient()); + vault.add_to_bonus_allocation(_token1_address, 25000000000000000, owner()); stop_cheat_caller_address(vault_address); // Check transaction history let history = vault.get_transaction_history(); + println!("history: {:?}", history); assert(history.len() == 3, 'Should have 3 transactions'); // Check first transaction (deposit) @@ -466,14 +516,14 @@ fn test_transaction_history_records_correctly() { // Access Control Tests #[test] fn test_allow_org_core_address() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); let new_org_address = contract_address_const::<'new_org'>(); vault.allow_org_core_address(new_org_address); // Test that the new address can now call functions start_cheat_caller_address(vault_address, new_org_address); - vault.deposit_funds(100000000000000000, owner()); + vault.deposit_funds(_token1_address, 100000000000000000, owner()); stop_cheat_caller_address(vault_address); } @@ -481,22 +531,22 @@ fn test_allow_org_core_address() { #[test] #[ignore] fn test_get_balance() { - let (vault, _, _) = deploy_vault(); - let balance = vault.get_balance(); + let (vault, _, _token1_address, _token2_address) = deploy_vault(); + let balance = vault.get_token_balance(_token1_address); assert(balance == 1000000000000000000, 'Incorrect balance'); } #[test] #[ignore] fn test_get_bonus_allocation() { - let (vault, _, _) = deploy_vault(); - let bonus = vault.get_bonus_allocation(); + let (vault, _, _token1_address, _token2_address) = deploy_vault(); + let bonus = vault.get_bonus_allocation(_token1_address); assert(bonus == 0, 'Incorrect bonus allocation'); } #[test] fn test_get_vault_status() { - let (vault, _, _) = deploy_vault(); + let (vault, _, _token1_address, _token2_address) = deploy_vault(); let status = vault.get_vault_status(); assert(status == VaultStatus::VAULTRESUMED, 'Incorrect vault status'); } @@ -505,67 +555,72 @@ fn test_get_vault_status() { #[test] #[should_panic(expected: 'Invalid Amount')] fn test_zero_amount_deposit() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); - let first_vault_balance = vault.get_balance(); - let first_bonus_allocation = vault.get_bonus_allocation(); + let first_vault_balance = vault.get_token_balance(_token1_address); + let first_bonus_allocation = vault.get_bonus_allocation(_token1_address); // Test zero deposit - vault.deposit_funds(0, owner()); + vault.deposit_funds(_token1_address, 0, owner()); } #[test] #[should_panic(expected: 'Invalid Address')] fn test_zero_address_deposit() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); - let first_vault_balance = vault.get_balance(); - let first_bonus_allocation = vault.get_bonus_allocation(); + let first_vault_balance = vault.get_token_balance(_token1_address); + let first_bonus_allocation = vault.get_bonus_allocation(_token1_address); // Test zero deposit - vault.deposit_funds(250, zero_addr()); + vault.deposit_funds(_token1_address, 250, zero_addr()); } // Integration Tests #[test] fn test_complete_vault_workflow() { - let (vault, vault_address, _) = deploy_vault(); + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); start_cheat_caller_address(vault_address, permitted_caller()); - let mut vault_balance = vault.get_balance(); - let mut bonus_allocation = vault.get_bonus_allocation(); + let mut vault_balance = vault.get_token_balance(_token1_address); + let mut bonus_allocation = vault.get_bonus_allocation(_token1_address); // 1. Deposit funds let deposit_amount = 200000000000000000; // 0.2 ETH - vault.deposit_funds(deposit_amount, owner()); + vault.deposit_funds(_token1_address, deposit_amount, owner()); println!("Did not fail before deposit"); let expected_balance = vault_balance + deposit_amount; vault_balance += deposit_amount; - assert(vault.get_balance() == expected_balance, 'Deposit failed'); + assert(vault.get_token_balance(_token1_address) == expected_balance, 'Deposit failed'); // 2. Add bonus allocation let bonus_amount = 100000000000000000; // 0.1 ETH - vault.add_to_bonus_allocation(bonus_amount, owner()); + vault.add_to_bonus_allocation(_token1_address, bonus_amount, owner()); let expected_bonus = bonus_allocation + bonus_amount; bonus_allocation += bonus_amount; - assert(vault.get_bonus_allocation() == expected_bonus, 'Bonus allocation failed'); + assert( + vault.get_bonus_allocation(_token1_address) == expected_bonus, 'Bonus allocation failed', + ); // 3. Pay member let payment_amount = 150000000000000; // 0.15 ETH - vault.pay_member(recipient(), payment_amount); + vault.pay_member(_token1_address, recipient(), payment_amount); println!("I didn't fail after pay_member"); let expected_balance_after_payment = expected_balance - payment_amount; - assert(vault.get_balance() == expected_balance_after_payment, 'Payment failed'); + assert( + vault.get_token_balance(_token1_address) == expected_balance_after_payment, + 'Payment failed', + ); // 4. Withdraw funds let withdraw_amount = 100000000000000000; // 0.1 ETH - vault.withdraw_funds(withdraw_amount, recipient()); + vault.withdraw_funds(_token1_address, withdraw_amount, recipient()); println!("I didn't fail after withdrawal"); let final_expected_balance = expected_balance_after_payment - withdraw_amount; - assert(vault.get_balance() == final_expected_balance, 'Withdrawal failed'); + assert(vault.get_token_balance(_token1_address) == final_expected_balance, 'Withdrawal failed'); // 5. Check transaction history let history = vault.get_transaction_history(); @@ -580,3 +635,68 @@ fn test_complete_vault_workflow() { stop_cheat_caller_address(vault_address); } + +#[test] +fn test_add_accepted_token() { + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); + + let accepted_tokens_before = vault.get_accepted_tokens(); + + assert!(accepted_tokens_before.len() == 2, "Incorrect number of accepted tokens"); + assert(accepted_tokens_before.at(0) == @_token1_address, 'Incorrect accepted token'); + assert(accepted_tokens_before.at(1) == @_token2_address, 'Incorrect accepted token'); + + let new_token_address = contract_address_const::<'new_token'>(); + + start_cheat_caller_address(vault_address, permitted_caller()); + vault.add_accepted_token(new_token_address); + stop_cheat_caller_address(vault_address); + + let accepted_tokens_after = vault.get_accepted_tokens(); + + assert!(accepted_tokens_after.len() == 3, "Incorrect number of accepted tokens"); + assert(accepted_tokens_after.at(0) == @_token1_address, 'Incorrect accepted token'); + assert(accepted_tokens_after.at(1) == @_token2_address, 'Incorrect accepted token'); + assert(accepted_tokens_after.at(2) == @new_token_address, 'Incorrect accepted token'); +} + +#[test] +#[should_panic(expected: 'Token already accepted')] +fn test_add_accepted_token_already_accepted() { + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); + + start_cheat_caller_address(vault_address, permitted_caller()); + vault.add_accepted_token(_token1_address); + stop_cheat_caller_address(vault_address); +} + +#[test] +fn test_remove_accepted_token() { + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); + + let accepted_tokens_before = vault.get_accepted_tokens(); + + assert!(accepted_tokens_before.len() == 2, "Incorrect number of accepted tokens"); + assert(accepted_tokens_before.at(0) == @_token1_address, 'Incorrect accepted token'); + assert(accepted_tokens_before.at(1) == @_token2_address, 'Incorrect accepted token'); + + start_cheat_caller_address(vault_address, permitted_caller()); + vault.remove_accepted_token(_token1_address); + stop_cheat_caller_address(vault_address); + + let accepted_tokens_after = vault.get_accepted_tokens(); + println!("accepted_tokens_after: {:?}", accepted_tokens_after); + + assert!(accepted_tokens_after.len() == 1, "Incorrect number of accepted tokens"); + assert(accepted_tokens_after.at(0) == @_token2_address, 'Incorrect accepted token'); +} + +#[test] +#[should_panic(expected: 'Direct Caller not permitted')] +fn test_remove_accepted_token_unauthorized() { + let (vault, vault_address, _token1_address, _token2_address) = deploy_vault(); + + start_cheat_caller_address(vault_address, non_permitted_caller()); + vault.remove_accepted_token(_token1_address); + stop_cheat_caller_address(vault_address); +}