Merge "[mdns] restart probing/announcing the services on host address removal" into main

This commit is contained in:
Handa Wang 2024-03-25 04:34:11 +00:00 committed by Gerrit Code Review
commit 1bd299524f
5 changed files with 344 additions and 9 deletions

View file

@ -22,12 +22,10 @@ import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.RequiresApi;
import android.net.LinkAddress;
import android.net.nsd.NsdManager;
import android.net.nsd.NsdServiceInfo;
import android.os.Build;
import android.os.Handler;
import android.os.Looper;
import android.util.ArraySet;
import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.HexDump;
@ -284,6 +282,7 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
if (!mRecordRepository.hasActiveService(id)) return;
mProber.stop(id);
mAnnouncer.stop(id);
final String hostname = mRecordRepository.getHostnameForServiceId(id);
final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
if (exitInfo != null) {
// This effectively schedules onAllServicesRemoved(), as it is to be called when the
@ -303,6 +302,24 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
}
});
}
// Re-probe/re-announce the services which have the same custom hostname. These services
// were probed/announced using host addresses which were just removed so they should be
// re-probed/re-announced without those addresses.
if (hostname != null) {
final List<MdnsProber.ProbingInfo> probingInfos =
mRecordRepository.restartProbingForHostname(hostname);
for (MdnsProber.ProbingInfo probingInfo : probingInfos) {
mProber.stop(probingInfo.getServiceId());
mProber.startProbing(probingInfo);
}
final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
mRecordRepository.restartAnnouncingForHostname(hostname);
for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) {
mAnnouncer.stop(announcementInfo.getServiceId());
mAnnouncer.startSending(
announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */);
}
}
}
/**

View file

@ -925,22 +925,79 @@ public class MdnsRecordRepository {
}
}
@Nullable
public String getHostnameForServiceId(int id) {
ServiceRegistration registration = mServices.get(id);
if (registration == null) {
return null;
}
return registration.serviceInfo.getHostname();
}
/**
* Restart probing the services which are being probed and using the given custom hostname.
*
* @return The list of {@link MdnsProber.ProbingInfo} to be used by advertiser.
*/
public List<MdnsProber.ProbingInfo> restartProbingForHostname(@NonNull String hostname) {
final ArrayList<MdnsProber.ProbingInfo> probingInfos = new ArrayList<>();
forEachActiveServiceRegistrationWithHostname(
hostname,
(id, registration) -> {
if (!registration.isProbing) {
return;
}
probingInfos.add(makeProbingInfo(id, registration));
});
return probingInfos;
}
/**
* Restart announcing the services which are using the given custom hostname.
*
* @return The list of {@link MdnsAnnouncer.AnnouncementInfo} to be used by advertiser.
*/
public List<MdnsAnnouncer.AnnouncementInfo> restartAnnouncingForHostname(
@NonNull String hostname) {
final ArrayList<MdnsAnnouncer.AnnouncementInfo> announcementInfos = new ArrayList<>();
forEachActiveServiceRegistrationWithHostname(
hostname,
(id, registration) -> {
if (registration.isProbing) {
return;
}
announcementInfos.add(makeAnnouncementInfo(id, registration));
});
return announcementInfos;
}
/**
* Called to indicate that probing succeeded for a service.
*
* @param probeSuccessInfo The successful probing info.
* @return The {@link MdnsAnnouncer.AnnouncementInfo} to send, now that probing has succeeded.
*/
public MdnsAnnouncer.AnnouncementInfo onProbingSucceeded(
MdnsProber.ProbingInfo probeSuccessInfo)
throws IOException {
int serviceId = probeSuccessInfo.getServiceId();
MdnsProber.ProbingInfo probeSuccessInfo) throws IOException {
final int serviceId = probeSuccessInfo.getServiceId();
final ServiceRegistration registration = mServices.get(serviceId);
if (registration == null) {
throw new IOException("Service is not registered: " + serviceId);
}
registration.setProbing(false);
return makeAnnouncementInfo(serviceId, registration);
}
/**
* Make the announcement info of the given service ID.
*
* @param serviceId The service ID.
* @param registration The service registration.
* @return The {@link MdnsAnnouncer.AnnouncementInfo} of the given service ID.
*/
private MdnsAnnouncer.AnnouncementInfo makeAnnouncementInfo(
int serviceId, ServiceRegistration registration) {
final Set<MdnsRecord> answersSet = new LinkedHashSet<>();
final ArrayList<MdnsRecord> additionalAnswers = new ArrayList<>();
@ -972,8 +1029,8 @@ public class MdnsRecordRepository {
addNsecRecordsForUniqueNames(additionalAnswers,
mGeneralRecords.iterator(), registration.allRecords.iterator());
return new MdnsAnnouncer.AnnouncementInfo(
probeSuccessInfo.getServiceId(), new ArrayList<>(answersSet), additionalAnswers);
return new MdnsAnnouncer.AnnouncementInfo(serviceId,
new ArrayList<>(answersSet), additionalAnswers);
}
/**

View file

@ -2205,6 +2205,66 @@ class NsdManagerTest {
}
}
@Test
fun testAdvertisingAndDiscovery_reregisterCustomHostWithDifferentAddresses_newAddressesFound() {
val si1 = NsdServiceInfo().also {
it.network = testNetwork1.network
it.hostname = customHostname
it.hostAddresses = listOf(
parseNumericAddress("192.0.2.23"),
parseNumericAddress("2001:db8::1"))
}
val si2 = NsdServiceInfo().also {
it.network = testNetwork1.network
it.serviceName = serviceName
it.serviceType = serviceType
it.hostname = customHostname
it.port = TEST_PORT
}
val si3 = NsdServiceInfo().also {
it.network = testNetwork1.network
it.hostname = customHostname
it.hostAddresses = listOf(
parseNumericAddress("192.0.2.24"),
parseNumericAddress("2001:db8::2"))
}
val registrationRecord1 = NsdRegistrationRecord()
val registrationRecord2 = NsdRegistrationRecord()
val registrationRecord3 = NsdRegistrationRecord()
val discoveryRecord = NsdDiscoveryRecord()
tryTest {
registerService(registrationRecord1, si1)
registerService(registrationRecord2, si2)
nsdManager.unregisterService(registrationRecord1)
registrationRecord1.expectCallback<ServiceUnregistered>()
registerService(registrationRecord3, si3)
nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
testNetwork1.network, Executor { it.run() }, discoveryRecord)
val discoveredInfo = discoveryRecord.waitForServiceDiscovered(
serviceName, serviceType, testNetwork1.network)
val resolvedInfo = resolveService(discoveredInfo)
assertEquals(serviceName, discoveredInfo.serviceName)
assertEquals(TEST_PORT, resolvedInfo.port)
assertEquals(customHostname, resolvedInfo.hostname)
assertAddressEquals(
listOf(parseNumericAddress("192.0.2.24"), parseNumericAddress("2001:db8::2")),
resolvedInfo.hostAddresses)
} cleanupStep {
nsdManager.stopServiceDiscovery(discoveryRecord)
discoveryRecord.expectCallbackEventually<DiscoveryStopped>()
} cleanup {
nsdManager.unregisterService(registrationRecord2)
nsdManager.unregisterService(registrationRecord3)
}
}
@Test
fun testServiceTypeClientRemovedAfterSocketDestroyed() {
val si = makeTestServiceInfo(testNetwork1.network)

View file

@ -18,7 +18,6 @@ package com.android.server.connectivity.mdns
import android.net.InetAddresses.parseNumericAddress
import android.net.LinkAddress
import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
@ -48,6 +47,7 @@ import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.anyString
import org.mockito.Mockito.argThat
import org.mockito.Mockito.atLeastOnce
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.eq
@ -55,6 +55,8 @@ import org.mockito.Mockito.mock
import org.mockito.Mockito.never
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
import org.mockito.Mockito.clearInvocations
import org.mockito.Mockito.inOrder
private const val LOG_TAG = "testlogtag"
private const val TIMEOUT_MS = 10_000L
@ -65,6 +67,7 @@ private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SERVICE_ID_1 = 42
private const val TEST_SERVICE_ID_DUPLICATE = 43
private const val TEST_SERVICE_ID_2 = 44
private val TEST_SERVICE_1 = NsdServiceInfo().apply {
serviceType = "_testservice._tcp"
serviceName = "MyTestService"
@ -78,6 +81,13 @@ private val TEST_SERVICE_1_SUBTYPE = NsdServiceInfo().apply {
port = 12345
}
private val TEST_SERVICE_1_CUSTOM_HOST = NsdServiceInfo().apply {
serviceType = "_testservice._tcp"
serviceName = "MyTestService"
hostname = "MyTestHost"
port = 12345
}
@RunWith(DevSdkIgnoreRunner::class)
@IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsInterfaceAdvertiserTest {
@ -182,6 +192,63 @@ class MdnsInterfaceAdvertiserTest {
verify(cb).onAllServicesRemoved(socket)
}
@Test
fun testAddRemoveServiceWithCustomHost_restartProbingForProbingServices() {
val customHost1 = NsdServiceInfo().apply {
hostname = "MyTestHost"
hostAddresses = listOf(
parseNumericAddress("192.0.2.23"),
parseNumericAddress("2001:db8::1"))
}
addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1)
addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
repository.setServiceProbing(TEST_SERVICE_ID_2)
val probingInfo = mock(ProbingInfo::class.java)
doReturn("MyTestHost")
.`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
doReturn(TEST_SERVICE_ID_2).`when`(probingInfo).serviceId
doReturn(listOf(probingInfo))
.`when`(repository).restartProbingForHostname("MyTestHost")
val inOrder = inOrder(prober, announcer)
// Remove the custom host: the custom host's announcement is stopped and the probing
// services which use that hostname are re-announced.
advertiser.removeService(TEST_SERVICE_ID_1)
inOrder.verify(prober).stop(TEST_SERVICE_ID_1)
inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
inOrder.verify(prober).stop(TEST_SERVICE_ID_2)
inOrder.verify(prober).startProbing(probingInfo)
}
@Test
fun testAddRemoveServiceWithCustomHost_restartAnnouncingForProbedServices() {
val customHost1 = NsdServiceInfo().apply {
hostname = "MyTestHost"
hostAddresses = listOf(
parseNumericAddress("192.0.2.23"),
parseNumericAddress("2001:db8::1"))
}
addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1)
val announcementInfo =
addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
doReturn("MyTestHost")
.`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
doReturn(TEST_SERVICE_ID_2).`when`(announcementInfo).serviceId
doReturn(listOf(announcementInfo))
.`when`(repository).restartAnnouncingForHostname("MyTestHost")
clearInvocations(announcer)
// Remove the custom host: the custom host's announcement is stopped and the probed services
// which use that hostname are re-announced.
advertiser.removeService(TEST_SERVICE_ID_1)
verify(prober).stop(TEST_SERVICE_ID_1)
verify(announcer, atLeastOnce()).stop(TEST_SERVICE_ID_1)
verify(announcer).stop(TEST_SERVICE_ID_2)
verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */)
}
@Test
fun testDoubleRemove() {
addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)

View file

@ -24,6 +24,7 @@ import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST
import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE
import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A
import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA
import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR
@ -51,6 +52,10 @@ import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.eq
import org.mockito.Mockito.mock
import org.mockito.Mockito.verify
private const val TEST_SERVICE_ID_1 = 42
private const val TEST_SERVICE_ID_2 = 43
@ -112,6 +117,14 @@ private val TEST_SERVICE_CUSTOM_HOST_1 = NsdServiceInfo().apply {
port = TEST_PORT
}
private val TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES = NsdServiceInfo().apply {
hostname = "TestHost"
hostAddresses = listOf()
serviceType = "_testservice._tcp"
serviceName = "TestService"
port = TEST_PORT
}
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsRecordRepositoryTest {
@ -1676,6 +1689,127 @@ class MdnsRecordRepositoryTest {
assertEquals(0, reply.additionalAnswers.size)
assertEquals(knownAnswers, reply.knownAnswers)
}
@Test
fun testRestartProbingForHostname() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1,
setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
repository.addService(TEST_SERVICE_CUSTOM_HOST_ID_1,
TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES, null)
repository.setServiceProbing(TEST_SERVICE_CUSTOM_HOST_ID_1)
repository.removeService(TEST_CUSTOM_HOST_ID_1)
val probingInfos = repository.restartProbingForHostname("TestHost")
assertEquals(1, probingInfos.size)
val probingInfo = probingInfos.get(0)
assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, probingInfo.serviceId)
val packet = probingInfo.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
assertEquals(0, packet.answers.size)
assertEquals(0, packet.additionalRecords.size)
assertEquals(1, packet.questions.size)
val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local")
assertEquals(MdnsAnyRecord(serviceName, false /* unicast */), packet.questions[0])
assertThat(packet.authorityRecords).containsExactly(
MdnsServiceRecord(
serviceName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
SHORT_TTL /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
TEST_PORT,
TEST_CUSTOM_HOST_1_NAME))
}
@Test
fun testRestartAnnouncingForHostname() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1,
setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
repository.addServiceAndFinishProbing(TEST_SERVICE_CUSTOM_HOST_ID_1,
TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES)
repository.removeService(TEST_CUSTOM_HOST_ID_1)
val announcementInfos = repository.restartAnnouncingForHostname("TestHost")
assertEquals(1, announcementInfos.size)
val announcementInfo = announcementInfos.get(0)
assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, announcementInfo.serviceId)
val packet = announcementInfo.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
assertEquals(0, packet.questions.size)
assertEquals(0, packet.authorityRecords.size)
val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local")
val serviceType = arrayOf("_testservice", "_tcp", "local")
val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
val v6Addr2Rev = getReverseDnsAddress(TEST_ADDRESSES[2].address)
assertThat(packet.answers).containsExactly(
MdnsPointerRecord(
serviceType,
0L /* receiptTimeMillis */,
// Not a unique name owned by the announcer, so cacheFlush=false
false /* cacheFlush */,
4500000L /* ttlMillis */,
serviceName),
MdnsServiceRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
TEST_PORT /* servicePort */,
TEST_CUSTOM_HOST_1_NAME),
MdnsTextRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
4500000L /* ttlMillis */,
emptyList() /* entries */),
MdnsPointerRecord(
arrayOf("_services", "_dns-sd", "_udp", "local"),
0L /* receiptTimeMillis */,
false /* cacheFlush */,
4500000L /* ttlMillis */,
serviceType))
assertThat(packet.additionalRecords).containsExactly(
MdnsNsecRecord(v4AddrRev,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
v4AddrRev,
intArrayOf(TYPE_PTR)),
MdnsNsecRecord(TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
TEST_HOSTNAME,
intArrayOf(TYPE_A, TYPE_AAAA)),
MdnsNsecRecord(v6Addr1Rev,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
v6Addr1Rev,
intArrayOf(TYPE_PTR)),
MdnsNsecRecord(v6Addr2Rev,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
v6Addr2Rev,
intArrayOf(TYPE_PTR)),
MdnsNsecRecord(serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
4500000L /* ttlMillis */,
serviceName,
intArrayOf(TYPE_TXT, TYPE_SRV)))
}
}
private fun MdnsRecordRepository.initWithService(