diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java index c1c7d5f7df..61eb76699b 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java @@ -22,12 +22,10 @@ import android.annotation.NonNull; import android.annotation.Nullable; import android.annotation.RequiresApi; import android.net.LinkAddress; -import android.net.nsd.NsdManager; import android.net.nsd.NsdServiceInfo; import android.os.Build; import android.os.Handler; import android.os.Looper; -import android.util.ArraySet; import com.android.internal.annotations.VisibleForTesting; import com.android.net.module.util.HexDump; @@ -284,6 +282,7 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand if (!mRecordRepository.hasActiveService(id)) return; mProber.stop(id); mAnnouncer.stop(id); + final String hostname = mRecordRepository.getHostnameForServiceId(id); final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id); if (exitInfo != null) { // This effectively schedules onAllServicesRemoved(), as it is to be called when the @@ -303,6 +302,24 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand } }); } + // Re-probe/re-announce the services which have the same custom hostname. These services + // were probed/announced using host addresses which were just removed so they should be + // re-probed/re-announced without those addresses. + if (hostname != null) { + final List probingInfos = + mRecordRepository.restartProbingForHostname(hostname); + for (MdnsProber.ProbingInfo probingInfo : probingInfos) { + mProber.stop(probingInfo.getServiceId()); + mProber.startProbing(probingInfo); + } + final List announcementInfos = + mRecordRepository.restartAnnouncingForHostname(hostname); + for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) { + mAnnouncer.stop(announcementInfo.getServiceId()); + mAnnouncer.startSending( + announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */); + } + } } /** diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java index ac64c3a9e8..073e465d82 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java @@ -925,22 +925,79 @@ public class MdnsRecordRepository { } } + @Nullable + public String getHostnameForServiceId(int id) { + ServiceRegistration registration = mServices.get(id); + if (registration == null) { + return null; + } + return registration.serviceInfo.getHostname(); + } + + /** + * Restart probing the services which are being probed and using the given custom hostname. + * + * @return The list of {@link MdnsProber.ProbingInfo} to be used by advertiser. + */ + public List restartProbingForHostname(@NonNull String hostname) { + final ArrayList probingInfos = new ArrayList<>(); + forEachActiveServiceRegistrationWithHostname( + hostname, + (id, registration) -> { + if (!registration.isProbing) { + return; + } + probingInfos.add(makeProbingInfo(id, registration)); + }); + return probingInfos; + } + + /** + * Restart announcing the services which are using the given custom hostname. + * + * @return The list of {@link MdnsAnnouncer.AnnouncementInfo} to be used by advertiser. + */ + public List restartAnnouncingForHostname( + @NonNull String hostname) { + final ArrayList announcementInfos = new ArrayList<>(); + forEachActiveServiceRegistrationWithHostname( + hostname, + (id, registration) -> { + if (registration.isProbing) { + return; + } + announcementInfos.add(makeAnnouncementInfo(id, registration)); + }); + return announcementInfos; + } + /** * Called to indicate that probing succeeded for a service. + * * @param probeSuccessInfo The successful probing info. * @return The {@link MdnsAnnouncer.AnnouncementInfo} to send, now that probing has succeeded. */ public MdnsAnnouncer.AnnouncementInfo onProbingSucceeded( - MdnsProber.ProbingInfo probeSuccessInfo) - throws IOException { - - int serviceId = probeSuccessInfo.getServiceId(); + MdnsProber.ProbingInfo probeSuccessInfo) throws IOException { + final int serviceId = probeSuccessInfo.getServiceId(); final ServiceRegistration registration = mServices.get(serviceId); if (registration == null) { throw new IOException("Service is not registered: " + serviceId); } registration.setProbing(false); + return makeAnnouncementInfo(serviceId, registration); + } + + /** + * Make the announcement info of the given service ID. + * + * @param serviceId The service ID. + * @param registration The service registration. + * @return The {@link MdnsAnnouncer.AnnouncementInfo} of the given service ID. + */ + private MdnsAnnouncer.AnnouncementInfo makeAnnouncementInfo( + int serviceId, ServiceRegistration registration) { final Set answersSet = new LinkedHashSet<>(); final ArrayList additionalAnswers = new ArrayList<>(); @@ -972,8 +1029,8 @@ public class MdnsRecordRepository { addNsecRecordsForUniqueNames(additionalAnswers, mGeneralRecords.iterator(), registration.allRecords.iterator()); - return new MdnsAnnouncer.AnnouncementInfo( - probeSuccessInfo.getServiceId(), new ArrayList<>(answersSet), additionalAnswers); + return new MdnsAnnouncer.AnnouncementInfo(serviceId, + new ArrayList<>(answersSet), additionalAnswers); } /** diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt index 61117dff71..6dd4857496 100644 --- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt +++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt @@ -2205,6 +2205,66 @@ class NsdManagerTest { } } + @Test + fun testAdvertisingAndDiscovery_reregisterCustomHostWithDifferentAddresses_newAddressesFound() { + val si1 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.hostname = customHostname + it.hostAddresses = listOf( + parseNumericAddress("192.0.2.23"), + parseNumericAddress("2001:db8::1")) + } + val si2 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.serviceName = serviceName + it.serviceType = serviceType + it.hostname = customHostname + it.port = TEST_PORT + } + val si3 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.hostname = customHostname + it.hostAddresses = listOf( + parseNumericAddress("192.0.2.24"), + parseNumericAddress("2001:db8::2")) + } + + val registrationRecord1 = NsdRegistrationRecord() + val registrationRecord2 = NsdRegistrationRecord() + val registrationRecord3 = NsdRegistrationRecord() + + val discoveryRecord = NsdDiscoveryRecord() + + tryTest { + registerService(registrationRecord1, si1) + registerService(registrationRecord2, si2) + + nsdManager.unregisterService(registrationRecord1) + registrationRecord1.expectCallback() + + registerService(registrationRecord3, si3) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, Executor { it.run() }, discoveryRecord) + val discoveredInfo = discoveryRecord.waitForServiceDiscovered( + serviceName, serviceType, testNetwork1.network) + val resolvedInfo = resolveService(discoveredInfo) + + assertEquals(serviceName, discoveredInfo.serviceName) + assertEquals(TEST_PORT, resolvedInfo.port) + assertEquals(customHostname, resolvedInfo.hostname) + assertAddressEquals( + listOf(parseNumericAddress("192.0.2.24"), parseNumericAddress("2001:db8::2")), + resolvedInfo.hostAddresses) + } cleanupStep { + nsdManager.stopServiceDiscovery(discoveryRecord) + discoveryRecord.expectCallbackEventually() + } cleanup { + nsdManager.unregisterService(registrationRecord2) + nsdManager.unregisterService(registrationRecord3) + } + } + @Test fun testServiceTypeClientRemovedAfterSocketDestroyed() { val si = makeTestServiceInfo(testNetwork1.network) diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt index 69fec859d6..7ac7beed38 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt @@ -18,7 +18,6 @@ package com.android.server.connectivity.mdns import android.net.InetAddresses.parseNumericAddress import android.net.LinkAddress -import android.net.nsd.NsdManager import android.net.nsd.NsdServiceInfo import android.os.Build import android.os.HandlerThread @@ -48,6 +47,7 @@ import org.mockito.Mockito.any import org.mockito.Mockito.anyInt import org.mockito.Mockito.anyString import org.mockito.Mockito.argThat +import org.mockito.Mockito.atLeastOnce import org.mockito.Mockito.doAnswer import org.mockito.Mockito.doReturn import org.mockito.Mockito.eq @@ -55,6 +55,8 @@ import org.mockito.Mockito.mock import org.mockito.Mockito.never import org.mockito.Mockito.times import org.mockito.Mockito.verify +import org.mockito.Mockito.clearInvocations +import org.mockito.Mockito.inOrder private const val LOG_TAG = "testlogtag" private const val TIMEOUT_MS = 10_000L @@ -65,6 +67,7 @@ private val TEST_HOSTNAME = arrayOf("Android_test", "local") private const val TEST_SERVICE_ID_1 = 42 private const val TEST_SERVICE_ID_DUPLICATE = 43 +private const val TEST_SERVICE_ID_2 = 44 private val TEST_SERVICE_1 = NsdServiceInfo().apply { serviceType = "_testservice._tcp" serviceName = "MyTestService" @@ -78,6 +81,13 @@ private val TEST_SERVICE_1_SUBTYPE = NsdServiceInfo().apply { port = 12345 } +private val TEST_SERVICE_1_CUSTOM_HOST = NsdServiceInfo().apply { + serviceType = "_testservice._tcp" + serviceName = "MyTestService" + hostname = "MyTestHost" + port = 12345 +} + @RunWith(DevSdkIgnoreRunner::class) @IgnoreUpTo(Build.VERSION_CODES.S_V2) class MdnsInterfaceAdvertiserTest { @@ -182,6 +192,63 @@ class MdnsInterfaceAdvertiserTest { verify(cb).onAllServicesRemoved(socket) } + @Test + fun testAddRemoveServiceWithCustomHost_restartProbingForProbingServices() { + val customHost1 = NsdServiceInfo().apply { + hostname = "MyTestHost" + hostAddresses = listOf( + parseNumericAddress("192.0.2.23"), + parseNumericAddress("2001:db8::1")) + } + addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1) + addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST) + repository.setServiceProbing(TEST_SERVICE_ID_2) + val probingInfo = mock(ProbingInfo::class.java) + doReturn("MyTestHost") + .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1) + doReturn(TEST_SERVICE_ID_2).`when`(probingInfo).serviceId + doReturn(listOf(probingInfo)) + .`when`(repository).restartProbingForHostname("MyTestHost") + val inOrder = inOrder(prober, announcer) + + // Remove the custom host: the custom host's announcement is stopped and the probing + // services which use that hostname are re-announced. + advertiser.removeService(TEST_SERVICE_ID_1) + + inOrder.verify(prober).stop(TEST_SERVICE_ID_1) + inOrder.verify(announcer).stop(TEST_SERVICE_ID_1) + inOrder.verify(prober).stop(TEST_SERVICE_ID_2) + inOrder.verify(prober).startProbing(probingInfo) + } + + @Test + fun testAddRemoveServiceWithCustomHost_restartAnnouncingForProbedServices() { + val customHost1 = NsdServiceInfo().apply { + hostname = "MyTestHost" + hostAddresses = listOf( + parseNumericAddress("192.0.2.23"), + parseNumericAddress("2001:db8::1")) + } + addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1) + val announcementInfo = + addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST) + doReturn("MyTestHost") + .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1) + doReturn(TEST_SERVICE_ID_2).`when`(announcementInfo).serviceId + doReturn(listOf(announcementInfo)) + .`when`(repository).restartAnnouncingForHostname("MyTestHost") + clearInvocations(announcer) + + // Remove the custom host: the custom host's announcement is stopped and the probed services + // which use that hostname are re-announced. + advertiser.removeService(TEST_SERVICE_ID_1) + + verify(prober).stop(TEST_SERVICE_ID_1) + verify(announcer, atLeastOnce()).stop(TEST_SERVICE_ID_1) + verify(announcer).stop(TEST_SERVICE_ID_2) + verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */) + } + @Test fun testDoubleRemove() { addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1) diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt index c69b1e1050..271cc65243 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt @@ -24,6 +24,7 @@ import android.os.HandlerThread import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE +import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR @@ -51,6 +52,10 @@ import org.junit.After import org.junit.Before import org.junit.Test import org.junit.runner.RunWith +import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers.eq +import org.mockito.Mockito.mock +import org.mockito.Mockito.verify private const val TEST_SERVICE_ID_1 = 42 private const val TEST_SERVICE_ID_2 = 43 @@ -112,6 +117,14 @@ private val TEST_SERVICE_CUSTOM_HOST_1 = NsdServiceInfo().apply { port = TEST_PORT } +private val TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES = NsdServiceInfo().apply { + hostname = "TestHost" + hostAddresses = listOf() + serviceType = "_testservice._tcp" + serviceName = "TestService" + port = TEST_PORT +} + @RunWith(DevSdkIgnoreRunner::class) @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2) class MdnsRecordRepositoryTest { @@ -1676,6 +1689,127 @@ class MdnsRecordRepositoryTest { assertEquals(0, reply.additionalAnswers.size) assertEquals(knownAnswers, reply.knownAnswers) } + + @Test + fun testRestartProbingForHostname() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1, + setOf(TEST_SUBTYPE, TEST_SUBTYPE2)) + repository.addService(TEST_SERVICE_CUSTOM_HOST_ID_1, + TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES, null) + repository.setServiceProbing(TEST_SERVICE_CUSTOM_HOST_ID_1) + repository.removeService(TEST_CUSTOM_HOST_ID_1) + + val probingInfos = repository.restartProbingForHostname("TestHost") + + assertEquals(1, probingInfos.size) + val probingInfo = probingInfos.get(0) + assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, probingInfo.serviceId) + val packet = probingInfo.getPacket(0) + assertEquals(0, packet.transactionId) + assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags) + assertEquals(0, packet.answers.size) + assertEquals(0, packet.additionalRecords.size) + assertEquals(1, packet.questions.size) + val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local") + assertEquals(MdnsAnyRecord(serviceName, false /* unicast */), packet.questions[0]) + assertThat(packet.authorityRecords).containsExactly( + MdnsServiceRecord( + serviceName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + SHORT_TTL /* ttlMillis */, + 0 /* servicePriority */, + 0 /* serviceWeight */, + TEST_PORT, + TEST_CUSTOM_HOST_1_NAME)) + } + + @Test + fun testRestartAnnouncingForHostname() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1, + setOf(TEST_SUBTYPE, TEST_SUBTYPE2)) + repository.addServiceAndFinishProbing(TEST_SERVICE_CUSTOM_HOST_ID_1, + TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES) + repository.removeService(TEST_CUSTOM_HOST_ID_1) + + val announcementInfos = repository.restartAnnouncingForHostname("TestHost") + + assertEquals(1, announcementInfos.size) + val announcementInfo = announcementInfos.get(0) + assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, announcementInfo.serviceId) + val packet = announcementInfo.getPacket(0) + assertEquals(0, packet.transactionId) + assertEquals(0x8400 /* response, authoritative */, packet.flags) + assertEquals(0, packet.questions.size) + assertEquals(0, packet.authorityRecords.size) + val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local") + val serviceType = arrayOf("_testservice", "_tcp", "local") + val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address) + val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address) + val v6Addr2Rev = getReverseDnsAddress(TEST_ADDRESSES[2].address) + assertThat(packet.answers).containsExactly( + MdnsPointerRecord( + serviceType, + 0L /* receiptTimeMillis */, + // Not a unique name owned by the announcer, so cacheFlush=false + false /* cacheFlush */, + 4500000L /* ttlMillis */, + serviceName), + MdnsServiceRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 120000L /* ttlMillis */, + 0 /* servicePriority */, + 0 /* serviceWeight */, + TEST_PORT /* servicePort */, + TEST_CUSTOM_HOST_1_NAME), + MdnsTextRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 4500000L /* ttlMillis */, + emptyList() /* entries */), + MdnsPointerRecord( + arrayOf("_services", "_dns-sd", "_udp", "local"), + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + 4500000L /* ttlMillis */, + serviceType)) + assertThat(packet.additionalRecords).containsExactly( + MdnsNsecRecord(v4AddrRev, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 120000L /* ttlMillis */, + v4AddrRev, + intArrayOf(TYPE_PTR)), + MdnsNsecRecord(TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 120000L /* ttlMillis */, + TEST_HOSTNAME, + intArrayOf(TYPE_A, TYPE_AAAA)), + MdnsNsecRecord(v6Addr1Rev, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 120000L /* ttlMillis */, + v6Addr1Rev, + intArrayOf(TYPE_PTR)), + MdnsNsecRecord(v6Addr2Rev, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 120000L /* ttlMillis */, + v6Addr2Rev, + intArrayOf(TYPE_PTR)), + MdnsNsecRecord(serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 4500000L /* ttlMillis */, + serviceName, + intArrayOf(TYPE_TXT, TYPE_SRV))) + } } private fun MdnsRecordRepository.initWithService(