diff --git a/pulsar-broker-auth-sasl/src/test/java/org/apache/pulsar/broker/authentication/SaslAuthenticateTest.java b/pulsar-broker-auth-sasl/src/test/java/org/apache/pulsar/broker/authentication/SaslAuthenticateTest.java index bb8595dacedd8..4e111eb3aeee4 100644 --- a/pulsar-broker-auth-sasl/src/test/java/org/apache/pulsar/broker/authentication/SaslAuthenticateTest.java +++ b/pulsar-broker-auth-sasl/src/test/java/org/apache/pulsar/broker/authentication/SaslAuthenticateTest.java @@ -61,13 +61,12 @@ import org.apache.pulsar.common.api.AuthData; import org.apache.pulsar.common.sasl.SaslConstants; import org.apache.pulsar.common.util.ObjectMapperFactory; -import org.testng.Assert; +import org.awaitility.Awaitility; import org.testng.annotations.AfterClass; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeClass; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import org.testng.collections.CollectionUtils; @CustomLog public class SaslAuthenticateTest extends ProducerConsumerBase { @@ -220,7 +219,7 @@ protected void setup() throws Exception { @Override protected void cleanup() throws Exception { FileUtils.deleteQuietly(secretKeyFile); - Assert.assertFalse(secretKeyFile.exists()); + assertFalse(secretKeyFile.exists()); super.internalCleanup(); } @@ -310,49 +309,56 @@ public void testSaslServerAndClientAuth() throws Exception { @Test @SuppressWarnings("unchecked") - public void testSaslOnlyAuthFirstStage() throws Exception { + public void testSaslOnlyAuthFirstStageKeepsInflightContextsBeforeExpiry() throws Exception { @Cleanup AuthenticationProviderSasl saslServer = new AuthenticationProviderSasl(); - // The cache expiration time is set to 500ms. Residual auth info should be cleaned up - conf.setInflightSaslContextExpiryMs(500); + conf.setInflightSaslContextExpiryMs(Integer.MAX_VALUE); saslServer.initialize(AuthenticationProvider.Context.builder().config(conf).build()); HttpServletRequest servletRequest = mock(HttpServletRequest.class); - doReturn("Init").when(servletRequest).getHeader("State"); - // 10 clients only do one-stage verification, resulting in 10 auth info remaining in memory + doReturn(SaslConstants.SASL_STATE_CLIENT_INIT).when(servletRequest).getHeader(SaslConstants.SASL_HEADER_STATE); for (int i = 0; i < 10; i++) { - AuthenticationDataProvider dataProvider = authSasl.getAuthData("localhost"); + AuthenticationDataProvider dataProvider = authSasl.getAuthData(localHostname); AuthData initData1 = dataProvider.authenticate(AuthData.INIT_AUTH_DATA); doReturn(Base64.getEncoder().encodeToString(initData1.getBytes())).when( - servletRequest).getHeader("SASL-Token"); - doReturn(String.valueOf(i)).when(servletRequest).getHeader("SASL-Server-ID"); + servletRequest).getHeader(SaslConstants.SASL_AUTH_TOKEN); + doReturn(String.valueOf(i)).when(servletRequest).getHeader(SaslConstants.SASL_STATE_SERVER); saslServer.authenticateHttpRequest(servletRequest, mock(HttpServletResponse.class)); } + Field field = AuthenticationProviderSasl.class.getDeclaredField("authStates"); field.setAccessible(true); Cache cache = (Cache) field.get(saslServer); assertEquals(cache.asMap().size(), 10); - // Add more auth info into memory + } + + @Test + @SuppressWarnings("unchecked") + public void testSaslOnlyAuthFirstStageExpiresResidualContexts() throws Exception { + @Cleanup + AuthenticationProviderSasl saslServer = new AuthenticationProviderSasl(); + conf.setInflightSaslContextExpiryMs(50); + saslServer.initialize(AuthenticationProvider.Context.builder().config(conf).build()); + + HttpServletRequest servletRequest = mock(HttpServletRequest.class); + doReturn(SaslConstants.SASL_STATE_CLIENT_INIT).when(servletRequest).getHeader(SaslConstants.SASL_HEADER_STATE); for (int i = 0; i < 10; i++) { - AuthenticationDataProvider dataProvider = authSasl.getAuthData("localhost"); + AuthenticationDataProvider dataProvider = authSasl.getAuthData(localHostname); AuthData initData1 = dataProvider.authenticate(AuthData.INIT_AUTH_DATA); doReturn(Base64.getEncoder().encodeToString(initData1.getBytes())).when( - servletRequest).getHeader("SASL-Token"); - doReturn(String.valueOf(10 + i)).when(servletRequest).getHeader("SASL-Server-ID"); + servletRequest).getHeader(SaslConstants.SASL_AUTH_TOKEN); + doReturn(String.valueOf(i)).when(servletRequest).getHeader(SaslConstants.SASL_STATE_SERVER); saslServer.authenticateHttpRequest(servletRequest, mock(HttpServletResponse.class)); } - long start = System.currentTimeMillis(); - while (true) { - if (System.currentTimeMillis() - start > 5000) { - fail(); - } - cache = (Cache) field.get(saslServer); - // Residual auth info should be cleaned up - if (CollectionUtils.hasElements(cache.asMap())) { - break; + + Field field = AuthenticationProviderSasl.class.getDeclaredField("authStates"); + field.setAccessible(true); + Cache cache = (Cache) field.get(saslServer); + Awaitility.await().atMost(5, TimeUnit.SECONDS).pollDelay(100, TimeUnit.MILLISECONDS).untilAsserted(() -> { + for (int i = 0; i < 10; i++) { + assertNull(cache.getIfPresent(((long) i))); } - Thread.sleep(5); - } + }); } @Test