perf: C++ backend performance improvements (Phase 2)#13
Merged
Conversation
Three targeted changes to the single-group fast EM path: 1. Save a 1-based copy of mloc before the decrement so uP2() is called with the pre-saved matrix instead of creating a J×L temporary via `mloc+1` on every EM iteration. 2. Hoist the missing-data mask (mX0 / mX1 / mXMissing) and the `has_nan()` check out of the while-loop. The data matrix mX never changes during estimation, so the mask only needs to be built once. 3. Precompute per-item, per-category column-index vectors from mloc before the loop. Previously arma::find(mloc.row(j)==k) was called twice per (j,k) pair per EM iteration; now each lookup is O(1). Also replace the J×N ones-matrix multiply used to compute expN in the no-missing case with arma::repmat(arma::sum(msdPost,0),J,1), reducing cost from O(J·N·L) to O(N·L + J·L). Numerical results are identical; only wall-clock time changes.
In the no-missing-data branch of both LikNR and LikNR_LC, expN was computed as ones(J,N) * msdPost — an O(J·N·L) matrix multiplication whose result has identical rows equal to sum(msdPost, 0). Replace with arma::repmat(arma::sum(msdPost,0), J, 1) which costs O(N·L + J·L) and produces the same result exactly.
Mord() computes ordinal moment matrices (Xi11, Xi21, Xi22) used for fit statistics. The Xi21 loop is O(nitem^3) and Xi22 is O(nitem^4); both previously called LCprob.rows(arma::find(item_no==(x+1))) inside their innermost iterations, re-scanning the full item_no vector on every visit. Extract a one-time precomputation step that builds lc[i] for each item before any loop runs. All inner-loop variables (lci, lcj, lck, lcl) become const references into this cache, so no matrix data is copied and no find() call is repeated. Numerical output is identical; the saving grows as O(nitem^4) for large tests with many items.
6e76a30 to
9baf9e6
Compare
Owner
|
Thanks! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
mX0/mX1/mXMissing) out of thefast_GDINA_EMwhile-loop inLik2.cpp. The data matrix never changes during estimation, so the mask andhas_nan()check were being redundantly rebuilt every EM iteration.mlocbefore decrement infast_GDINA_EM, eliminating aJ x Ltemporary matrix created viamloc+1on every iteration.mlocbefore the EM loop infast_GDINA_EM. Previouslyarma::find(mloc.row(j)==k)was called twice per(j,k)pair per iteration; now each lookup is O(1).ones(J,N) * msdPostwithrepmat(sum(msdPost,0), J, 1)inLikNR,LikNR_LC, andfast_GDINA_EM. In the no-missing-data case every row ofexpNis identical, so the J x N ones-matrix multiply (O(J x N x L)) is wasteful; the replacement costs O(N x L + J x L).LCprobrow slices into astd::vector<arma::mat>inMord(). The Xi21 loop is O(nitem^3) and Xi22 is O(nitem^4); without caching, each iteration re-ranrows(find(...))over the full category index vector.Benchmark
Measured on the built-in
sim30GDINAdataset (N=1000, J=30, K=5, L=32):LikNRfast_GDINA_EM(30 itr)MordGains scale with larger J, N, or K.
Mordbenefits most because the savings compound across its O(nitem^4) Xi22 loop.Verification
:::in test files, unrelated to this PR).R CMD INSTALL.