diff --git a/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQLitePullConsumerWithTraceTest.java b/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQLitePullConsumerWithTraceTest.java index e0573bdfb0b..c4065cf8527 100644 --- a/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQLitePullConsumerWithTraceTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQLitePullConsumerWithTraceTest.java @@ -27,7 +27,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentMap; import org.apache.commons.lang3.reflect.FieldUtils; -import org.apache.rocketmq.client.ClientConfig; import org.apache.rocketmq.client.consumer.DefaultLitePullConsumer; import org.apache.rocketmq.client.consumer.PullCallback; import org.apache.rocketmq.client.consumer.PullResult; @@ -60,13 +59,13 @@ import org.apache.rocketmq.remoting.protocol.route.BrokerData; import org.apache.rocketmq.remoting.protocol.route.QueueData; import org.apache.rocketmq.remoting.protocol.route.TopicRouteData; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.Spy; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnitRunner; import org.mockito.stubbing.Answer; @@ -78,14 +77,16 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class DefaultMQLitePullConsumerWithTraceTest { - @Spy - private MQClientInstance mQClientFactory = MQClientManager.getInstance().getOrCreateMQClientInstance(new ClientConfig()); + private MQClientInstance mQClientFactory; + private MQClientInstance mqClientInstance; + private MQClientInstance traceMqClientInstance; @Mock private MQClientAPIImpl mQClientAPIImpl; @@ -111,12 +112,45 @@ public static void setUpEnv() { @Before public void init() throws Exception { - Field field = MQClientInstance.class.getDeclaredField("rebalanceService"); - field.setAccessible(true); - RebalanceService rebalanceService = (RebalanceService) field.get(mQClientFactory); - field = RebalanceService.class.getDeclaredField("waitInterval"); - field.setAccessible(true); - field.set(rebalanceService, 100); + ConcurrentMap factoryTable = + (ConcurrentMap) FieldUtils.readDeclaredField( + MQClientManager.getInstance(), "factoryTable", true); + factoryTable.forEach((clientId, instance) -> instance.shutdown()); + factoryTable.clear(); + mQClientFactory = null; + mqClientInstance = null; + traceMqClientInstance = null; + asyncTraceDispatcher = null; + traceProducer = null; + rebalanceImpl = null; + offsetStore = null; + litePullConsumerImpl = null; + } + + @After + public void destroy() { + if (traceProducer != null) { + MQClientInstance traceClientFactory = traceProducer.getDefaultMQProducerImpl().getMqClientFactory(); + traceClientFactory.unregisterProducer(producerGroupTraceTemp); + traceClientFactory.unregisterProducer(traceProducer.getProducerGroup()); + } + + if (traceMqClientInstance != null && traceProducer != null) { + traceMqClientInstance.unregisterProducer(traceProducer.getProducerGroup()); + traceMqClientInstance.shutdown(); + } + + if (litePullConsumerImpl != null) { + if (mQClientFactory != null) { + mQClientFactory.unregisterConsumer(litePullConsumerImpl.groupName()); + mQClientFactory.shutdown(); + } + + if (mqClientInstance != null && mqClientInstance != mQClientFactory) { + mqClientInstance.unregisterConsumer(litePullConsumerImpl.groupName()); + mqClientInstance.shutdown(); + } + } } @Test @@ -126,8 +160,8 @@ public void testSubscribe_PollMessageSuccess_WithDefaultTraceTopic() throws Exce Set messageQueueSet = new HashSet<>(); messageQueueSet.add(createMessageQueue()); litePullConsumerImpl.updateTopicSubscribeInfo(topic, messageQueueSet); - litePullConsumer.setPollTimeoutMillis(20 * 1000); - List result = litePullConsumer.poll(); + List result = pollUntilFound(litePullConsumer); + assertThat(result).isNotEmpty(); assertThat(result.get(0).getTopic()).isEqualTo(topic); assertThat(result.get(0).getBody()).isEqualTo(new byte[] {'a'}); } finally { @@ -142,8 +176,8 @@ public void testSubscribe_PollMessageSuccess_WithCustomizedTraceTopic() throws E Set messageQueueSet = new HashSet<>(); messageQueueSet.add(createMessageQueue()); litePullConsumerImpl.updateTopicSubscribeInfo(topic, messageQueueSet); - litePullConsumer.setPollTimeoutMillis(20 * 1000); - List result = litePullConsumer.poll(); + List result = pollUntilFound(litePullConsumer); + assertThat(result).isNotEmpty(); assertThat(result.get(0).getTopic()).isEqualTo(topic); assertThat(result.get(0).getBody()).isEqualTo(new byte[] {'a'}); } finally { @@ -154,11 +188,15 @@ public void testSubscribe_PollMessageSuccess_WithCustomizedTraceTopic() throws E @Test public void testLitePullConsumerWithTraceTLS() throws Exception { DefaultLitePullConsumer consumer = new DefaultLitePullConsumer("consumerGroup"); - consumer.setUseTLS(true); - consumer.setEnableMsgTrace(true); - consumer.start(); - AsyncTraceDispatcher asyncTraceDispatcher = (AsyncTraceDispatcher) consumer.getTraceDispatcher(); - Assert.assertTrue(asyncTraceDispatcher.getTraceProducer().isUseTLS()); + try { + consumer.setUseTLS(true); + consumer.setEnableMsgTrace(true); + consumer.start(); + AsyncTraceDispatcher asyncTraceDispatcher = (AsyncTraceDispatcher) consumer.getTraceDispatcher(); + Assert.assertTrue(asyncTraceDispatcher.getTraceProducer().isUseTLS()); + } finally { + consumer.shutdown(); + } } private DefaultLitePullConsumer createLitePullConsumerWithDefaultTraceTopic() throws Exception { @@ -192,8 +230,18 @@ private void initDefaultLitePullConsumer(DefaultLitePullConsumer litePullConsume litePullConsumerImpl = (DefaultLitePullConsumerImpl) field.get(litePullConsumer); field = DefaultLitePullConsumerImpl.class.getDeclaredField("mQClientFactory"); field.setAccessible(true); + mqClientInstance = (MQClientInstance) field.get(litePullConsumerImpl); + mQClientFactory = spy(mqClientInstance); + mQClientFactory.getClientConfig().setDecodeReadBody(true); field.set(litePullConsumerImpl, mQClientFactory); + field = MQClientInstance.class.getDeclaredField("rebalanceService"); + field.setAccessible(true); + RebalanceService rebalanceService = (RebalanceService) field.get(mQClientFactory); + field = RebalanceService.class.getDeclaredField("waitInterval"); + field.setAccessible(true); + field.set(rebalanceService, 100); + PullAPIWrapper pullAPIWrapper = litePullConsumerImpl.getPullAPIWrapper(); field = PullAPIWrapper.class.getDeclaredField("mQClientFactory"); field.setAccessible(true); @@ -201,6 +249,7 @@ private void initDefaultLitePullConsumer(DefaultLitePullConsumer litePullConsume Field fieldTrace = DefaultMQProducerImpl.class.getDeclaredField("mQClientFactory"); fieldTrace.setAccessible(true); + traceMqClientInstance = traceProducer.getDefaultMQProducerImpl().getMqClientFactory(); fieldTrace.set(traceProducer.getDefaultMQProducerImpl(), mQClientFactory); field = MQClientInstance.class.getDeclaredField("mQClientAPIImpl"); @@ -225,6 +274,8 @@ private void initDefaultLitePullConsumer(DefaultLitePullConsumer litePullConsume traceProducer.getDefaultMQProducerImpl().getMqClientFactory().registerProducer(producerGroupTraceTemp, traceProducer.getDefaultMQProducerImpl()); + lenient().when(mQClientAPIImpl.getTopicRouteInfoFromNameServer(anyString(), anyLong())).thenReturn(createTopicRoute()); + when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) .thenAnswer(new Answer() { @@ -252,6 +303,19 @@ public Object answer(InvocationOnMock mock) throws Throwable { } + private List pollUntilFound(DefaultLitePullConsumer litePullConsumer) { + litePullConsumer.setPollTimeoutMillis(1000); + long deadline = System.currentTimeMillis() + 20 * 1000; + List result = Collections.emptyList(); + while (System.currentTimeMillis() < deadline) { + result = litePullConsumer.poll(); + if (!result.isEmpty()) { + return result; + } + } + return result; + } + private PullResultExt createPullResult(PullMessageRequestHeader requestHeader, PullStatus pullStatus, List messageExtList) throws Exception { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); diff --git a/test/src/test/java/org/apache/rocketmq/test/grpc/v2/GrpcBaseIT.java b/test/src/test/java/org/apache/rocketmq/test/grpc/v2/GrpcBaseIT.java index 1b1abd0a101..90b02f6a56e 100644 --- a/test/src/test/java/org/apache/rocketmq/test/grpc/v2/GrpcBaseIT.java +++ b/test/src/test/java/org/apache/rocketmq/test/grpc/v2/GrpcBaseIT.java @@ -636,6 +636,7 @@ public void testConsumeOrderly() throws Exception { } public void testSimpleConsumerSendAndRecvPriorityMessage() throws Exception { + brokerController1.getBrokerConfig().setPriorityOrderAsc(true); String topic = initTopicOnSampleTopicBroker(BROKER1_NAME, TopicMessageType.PRIORITY); String group = MQRandomUtils.getRandomConsumerGroup(); initConsumerGroup(group);