From 49ab7069baed67d1217207495c30ca5b5e997b94 Mon Sep 17 00:00:00 2001 From: Joe Richey Date: Fri, 29 May 2026 00:52:47 +0000 Subject: [PATCH 1/4] tpm2/transport/test: Support spawning Reference Simulator in Open Adds support for spawning a local reference simulator binary via the new '--tpm-sim-path' flag. The simulator is spawned inside a temporary directory with '--pick_ports' enabled. Port files are polled on startup and parsed to open the TCP transport. All unit tests under `tpm2/test` are refactored to use `testhelper.Open` to utilize this support. To show that this works: 1. Build the Reference Simulator: cmake --build /usr/local/google/home/joerichey/dev/tcg/TPM/build --target Simulator 2. Run tests using the built binary: go test ./tpm2/test -tpm-sim-path /usr/local/google/home/joerichey/dev/tcg/TPM/build/Simulator/Simulator Note: The RSA 3072 and 4096 parameter tests ('TestTestParms') are expected to fail when executed against the Reference Simulator because the Reference Simulator compiles with native support for 3072-bit and 4096-bit RSA keys, whereas the test suite expects them to be unsupported and return TPM_RC_VALUE. Signed-off-by: Joe Richey --- tpm2/test/activate_credential_test.go | 12 +- tpm2/test/audit_test.go | 12 +- tpm2/test/certify_test.go | 17 +-- tpm2/test/clear_test.go | 7 +- tpm2/test/combined_context_test.go | 7 +- tpm2/test/commit_test.go | 7 +- tpm2/test/create_loaded_test.go | 7 +- tpm2/test/duplicate_test.go | 7 +- tpm2/test/ecdh_test.go | 7 +- tpm2/test/ek_test.go | 7 +- tpm2/test/evict_control_test.go | 7 +- tpm2/test/get_random_test.go | 7 +- tpm2/test/get_time_test.go | 7 +- tpm2/test/hash_sequence_hash_test.go | 22 +--- tpm2/test/hierarchy_change_auth_test.go | 7 +- tpm2/test/hmac_start_test.go | 12 +- tpm2/test/hmac_test.go | 12 +- tpm2/test/import_test.go | 12 +- tpm2/test/load_external_test.go | 7 +- tpm2/test/names_test.go | 12 +- tpm2/test/nv_test.go | 22 +--- tpm2/test/object_change_auth_test.go | 7 +- tpm2/test/pcr_test.go | 18 +-- tpm2/test/policy_test.go | 80 +++--------- tpm2/test/read_public_test.go | 12 +- tpm2/test/rsa_encryption_test.go | 7 +- tpm2/test/sealing_test.go | 7 +- tpm2/test/sign_test.go | 7 +- tpm2/test/symmetric_encryption_test.go | 12 +- tpm2/test/test_parms_test.go | 7 +- tpm2/transport/tcp/tcp.go | 13 ++ tpm2/transport/test/open.go | 156 ++++++++++++++++++++++++ 32 files changed, 270 insertions(+), 273 deletions(-) create mode 100644 tpm2/transport/test/open.go diff --git a/tpm2/test/activate_credential_test.go b/tpm2/test/activate_credential_test.go index caed4d5f..e67ab9c9 100644 --- a/tpm2/test/activate_credential_test.go +++ b/tpm2/test/activate_credential_test.go @@ -6,7 +6,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) // p384Template is an SRK-like ECDH-P384 key based on the P384 EK template. @@ -59,10 +59,7 @@ var p384Template = TPMTPublic{ // This test checks that ActivateCredential can decrypt a credential created by the TPM in MakeCredential. func TestActivateTPMCredential(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() ekCreate := CreatePrimary{ @@ -140,10 +137,7 @@ func TestActivateTPMCredential(t *testing.T) { // This test checks that ActivateCredential can decrypt a credential created by a remote server using CreateCredential. func TestActivateSWCredential(t *testing.T) { - tpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("OpenSimulator() = %v", err) - } + tpm := testhelper.Open(t) defer tpm.Close() for _, tc := range []struct { diff --git a/tpm2/test/audit_test.go b/tpm2/test/audit_test.go index 894a6f86..b6660b0c 100644 --- a/tpm2/test/audit_test.go +++ b/tpm2/test/audit_test.go @@ -6,14 +6,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestAuditSession(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Create the audit session @@ -193,10 +190,7 @@ func TestAuditSession(t *testing.T) { // TestAuditSessionWithCertify tests audit session with a more complex command (Certify) // which has two AuthHandles func TestAuditSessionWithCertify(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Create the audit session diff --git a/tpm2/test/certify_test.go b/tpm2/test/certify_test.go index 8cba5aa2..aa05f698 100644 --- a/tpm2/test/certify_test.go +++ b/tpm2/test/certify_test.go @@ -9,14 +9,11 @@ import ( "github.com/google/go-cmp/cmp" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestCertify(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() Auth := []byte("password") @@ -172,10 +169,7 @@ func TestCertify(t *testing.T) { } func TestCreateAndCertifyCreation(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() public := New2B(TPMTPublic{ @@ -303,10 +297,7 @@ func TestCreateAndCertifyCreation(t *testing.T) { } func TestNVCertify(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() Auth := []byte("password") diff --git a/tpm2/test/clear_test.go b/tpm2/test/clear_test.go index bdac11f5..1536103c 100644 --- a/tpm2/test/clear_test.go +++ b/tpm2/test/clear_test.go @@ -5,14 +5,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestClear(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() srkCreate := CreatePrimary{ diff --git a/tpm2/test/combined_context_test.go b/tpm2/test/combined_context_test.go index d7cd40c1..a2a3de31 100644 --- a/tpm2/test/combined_context_test.go +++ b/tpm2/test/combined_context_test.go @@ -7,7 +7,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func ReadPublicName(t *testing.T, handle TPMHandle, thetpm transport.TPM) TPM2BName { @@ -24,10 +24,7 @@ func ReadPublicName(t *testing.T, handle TPMHandle, thetpm transport.TPM) TPM2BN } func TestCombinedContext(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() createPrimary := CreatePrimary{ diff --git a/tpm2/test/commit_test.go b/tpm2/test/commit_test.go index 226df310..e107af32 100644 --- a/tpm2/test/commit_test.go +++ b/tpm2/test/commit_test.go @@ -4,14 +4,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestCommit(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() diff --git a/tpm2/test/create_loaded_test.go b/tpm2/test/create_loaded_test.go index fc7114f0..63bceb6d 100644 --- a/tpm2/test/create_loaded_test.go +++ b/tpm2/test/create_loaded_test.go @@ -5,7 +5,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func getDeriver(t *testing.T, thetpm transport.TPM) NamedHandle { @@ -51,10 +51,7 @@ func getDeriver(t *testing.T, thetpm transport.TPM) NamedHandle { } func TestCreateLoaded(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() deriver := getDeriver(t, thetpm) diff --git a/tpm2/test/duplicate_test.go b/tpm2/test/duplicate_test.go index 7d537056..f0167722 100644 --- a/tpm2/test/duplicate_test.go +++ b/tpm2/test/duplicate_test.go @@ -5,16 +5,13 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) // TestDuplicate creates an object under Owner->SRK and duplicates it to // Endorsement->SRK. func TestDuplicate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() t.Log("### Create Owner SRK") diff --git a/tpm2/test/ecdh_test.go b/tpm2/test/ecdh_test.go index d2bbce8f..d4cd8421 100644 --- a/tpm2/test/ecdh_test.go +++ b/tpm2/test/ecdh_test.go @@ -8,14 +8,11 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestECDH(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Create a TPM ECDH key diff --git a/tpm2/test/ek_test.go b/tpm2/test/ek_test.go index ea901bfa..d71a46d2 100644 --- a/tpm2/test/ek_test.go +++ b/tpm2/test/ek_test.go @@ -8,7 +8,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) // Decodes the provided hex strings into a byte array. Panics on non-hex chars. @@ -317,10 +317,7 @@ func ekTest(t *testing.T, ekTemplate TPMTPublic) { } } - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() for _, c := range cases { diff --git a/tpm2/test/evict_control_test.go b/tpm2/test/evict_control_test.go index fc84bb9f..a1c5f7ec 100644 --- a/tpm2/test/evict_control_test.go +++ b/tpm2/test/evict_control_test.go @@ -4,14 +4,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestEvictControl(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() srkCreate := CreatePrimary{ diff --git a/tpm2/test/get_random_test.go b/tpm2/test/get_random_test.go index 0a524cb6..e26623f2 100644 --- a/tpm2/test/get_random_test.go +++ b/tpm2/test/get_random_test.go @@ -4,14 +4,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestGetRandom(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() grc := GetRandom{ diff --git a/tpm2/test/get_time_test.go b/tpm2/test/get_time_test.go index f4ae6d04..8eefedca 100644 --- a/tpm2/test/get_time_test.go +++ b/tpm2/test/get_time_test.go @@ -5,14 +5,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestGetTime(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() createPrimary := CreatePrimary{ diff --git a/tpm2/test/hash_sequence_hash_test.go b/tpm2/test/hash_sequence_hash_test.go index 7378cefb..095489ed 100644 --- a/tpm2/test/hash_sequence_hash_test.go +++ b/tpm2/test/hash_sequence_hash_test.go @@ -9,15 +9,12 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestHash(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() run := func(t *testing.T, data []byte, hierarchy TPMHandle, thetpm transport.TPM) { @@ -48,10 +45,7 @@ func TestHash(t *testing.T) { } func TestHashNullHierarchy(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() data := []byte("carolyn") @@ -72,10 +66,7 @@ func TestHashNullHierarchy(t *testing.T) { } func TestHashSequence(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() run := func(t *testing.T, bufferSize int, password string, hierarchy TPMHandle, thetpm transport.TPM) { @@ -153,10 +144,7 @@ func TestHashSequence(t *testing.T) { } func TestHashSequenceNullHierarchy(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() run := func(t *testing.T, bufferSize int, password string, thetpm transport.TPM) { diff --git a/tpm2/test/hierarchy_change_auth_test.go b/tpm2/test/hierarchy_change_auth_test.go index c7ae382e..fefe39c3 100644 --- a/tpm2/test/hierarchy_change_auth_test.go +++ b/tpm2/test/hierarchy_change_auth_test.go @@ -5,14 +5,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestHierarchyChangeAuth(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() authKey := []byte("authkey") diff --git a/tpm2/test/hmac_start_test.go b/tpm2/test/hmac_start_test.go index cfb61112..bdc5054b 100644 --- a/tpm2/test/hmac_start_test.go +++ b/tpm2/test/hmac_start_test.go @@ -8,14 +8,11 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestHmacStart(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() run := func(t *testing.T, data []byte, password []byte, hierarchy TPMHandle, thetpm transport.TPM) []byte { @@ -251,10 +248,7 @@ func TestHmacStart(t *testing.T) { } func TestHmacStartNoKeyAuth(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() run := func(t *testing.T, data []byte, password []byte, hierarchy TPMHandle, thetpm transport.TPM) []byte { diff --git a/tpm2/test/hmac_test.go b/tpm2/test/hmac_test.go index 41feb84b..d8648e74 100644 --- a/tpm2/test/hmac_test.go +++ b/tpm2/test/hmac_test.go @@ -10,15 +10,12 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestHMAC(t *testing.T) { // connect to TPM simulator - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // create HMAC key @@ -91,10 +88,7 @@ func TestImportedHMACKey(t *testing.T) { persistentHandle := TPMHandle(0x81000000) // connect to TPM simulator - theTPM, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + theTPM := testhelper.Open(t) defer theTPM.Close() // create primary key diff --git a/tpm2/test/import_test.go b/tpm2/test/import_test.go index 5745f4fc..12790429 100644 --- a/tpm2/test/import_test.go +++ b/tpm2/test/import_test.go @@ -8,15 +8,12 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) // This test checks that Import can import an object in the clear. func TestCleartextImport(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() srkCreate := CreatePrimary{ @@ -136,10 +133,7 @@ func makeSealedBlob(t *testing.T, nameAlg TPMIAlgHash, obfuscation []byte, conte // This test checks that Import can import an object created by a remote server using CreateDuplicate. func TestSWDuplicateImport(t *testing.T) { - tpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("OpenSimulator() = %v", err) - } + tpm := testhelper.Open(t) defer tpm.Close() for _, tc := range []struct { diff --git a/tpm2/test/load_external_test.go b/tpm2/test/load_external_test.go index 238d4e51..92f54d3f 100644 --- a/tpm2/test/load_external_test.go +++ b/tpm2/test/load_external_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func decodeHex(t *testing.T, h string) []byte { @@ -71,10 +71,7 @@ func TestLoadExternal(t *testing.T) { }, } - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() for name, load := range loads { diff --git a/tpm2/test/names_test.go b/tpm2/test/names_test.go index 270396a6..5c44c1d3 100644 --- a/tpm2/test/names_test.go +++ b/tpm2/test/names_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestHandleName(t *testing.T) { @@ -17,10 +17,7 @@ func TestHandleName(t *testing.T) { } func TestObjectName(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() createPrimary := CreatePrimary{ @@ -50,10 +47,7 @@ func TestObjectName(t *testing.T) { } func TestNVName(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() public := New2B( diff --git a/tpm2/test/nv_test.go b/tpm2/test/nv_test.go index 4dedabdc..b9033839 100644 --- a/tpm2/test/nv_test.go +++ b/tpm2/test/nv_test.go @@ -7,14 +7,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestNVAuthWrite(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() def := NVDefineSpace{ @@ -98,10 +95,7 @@ func TestNVAuthWrite(t *testing.T) { } func TestNVAuthIncrement(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Define the counter space @@ -210,10 +204,7 @@ func TestNVAuthIncrement(t *testing.T) { } func TestNVWriteLock(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Define the NV space with attributes that allow it to be locked @@ -324,10 +315,7 @@ func TestNVWriteLock(t *testing.T) { } func TestNVReadLock(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Define the NV space with attributes that allow it to be locked for reading diff --git a/tpm2/test/object_change_auth_test.go b/tpm2/test/object_change_auth_test.go index bd46dd95..7513369d 100644 --- a/tpm2/test/object_change_auth_test.go +++ b/tpm2/test/object_change_auth_test.go @@ -6,14 +6,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestObjectChangeAuth(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Create the SRK diff --git a/tpm2/test/pcr_test.go b/tpm2/test/pcr_test.go index 09a2bc96..edbfe5fe 100644 --- a/tpm2/test/pcr_test.go +++ b/tpm2/test/pcr_test.go @@ -9,7 +9,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestPCRs(t *testing.T) { @@ -87,10 +87,7 @@ func allZero(s []byte) bool { } func TestPCRReset(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() DebugPCR := uint(16) @@ -144,6 +141,9 @@ func TestPCRReset(t *testing.T) { if err != nil { t.Fatalf("failed to read PCRs") } + if len(pcrReadRsp.PCRValues.Digests) == 0 { + t.Skipf("PCR bank %v not allocated/supported, skipping", c.hashalg) + } postExtendPCR16 := pcrReadRsp.PCRValues.Digests[0].Buffer if allZero(postExtendPCR16) { t.Errorf("postExtendPCR16 not expected to be all Zero: %v", postExtendPCR16) @@ -169,10 +169,7 @@ func TestPCRReset(t *testing.T) { } func TestPCREvent(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() cases := []struct { @@ -212,6 +209,9 @@ func TestPCREvent(t *testing.T) { if err != nil { t.Fatalf("failed to read PCRs") } + if len(pcrReadRsp.PCRValues.Digests) == 0 { + t.Skipf("PCR bank %v not allocated/supported, skipping", c.hashalg) + } postExtendPCR16 := pcrReadRsp.PCRValues.Digests[0].Buffer if allZero(postExtendPCR16) { t.Errorf("postExtendPCR16 not expected to be all Zero: %v", postExtendPCR16) diff --git a/tpm2/test/policy_test.go b/tpm2/test/policy_test.go index 81a90890..56aaf04b 100644 --- a/tpm2/test/policy_test.go +++ b/tpm2/test/policy_test.go @@ -2,13 +2,12 @@ package tpm2test import ( "bytes" - "crypto/sha1" "crypto/sha256" "testing" . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) // This test isn't interesting, but it checks that you can omit the handles on `StartAuthSession`. @@ -27,10 +26,7 @@ func TestCreatePolicySession(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() sas, err := StartAuthSession{ @@ -213,10 +209,7 @@ func primaryRSAEK(t *testing.T, thetpm transport.TPM) (NamedHandle, func()) { } func TestPolicySignedUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() sk, cleanup := signingKey(t, thetpm) @@ -275,10 +268,7 @@ func TestPolicySignedUpdate(t *testing.T) { } func TestPolicySecretUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() sk, cleanup := signingKey(t, thetpm) @@ -331,10 +321,7 @@ func TestPolicySecretUpdate(t *testing.T) { } func TestPolicyOrUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Use a trial session to calculate this policy @@ -411,24 +398,21 @@ func getExpectedPCRDigest(t *testing.T, thetpm transport.TPM, selection TPMLPCRS } func TestPolicyPCR(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() selection := TPMLPCRSelection{ PCRSelections: []TPMSPCRSelection{ { - Hash: TPMAlgSHA1, + Hash: TPMAlgSHA256, PCRSelect: PCClientCompatible.PCRs(0, 1, 2, 3, 7), }, }, } - expectedDigest := getExpectedPCRDigest(t, thetpm, selection, TPMAlgSHA1) + expectedDigest := getExpectedPCRDigest(t, thetpm, selection, TPMAlgSHA256) - wrongDigest := sha1.Sum(expectedDigest) + wrongDigest := sha256.Sum256(expectedDigest) tests := []struct { name string @@ -446,7 +430,7 @@ func TestPolicyPCR(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sess, cleanup2, err := PolicySession(thetpm, TPMAlgSHA1, 16, tt.authOption...) + sess, cleanup2, err := PolicySession(thetpm, TPMAlgSHA256, 16, tt.authOption...) if err != nil { t.Fatalf("setting up policy session: %v", err) @@ -482,7 +466,7 @@ func TestPolicyPCR(t *testing.T) { // If the pcrDigest is empty: see TPM 2.0 Part 3, 23.7. if tt.pcrDigest == nil { - expectedDigest := getExpectedPCRDigest(t, thetpm, selection, TPMAlgSHA1) + expectedDigest := getExpectedPCRDigest(t, thetpm, selection, TPMAlgSHA256) t.Logf("expectedDigest=%x", expectedDigest) // Create a populated policyPCR for the PolicyCalculator @@ -490,7 +474,7 @@ func TestPolicyPCR(t *testing.T) { } // Use the policy helper to calculate the same policy - pol, err := NewPolicyCalculator(TPMAlgSHA1) + pol, err := NewPolicyCalculator(TPMAlgSHA256) if err != nil { t.Fatalf("creating policy calculator: %v", err) } @@ -510,10 +494,7 @@ func TestPolicyPCR(t *testing.T) { } func TestPolicyCpHashUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Use a trial session to calculate this policy @@ -561,10 +542,7 @@ func TestPolicyCpHashUpdate(t *testing.T) { } func TestPolicyAuthorizeUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Use a trial session to calculate this policy @@ -618,10 +596,7 @@ func TestPolicyAuthorizeUpdate(t *testing.T) { } func TestPolicyNVWrittenUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Use a trial session to calculate this policy @@ -667,10 +642,7 @@ func TestPolicyNVWrittenUpdate(t *testing.T) { } func TestPolicyNVUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() nv, cleanup := nvIndex(t, thetpm) @@ -723,10 +695,7 @@ func TestPolicyNVUpdate(t *testing.T) { } func TestPolicyAuthorizeNVUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() nv, cleanup := nvIndex(t, thetpm) @@ -776,10 +745,7 @@ func TestPolicyAuthorizeNVUpdate(t *testing.T) { } func TestPolicyCommandCodeUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Use a trial session to calculate this policy @@ -824,10 +790,7 @@ func TestPolicyCommandCodeUpdate(t *testing.T) { } func TestPolicyAuthValue(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() password := []byte("foo") @@ -1008,10 +971,7 @@ func TestPolicyAuthValue(t *testing.T) { } func TestPolicyDuplicationSelectUpdate(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() ek, ekcleanup := primaryRSAEK(t, thetpm) diff --git a/tpm2/test/read_public_test.go b/tpm2/test/read_public_test.go index 871b8fc5..c6586adc 100644 --- a/tpm2/test/read_public_test.go +++ b/tpm2/test/read_public_test.go @@ -6,16 +6,13 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) // TestReadPublicKey compares the CreatePrimary response parameter outPublic with the output of ReadPublic outPublic. func TestReadPublicKey(t *testing.T) { // Open simulated TPM for testing. - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) // Defer the close of the simulated TPM to after use. // Without this, other programs/tests may not be able to get a handle to the TPM. @@ -113,10 +110,7 @@ func TestReadPublicKey(t *testing.T) { // TestReadPublicWithHMACSession tests that ReadPublic works when called with an HMAC session. func TestReadPublicWithHMACSession(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() createPrimaryCmd := CreatePrimary{ diff --git a/tpm2/test/rsa_encryption_test.go b/tpm2/test/rsa_encryption_test.go index 17011852..34ec42e9 100644 --- a/tpm2/test/rsa_encryption_test.go +++ b/tpm2/test/rsa_encryption_test.go @@ -5,14 +5,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestRSAEncryption(t *testing.T) { - theTpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + theTpm := testhelper.Open(t) t.Cleanup(func() { if err := theTpm.Close(); err != nil { t.Errorf("%v", err) diff --git a/tpm2/test/sealing_test.go b/tpm2/test/sealing_test.go index 8f498f61..0f8a781a 100644 --- a/tpm2/test/sealing_test.go +++ b/tpm2/test/sealing_test.go @@ -6,7 +6,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) // Test creating and unsealing a sealed data blob with a password and HMAC. @@ -25,10 +25,7 @@ func TestUnseal(t *testing.T) { } func unsealingTest(t *testing.T, srkTemplate TPMTPublic) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() // Create the SRK diff --git a/tpm2/test/sign_test.go b/tpm2/test/sign_test.go index c12caa29..b1179e0b 100644 --- a/tpm2/test/sign_test.go +++ b/tpm2/test/sign_test.go @@ -7,15 +7,12 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestSign(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() createPrimary := CreatePrimary{ diff --git a/tpm2/test/symmetric_encryption_test.go b/tpm2/test/symmetric_encryption_test.go index c76f1584..9c037e9e 100644 --- a/tpm2/test/symmetric_encryption_test.go +++ b/tpm2/test/symmetric_encryption_test.go @@ -9,16 +9,13 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) const maxDigestBuffer = 1024 func TestAESEncryption(t *testing.T) { - theTpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + theTpm := testhelper.Open(t) t.Cleanup(func() { if err := theTpm.Close(); err != nil { t.Errorf("%v", err) @@ -115,10 +112,7 @@ func TestAESEncryption(t *testing.T) { } func TestAESEncryptionBlock(t *testing.T) { - theTpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + theTpm := testhelper.Open(t) t.Cleanup(func() { if err := theTpm.Close(); err != nil { t.Errorf("%v", err) diff --git a/tpm2/test/test_parms_test.go b/tpm2/test/test_parms_test.go index d64e9505..4aef8d42 100644 --- a/tpm2/test/test_parms_test.go +++ b/tpm2/test/test_parms_test.go @@ -5,14 +5,11 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpm2/transport/simulator" + testhelper "github.com/google/go-tpm/tpm2/transport/test" ) func TestTestParms(t *testing.T) { - thetpm, err := simulator.OpenSimulator() - if err != nil { - t.Fatalf("could not connect to TPM simulator: %v", err) - } + thetpm := testhelper.Open(t) defer thetpm.Close() for _, tt := range []struct { diff --git a/tpm2/transport/tcp/tcp.go b/tpm2/transport/tcp/tcp.go index 1b6827fe..bea7e762 100644 --- a/tpm2/transport/tcp/tcp.go +++ b/tpm2/transport/tcp/tcp.go @@ -205,6 +205,19 @@ func (t *TPM) Reset() error { return t.sendBasicPlatformCommand(platformReset) } +// Stop tells the simulator process to exit. +func (t *TPM) Stop() error { + var errs []error + if err := binary.Write(t.cmd, binary.BigEndian, tpmStop); err != nil { + errs = append(errs, fmt.Errorf("could not write STOP to command service: %w", err)) + } + if err := binary.Write(t.plat, binary.BigEndian, platformStop); err != nil { + errs = append(errs, fmt.Errorf("could not write STOP to platform service: %w", err)) + } + return errors.Join(errs...) +} + + // Config provides the connection information for a running TCP TPM. type Config struct { // CommandAddress is the full host:port address of the Command server, e.g., diff --git a/tpm2/transport/test/open.go b/tpm2/transport/test/open.go new file mode 100644 index 00000000..c3c89bbe --- /dev/null +++ b/tpm2/transport/test/open.go @@ -0,0 +1,156 @@ +package testhelper + +import ( + "flag" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/google/go-tpm/tpm2" + "github.com/google/go-tpm/tpm2/transport" + "github.com/google/go-tpm/tpm2/transport/simulator" + "github.com/google/go-tpm/tpm2/transport/tcp" +) + +var ( + tpmSimPath = flag.String("tpm-sim-path", "", "Path to a TPM simulator binary") +) + +type process struct { + tb testing.TB + cmd *exec.Cmd + dir string + conn *tcp.TPM +} + +func startProcess(tb testing.TB, path string) *process { + dir, err := os.MkdirTemp("", "tpm-sim-*") + if err != nil { + tb.Fatalf("failed to create temp dir: %v", err) + } + + keep := false + defer func() { + if !keep { + os.RemoveAll(dir) + } + }() + + cmd := exec.Command(path, "--pick_ports") + cmd.Dir = dir + if err := cmd.Start(); err != nil { + tb.Fatalf("failed to start simulator process: %v", err) + } + defer func() { + if !keep { + cmd.Process.Kill() + cmd.Wait() + } + }() + + cmdPort, platPort, err := readPorts(dir) + if err != nil { + tb.Fatalf("failed to read ports: %v", err) + } + conn, err := tcp.Open(tcp.Config{ + CommandAddress: fmt.Sprintf("127.0.0.1:%d", cmdPort), + PlatformAddress: fmt.Sprintf("127.0.0.1:%d", platPort), + }) + if err != nil { + tb.Fatalf("failed to open TCP connection to simulator: %v", err) + } + defer func() { + if !keep { + conn.Close() + } + }() + + if err := conn.PowerOn(); err != nil { + tb.Fatalf("failed to power on simulator: %v", err) + } + + _, err = tpm2.Startup{ + StartupType: tpm2.TPMSUClear, + }.Execute(conn) + if err != nil { + tb.Fatalf("failed to startup simulator: %v", err) + } + + keep = true + return &process{ + tb: tb, + cmd: cmd, + dir: dir, + conn: conn, + } +} + +func (p *process) Send(cmd []byte) ([]byte, error) { + rsp, err := p.conn.Send(cmd) + if err == nil { + if hdr, err := tpm2.Unmarshal[tpm2.TPMRspHeader](rsp); err == nil { + if hdr.ResponseCode == tpm2.TPMRCRetry { + return p.conn.Send(cmd) + } + } + } + return rsp, err +} + +// Close implements the TPMCloser interface. +func (p *process) Close() error { + var err error + if err = p.conn.Stop(); err != nil { + p.tb.Errorf("failed to stop simulator: %v", err) + } + if err = p.conn.Close(); err != nil { + p.tb.Errorf("failed to close simulator connection: %v", err) + } + if err = p.cmd.Wait(); err != nil { + p.tb.Errorf("failed to wait for simulator process: %v", err) + } + if err = os.RemoveAll(p.dir); err != nil { + p.tb.Errorf("failed to remove temp dir %q: %v", p.dir, err) + } + return err // Report all errors but only return the last one +} + +func Open(tb testing.TB) transport.TPMCloser { + if *tpmSimPath != "" { + return startProcess(tb, *tpmSimPath) + } + tpm, err := simulator.OpenSimulator() + if err != nil { + tb.Fatalf("Unable to OpenSimulator: %v", err) + } + return tpm +} + +func readPorts(dir string) (cmdPort, platPort int, err error) { + deadline := time.Now().Add(5 * time.Second) + for { + if time.Now().After(deadline) { + return 0, 0, fmt.Errorf("timed out waiting for simulator port files") + } + + cmdPortBytes, err1 := os.ReadFile(filepath.Join(dir, "command.port")) + platPortBytes, err2 := os.ReadFile(filepath.Join(dir, "platform.port")) + if err1 == nil && err2 == nil { + cmdPortStr := strings.TrimSpace(string(cmdPortBytes)) + platPortStr := strings.TrimSpace(string(platPortBytes)) + if cmdPortStr != "" && platPortStr != "" { + cmdPort, err1 := strconv.Atoi(cmdPortStr) + platPort, err2 := strconv.Atoi(platPortStr) + if err1 == nil && err2 == nil { + return cmdPort, platPort, nil + } + } + } + time.Sleep(50 * time.Millisecond) + } +} From 4f9e41f0a4681bb8979e3da4980cbd0482ac20bb Mon Sep 17 00:00:00 2001 From: Megan Lu Date: Fri, 29 May 2026 23:56:38 +0000 Subject: [PATCH 2/4] fix: support larger contexts, refactor simulator transport helpers, simplify TCP Stop - Raise maxListLength to 32767 to support larger contexts from modern simulator configurations (Part 2, Section 10.3.1). - Rename transport/test directory to transport/testhelper and clean up test helper imports. - Refactor startProcess to initialize the process struct first and defer Close. - Simplify TCP Stop logic by only sending tpmStop and omitting platformStop. - Clean up port file polling in readPorts to use io/fs and time.Ticker. --- tpm2/reflect.go | 10 +- tpm2/test/activate_credential_test.go | 2 +- tpm2/test/audit_test.go | 2 +- tpm2/test/certify_test.go | 2 +- tpm2/test/clear_test.go | 2 +- tpm2/test/combined_context_test.go | 2 +- tpm2/test/commit_test.go | 2 +- tpm2/test/create_loaded_test.go | 2 +- tpm2/test/duplicate_test.go | 2 +- tpm2/test/ecdh_test.go | 2 +- tpm2/test/ek_test.go | 2 +- tpm2/test/evict_control_test.go | 2 +- tpm2/test/get_random_test.go | 2 +- tpm2/test/get_time_test.go | 2 +- tpm2/test/hash_sequence_hash_test.go | 2 +- tpm2/test/hierarchy_change_auth_test.go | 2 +- tpm2/test/hmac_start_test.go | 2 +- tpm2/test/hmac_test.go | 2 +- tpm2/test/import_test.go | 2 +- tpm2/test/load_external_test.go | 2 +- tpm2/test/names_test.go | 2 +- tpm2/test/nv_test.go | 2 +- tpm2/test/object_change_auth_test.go | 2 +- tpm2/test/pcr_test.go | 2 +- tpm2/test/policy_test.go | 2 +- tpm2/test/read_public_test.go | 2 +- tpm2/test/rsa_encryption_test.go | 2 +- tpm2/test/sealing_test.go | 2 +- tpm2/test/sign_test.go | 2 +- tpm2/test/symmetric_encryption_test.go | 2 +- tpm2/test/test_parms_test.go | 2 +- tpm2/transport/linuxtpm/linuxtpm_test.go | 2 +- .../transport/linuxudstpm/linuxudstpm_test.go | 2 +- tpm2/transport/tcp/tcp.go | 10 +- tpm2/transport/{test => testhelper}/helper.go | 0 tpm2/transport/{test => testhelper}/open.go | 104 ++++++++++-------- tpm2/transport/windowstpm/windowstpm_test.go | 2 +- 37 files changed, 103 insertions(+), 87 deletions(-) rename tpm2/transport/{test => testhelper}/helper.go (100%) rename tpm2/transport/{test => testhelper}/open.go (66%) diff --git a/tpm2/reflect.go b/tpm2/reflect.go index 714f1f57..81acc893 100644 --- a/tpm2/reflect.go +++ b/tpm2/reflect.go @@ -14,11 +14,11 @@ import ( ) const ( - // Chosen based on MAX_DIGEST_BUFFER, the length of the longest - // reasonable list returned by the reference implementation. - // The maxListLength must be greater than MAX_CONTEXT_SIZE = 1344, - // in order to allow for the unmarshalling of Context. - maxListLength uint32 = 4096 + // The maxListLength must be greater than MAX_CONTEXT_SIZE (which can be up to + // 4344 in modern reference implementations), in order to allow for the + // unmarshalling of Context. Under Part 2, Section 10.3.1 of the TPM + // specification, the maximum value of the size field in any TPM2B is 32767. + maxListLength uint32 = 0x7FFF // 32767 ) // execute sends the provided command and returns the TPM's response. diff --git a/tpm2/test/activate_credential_test.go b/tpm2/test/activate_credential_test.go index e67ab9c9..3c611f6e 100644 --- a/tpm2/test/activate_credential_test.go +++ b/tpm2/test/activate_credential_test.go @@ -6,7 +6,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) // p384Template is an SRK-like ECDH-P384 key based on the P384 EK template. diff --git a/tpm2/test/audit_test.go b/tpm2/test/audit_test.go index b6660b0c..74fb81cd 100644 --- a/tpm2/test/audit_test.go +++ b/tpm2/test/audit_test.go @@ -6,7 +6,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestAuditSession(t *testing.T) { diff --git a/tpm2/test/certify_test.go b/tpm2/test/certify_test.go index aa05f698..7fd313cf 100644 --- a/tpm2/test/certify_test.go +++ b/tpm2/test/certify_test.go @@ -9,7 +9,7 @@ import ( "github.com/google/go-cmp/cmp" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestCertify(t *testing.T) { diff --git a/tpm2/test/clear_test.go b/tpm2/test/clear_test.go index 1536103c..d3880749 100644 --- a/tpm2/test/clear_test.go +++ b/tpm2/test/clear_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestClear(t *testing.T) { diff --git a/tpm2/test/combined_context_test.go b/tpm2/test/combined_context_test.go index a2a3de31..02b8072b 100644 --- a/tpm2/test/combined_context_test.go +++ b/tpm2/test/combined_context_test.go @@ -7,7 +7,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func ReadPublicName(t *testing.T, handle TPMHandle, thetpm transport.TPM) TPM2BName { diff --git a/tpm2/test/commit_test.go b/tpm2/test/commit_test.go index e107af32..0046905a 100644 --- a/tpm2/test/commit_test.go +++ b/tpm2/test/commit_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestCommit(t *testing.T) { diff --git a/tpm2/test/create_loaded_test.go b/tpm2/test/create_loaded_test.go index 63bceb6d..0c79bb5d 100644 --- a/tpm2/test/create_loaded_test.go +++ b/tpm2/test/create_loaded_test.go @@ -5,7 +5,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func getDeriver(t *testing.T, thetpm transport.TPM) NamedHandle { diff --git a/tpm2/test/duplicate_test.go b/tpm2/test/duplicate_test.go index f0167722..edcef0cd 100644 --- a/tpm2/test/duplicate_test.go +++ b/tpm2/test/duplicate_test.go @@ -5,7 +5,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) // TestDuplicate creates an object under Owner->SRK and duplicates it to diff --git a/tpm2/test/ecdh_test.go b/tpm2/test/ecdh_test.go index d4cd8421..ae01429e 100644 --- a/tpm2/test/ecdh_test.go +++ b/tpm2/test/ecdh_test.go @@ -8,7 +8,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestECDH(t *testing.T) { diff --git a/tpm2/test/ek_test.go b/tpm2/test/ek_test.go index d71a46d2..1b00fe70 100644 --- a/tpm2/test/ek_test.go +++ b/tpm2/test/ek_test.go @@ -8,7 +8,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) // Decodes the provided hex strings into a byte array. Panics on non-hex chars. diff --git a/tpm2/test/evict_control_test.go b/tpm2/test/evict_control_test.go index a1c5f7ec..3ecfa45c 100644 --- a/tpm2/test/evict_control_test.go +++ b/tpm2/test/evict_control_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestEvictControl(t *testing.T) { diff --git a/tpm2/test/get_random_test.go b/tpm2/test/get_random_test.go index e26623f2..b6b24323 100644 --- a/tpm2/test/get_random_test.go +++ b/tpm2/test/get_random_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestGetRandom(t *testing.T) { diff --git a/tpm2/test/get_time_test.go b/tpm2/test/get_time_test.go index 8eefedca..380c8934 100644 --- a/tpm2/test/get_time_test.go +++ b/tpm2/test/get_time_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestGetTime(t *testing.T) { diff --git a/tpm2/test/hash_sequence_hash_test.go b/tpm2/test/hash_sequence_hash_test.go index 095489ed..e69705f3 100644 --- a/tpm2/test/hash_sequence_hash_test.go +++ b/tpm2/test/hash_sequence_hash_test.go @@ -9,7 +9,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestHash(t *testing.T) { diff --git a/tpm2/test/hierarchy_change_auth_test.go b/tpm2/test/hierarchy_change_auth_test.go index fefe39c3..27c34498 100644 --- a/tpm2/test/hierarchy_change_auth_test.go +++ b/tpm2/test/hierarchy_change_auth_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestHierarchyChangeAuth(t *testing.T) { diff --git a/tpm2/test/hmac_start_test.go b/tpm2/test/hmac_start_test.go index bdc5054b..8242c169 100644 --- a/tpm2/test/hmac_start_test.go +++ b/tpm2/test/hmac_start_test.go @@ -8,7 +8,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestHmacStart(t *testing.T) { diff --git a/tpm2/test/hmac_test.go b/tpm2/test/hmac_test.go index d8648e74..40aaa7aa 100644 --- a/tpm2/test/hmac_test.go +++ b/tpm2/test/hmac_test.go @@ -10,7 +10,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestHMAC(t *testing.T) { diff --git a/tpm2/test/import_test.go b/tpm2/test/import_test.go index 12790429..dbb3b9d9 100644 --- a/tpm2/test/import_test.go +++ b/tpm2/test/import_test.go @@ -8,7 +8,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) // This test checks that Import can import an object in the clear. diff --git a/tpm2/test/load_external_test.go b/tpm2/test/load_external_test.go index 92f54d3f..bf230f15 100644 --- a/tpm2/test/load_external_test.go +++ b/tpm2/test/load_external_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func decodeHex(t *testing.T, h string) []byte { diff --git a/tpm2/test/names_test.go b/tpm2/test/names_test.go index 5c44c1d3..ac6f0b17 100644 --- a/tpm2/test/names_test.go +++ b/tpm2/test/names_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestHandleName(t *testing.T) { diff --git a/tpm2/test/nv_test.go b/tpm2/test/nv_test.go index b9033839..7aa8d550 100644 --- a/tpm2/test/nv_test.go +++ b/tpm2/test/nv_test.go @@ -7,7 +7,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestNVAuthWrite(t *testing.T) { diff --git a/tpm2/test/object_change_auth_test.go b/tpm2/test/object_change_auth_test.go index 7513369d..f37c097d 100644 --- a/tpm2/test/object_change_auth_test.go +++ b/tpm2/test/object_change_auth_test.go @@ -6,7 +6,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestObjectChangeAuth(t *testing.T) { diff --git a/tpm2/test/pcr_test.go b/tpm2/test/pcr_test.go index edbfe5fe..26a2e27c 100644 --- a/tpm2/test/pcr_test.go +++ b/tpm2/test/pcr_test.go @@ -9,7 +9,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestPCRs(t *testing.T) { diff --git a/tpm2/test/policy_test.go b/tpm2/test/policy_test.go index 56aaf04b..95a1ad4b 100644 --- a/tpm2/test/policy_test.go +++ b/tpm2/test/policy_test.go @@ -7,7 +7,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) // This test isn't interesting, but it checks that you can omit the handles on `StartAuthSession`. diff --git a/tpm2/test/read_public_test.go b/tpm2/test/read_public_test.go index c6586adc..c60496d9 100644 --- a/tpm2/test/read_public_test.go +++ b/tpm2/test/read_public_test.go @@ -6,7 +6,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) // TestReadPublicKey compares the CreatePrimary response parameter outPublic with the output of ReadPublic outPublic. diff --git a/tpm2/test/rsa_encryption_test.go b/tpm2/test/rsa_encryption_test.go index 34ec42e9..5f7df43b 100644 --- a/tpm2/test/rsa_encryption_test.go +++ b/tpm2/test/rsa_encryption_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestRSAEncryption(t *testing.T) { diff --git a/tpm2/test/sealing_test.go b/tpm2/test/sealing_test.go index 0f8a781a..2e754673 100644 --- a/tpm2/test/sealing_test.go +++ b/tpm2/test/sealing_test.go @@ -6,7 +6,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) // Test creating and unsealing a sealed data blob with a password and HMAC. diff --git a/tpm2/test/sign_test.go b/tpm2/test/sign_test.go index b1179e0b..127b57c7 100644 --- a/tpm2/test/sign_test.go +++ b/tpm2/test/sign_test.go @@ -7,7 +7,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestSign(t *testing.T) { diff --git a/tpm2/test/symmetric_encryption_test.go b/tpm2/test/symmetric_encryption_test.go index 9c037e9e..3db9313b 100644 --- a/tpm2/test/symmetric_encryption_test.go +++ b/tpm2/test/symmetric_encryption_test.go @@ -9,7 +9,7 @@ import ( . "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) const maxDigestBuffer = 1024 diff --git a/tpm2/test/test_parms_test.go b/tpm2/test/test_parms_test.go index 4aef8d42..2624832d 100644 --- a/tpm2/test/test_parms_test.go +++ b/tpm2/test/test_parms_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/google/go-tpm/tpm2" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func TestTestParms(t *testing.T) { diff --git a/tpm2/transport/linuxtpm/linuxtpm_test.go b/tpm2/transport/linuxtpm/linuxtpm_test.go index 5a4696b4..7ac8c592 100644 --- a/tpm2/transport/linuxtpm/linuxtpm_test.go +++ b/tpm2/transport/linuxtpm/linuxtpm_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) func open(path string) func() (transport.TPMCloser, error) { diff --git a/tpm2/transport/linuxudstpm/linuxudstpm_test.go b/tpm2/transport/linuxudstpm/linuxudstpm_test.go index 04969a52..e99559f2 100644 --- a/tpm2/transport/linuxudstpm/linuxudstpm_test.go +++ b/tpm2/transport/linuxudstpm/linuxudstpm_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/google/go-tpm/tpm2/transport" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" ) var tpmSocket = flag.String("tpm_socket", "/dev/tpm0", "path to the TPM simulator UDS") diff --git a/tpm2/transport/tcp/tcp.go b/tpm2/transport/tcp/tcp.go index bea7e762..a012da62 100644 --- a/tpm2/transport/tcp/tcp.go +++ b/tpm2/transport/tcp/tcp.go @@ -207,14 +207,12 @@ func (t *TPM) Reset() error { // Stop tells the simulator process to exit. func (t *TPM) Stop() error { - var errs []error + // We only write tpmStop to the command socket because receiving it causes + // the simulator process to exit, which also kills the platform socket. if err := binary.Write(t.cmd, binary.BigEndian, tpmStop); err != nil { - errs = append(errs, fmt.Errorf("could not write STOP to command service: %w", err)) + return fmt.Errorf("could not write STOP to command service: %w", err) } - if err := binary.Write(t.plat, binary.BigEndian, platformStop); err != nil { - errs = append(errs, fmt.Errorf("could not write STOP to platform service: %w", err)) - } - return errors.Join(errs...) + return nil } diff --git a/tpm2/transport/test/helper.go b/tpm2/transport/testhelper/helper.go similarity index 100% rename from tpm2/transport/test/helper.go rename to tpm2/transport/testhelper/helper.go diff --git a/tpm2/transport/test/open.go b/tpm2/transport/testhelper/open.go similarity index 66% rename from tpm2/transport/test/open.go rename to tpm2/transport/testhelper/open.go index c3c89bbe..b3a4d126 100644 --- a/tpm2/transport/test/open.go +++ b/tpm2/transport/testhelper/open.go @@ -3,9 +3,9 @@ package testhelper import ( "flag" "fmt" + "io/fs" "os" "os/exec" - "path/filepath" "strconv" "strings" "testing" @@ -29,29 +29,26 @@ type process struct { } func startProcess(tb testing.TB, path string) *process { - dir, err := os.MkdirTemp("", "tpm-sim-*") - if err != nil { - tb.Fatalf("failed to create temp dir: %v", err) - } - + p := &process{tb: tb} keep := false defer func() { if !keep { - os.RemoveAll(dir) + p.Close() } }() + dir, err := os.MkdirTemp("", "tpm-sim-*") + if err != nil { + tb.Fatalf("failed to create temp dir: %v", err) + } + p.dir = dir + cmd := exec.Command(path, "--pick_ports") cmd.Dir = dir if err := cmd.Start(); err != nil { tb.Fatalf("failed to start simulator process: %v", err) } - defer func() { - if !keep { - cmd.Process.Kill() - cmd.Wait() - } - }() + p.cmd = cmd cmdPort, platPort, err := readPorts(dir) if err != nil { @@ -64,11 +61,7 @@ func startProcess(tb testing.TB, path string) *process { if err != nil { tb.Fatalf("failed to open TCP connection to simulator: %v", err) } - defer func() { - if !keep { - conn.Close() - } - }() + p.conn = conn if err := conn.PowerOn(); err != nil { tb.Fatalf("failed to power on simulator: %v", err) @@ -82,12 +75,7 @@ func startProcess(tb testing.TB, path string) *process { } keep = true - return &process{ - tb: tb, - cmd: cmd, - dir: dir, - conn: conn, - } + return p } func (p *process) Send(cmd []byte) ([]byte, error) { @@ -105,19 +93,37 @@ func (p *process) Send(cmd []byte) ([]byte, error) { // Close implements the TPMCloser interface. func (p *process) Close() error { var err error - if err = p.conn.Stop(); err != nil { - p.tb.Errorf("failed to stop simulator: %v", err) - } - if err = p.conn.Close(); err != nil { - p.tb.Errorf("failed to close simulator connection: %v", err) + var killed bool + if p.conn != nil { + if err = p.conn.Stop(); err != nil { + p.tb.Errorf("failed to stop simulator: %v", err) + if p.cmd != nil && p.cmd.Process != nil { + p.cmd.Process.Kill() + killed = true + } + } + if err = p.conn.Close(); err != nil { + p.tb.Errorf("failed to close simulator connection: %v", err) + } + } else { + if p.cmd != nil && p.cmd.Process != nil { + p.cmd.Process.Kill() + killed = true + } } - if err = p.cmd.Wait(); err != nil { - p.tb.Errorf("failed to wait for simulator process: %v", err) + if p.cmd != nil { + if werr := p.cmd.Wait(); werr != nil && !killed { + p.tb.Errorf("failed to wait for simulator process: %v", werr) + err = werr + } } - if err = os.RemoveAll(p.dir); err != nil { - p.tb.Errorf("failed to remove temp dir %q: %v", p.dir, err) + if p.dir != "" { + if derr := os.RemoveAll(p.dir); derr != nil { + p.tb.Errorf("failed to remove temp dir %q: %v", p.dir, derr) + err = derr + } } - return err // Report all errors but only return the last one + return err } func Open(tb testing.TB) transport.TPMCloser { @@ -132,14 +138,11 @@ func Open(tb testing.TB) transport.TPMCloser { } func readPorts(dir string) (cmdPort, platPort int, err error) { - deadline := time.Now().Add(5 * time.Second) - for { - if time.Now().After(deadline) { - return 0, 0, fmt.Errorf("timed out waiting for simulator port files") - } + fsys := os.DirFS(dir) - cmdPortBytes, err1 := os.ReadFile(filepath.Join(dir, "command.port")) - platPortBytes, err2 := os.ReadFile(filepath.Join(dir, "platform.port")) + tryRead := func() (int, int, bool) { + cmdPortBytes, err1 := fs.ReadFile(fsys, "command.port") + platPortBytes, err2 := fs.ReadFile(fsys, "platform.port") if err1 == nil && err2 == nil { cmdPortStr := strings.TrimSpace(string(cmdPortBytes)) platPortStr := strings.TrimSpace(string(platPortBytes)) @@ -147,10 +150,25 @@ func readPorts(dir string) (cmdPort, platPort int, err error) { cmdPort, err1 := strconv.Atoi(cmdPortStr) platPort, err2 := strconv.Atoi(platPortStr) if err1 == nil && err2 == nil { - return cmdPort, platPort, nil + return cmdPort, platPort, true } } } - time.Sleep(50 * time.Millisecond) + return 0, 0, false + } + + timeout := time.After(5 * time.Second) + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + return 0, 0, fmt.Errorf("timed out waiting for simulator port files") + case <-ticker.C: + if cmdPort, platPort, ok := tryRead(); ok { + return cmdPort, platPort, nil + } + } } } diff --git a/tpm2/transport/windowstpm/windowstpm_test.go b/tpm2/transport/windowstpm/windowstpm_test.go index 278e71c7..38c61629 100644 --- a/tpm2/transport/windowstpm/windowstpm_test.go +++ b/tpm2/transport/windowstpm/windowstpm_test.go @@ -6,7 +6,7 @@ import ( "os" "testing" - testhelper "github.com/google/go-tpm/tpm2/transport/test" + "github.com/google/go-tpm/tpm2/transport/testhelper" "github.com/google/go-tpm/tpmutil/tbs" ) From 0d6d58ca80189364c8218e4031233f16387b9818 Mon Sep 17 00:00:00 2001 From: Joe Richey Date: Sat, 30 May 2026 00:29:01 +0000 Subject: [PATCH 3/4] Cleanup startProcess Signed-off-by: Joe Richey --- tpm2/transport/testhelper/open.go | 102 +++++++++++++----------------- 1 file changed, 45 insertions(+), 57 deletions(-) diff --git a/tpm2/transport/testhelper/open.go b/tpm2/transport/testhelper/open.go index b3a4d126..0c609252 100644 --- a/tpm2/transport/testhelper/open.go +++ b/tpm2/transport/testhelper/open.go @@ -24,12 +24,21 @@ var ( type process struct { tb testing.TB cmd *exec.Cmd - dir string conn *tcp.TPM } func startProcess(tb testing.TB, path string) *process { - p := &process{tb: tb} + dir, err := os.MkdirTemp("", "tpm-sim-*") + if err != nil { + tb.Fatalf("failed to create temp dir: %v", err) + } + + p := &process{ + tb: tb, + cmd: exec.Command(path, "--pick_ports"), + } + p.cmd.Dir = dir + keep := false defer func() { if !keep { @@ -37,40 +46,28 @@ func startProcess(tb testing.TB, path string) *process { } }() - dir, err := os.MkdirTemp("", "tpm-sim-*") - if err != nil { - tb.Fatalf("failed to create temp dir: %v", err) - } - p.dir = dir - - cmd := exec.Command(path, "--pick_ports") - cmd.Dir = dir - if err := cmd.Start(); err != nil { + if err := p.cmd.Start(); err != nil { tb.Fatalf("failed to start simulator process: %v", err) } - p.cmd = cmd - cmdPort, platPort, err := readPorts(dir) + cPort, pPort, err := readPorts(p.cmd.Dir) if err != nil { tb.Fatalf("failed to read ports: %v", err) } - conn, err := tcp.Open(tcp.Config{ - CommandAddress: fmt.Sprintf("127.0.0.1:%d", cmdPort), - PlatformAddress: fmt.Sprintf("127.0.0.1:%d", platPort), + p.conn, err = tcp.Open(tcp.Config{ + CommandAddress: fmt.Sprintf("127.0.0.1:%d", cPort), + PlatformAddress: fmt.Sprintf("127.0.0.1:%d", pPort), }) if err != nil { tb.Fatalf("failed to open TCP connection to simulator: %v", err) } - p.conn = conn - if err := conn.PowerOn(); err != nil { + if err := p.conn.PowerOn(); err != nil { tb.Fatalf("failed to power on simulator: %v", err) } - _, err = tpm2.Startup{ - StartupType: tpm2.TPMSUClear, - }.Execute(conn) - if err != nil { + startupCmd := tpm2.Startup{StartupType: tpm2.TPMSUClear} + if _, err = startupCmd.Execute(p.conn); err != nil { tb.Fatalf("failed to startup simulator: %v", err) } @@ -93,36 +90,30 @@ func (p *process) Send(cmd []byte) ([]byte, error) { // Close implements the TPMCloser interface. func (p *process) Close() error { var err error - var killed bool + var stopped bool if p.conn != nil { - if err = p.conn.Stop(); err != nil { + if err = p.conn.Stop(); err == nil { + stopped = true + } else { p.tb.Errorf("failed to stop simulator: %v", err) - if p.cmd != nil && p.cmd.Process != nil { - p.cmd.Process.Kill() - killed = true - } } if err = p.conn.Close(); err != nil { p.tb.Errorf("failed to close simulator connection: %v", err) } - } else { - if p.cmd != nil && p.cmd.Process != nil { - p.cmd.Process.Kill() - killed = true - } } - if p.cmd != nil { - if werr := p.cmd.Wait(); werr != nil && !killed { - p.tb.Errorf("failed to wait for simulator process: %v", werr) - err = werr + + if stopped { + if err = p.cmd.Wait(); err != nil { + p.tb.Errorf("failed to wait for simulator process: %v", err) } - } - if p.dir != "" { - if derr := os.RemoveAll(p.dir); derr != nil { - p.tb.Errorf("failed to remove temp dir %q: %v", p.dir, derr) - err = derr + } else if p.cmd.Process != nil && err = p.cmd.Process.Kill(); err != nil { + p.tb.Errorf("failed to kill simulator process: %v", err) } } + + if err = os.RemoveAll(p.cmd.Dir); err != nil { + p.tb.Errorf("failed to remove temp dir %q: %v", p.cmd.Dir, err) + } return err } @@ -141,20 +132,17 @@ func readPorts(dir string) (cmdPort, platPort int, err error) { fsys := os.DirFS(dir) tryRead := func() (int, int, bool) { - cmdPortBytes, err1 := fs.ReadFile(fsys, "command.port") - platPortBytes, err2 := fs.ReadFile(fsys, "platform.port") - if err1 == nil && err2 == nil { - cmdPortStr := strings.TrimSpace(string(cmdPortBytes)) - platPortStr := strings.TrimSpace(string(platPortBytes)) - if cmdPortStr != "" && platPortStr != "" { - cmdPort, err1 := strconv.Atoi(cmdPortStr) - platPort, err2 := strconv.Atoi(platPortStr) - if err1 == nil && err2 == nil { - return cmdPort, platPort, true - } - } + cBytes, err1 := fs.ReadFile(fsys, "command.port") + pBytes, err2 := fs.ReadFile(fsys, "platform.port") + if err1 != nil || err2 != nil { + return 0, 0, false + } + cPort, err1 := strconv.Atoi(strings.TrimSpace(string(cBytes))) + pPort, err2 := strconv.Atoi(strings.TrimSpace(string(pBytes))) + if err1 != nil || err2 != nil { + return 0, 0, false } - return 0, 0, false + return cPort, pPort, true } timeout := time.After(5 * time.Second) @@ -166,8 +154,8 @@ func readPorts(dir string) (cmdPort, platPort int, err error) { case <-timeout: return 0, 0, fmt.Errorf("timed out waiting for simulator port files") case <-ticker.C: - if cmdPort, platPort, ok := tryRead(); ok { - return cmdPort, platPort, nil + if cPort, pPort, ok := tryRead(); ok { + return cPort, pPort, nil } } } From 588fa0dc0ecbdf95e412698f1688f296ad872787 Mon Sep 17 00:00:00 2001 From: Megan Lu Date: Tue, 2 Jun 2026 20:56:25 +0000 Subject: [PATCH 4/4] Fix CI checks --- tpm2/transport/tcp/tcp.go | 1 - tpm2/transport/testhelper/open.go | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tpm2/transport/tcp/tcp.go b/tpm2/transport/tcp/tcp.go index a012da62..8204e572 100644 --- a/tpm2/transport/tcp/tcp.go +++ b/tpm2/transport/tcp/tcp.go @@ -215,7 +215,6 @@ func (t *TPM) Stop() error { return nil } - // Config provides the connection information for a running TCP TPM. type Config struct { // CommandAddress is the full host:port address of the Command server, e.g., diff --git a/tpm2/transport/testhelper/open.go b/tpm2/transport/testhelper/open.go index 0c609252..d5740747 100644 --- a/tpm2/transport/testhelper/open.go +++ b/tpm2/transport/testhelper/open.go @@ -34,7 +34,7 @@ func startProcess(tb testing.TB, path string) *process { } p := &process{ - tb: tb, + tb: tb, cmd: exec.Command(path, "--pick_ports"), } p.cmd.Dir = dir @@ -106,7 +106,8 @@ func (p *process) Close() error { if err = p.cmd.Wait(); err != nil { p.tb.Errorf("failed to wait for simulator process: %v", err) } - } else if p.cmd.Process != nil && err = p.cmd.Process.Kill(); err != nil { + } else if p.cmd.Process != nil { + if err = p.cmd.Process.Kill(); err != nil { p.tb.Errorf("failed to kill simulator process: %v", err) } }