Merge "blockimgdiff: Factor out the diff_worker"

am: 8ab919dcce

Change-Id: I590dac85679de6209a3ba00052a309f41567cbae
This commit is contained in:
Tianjie Xu 2018-12-13 12:16:18 -08:00 committed by android-build-merger
commit 5437eeeb1c

View file

@ -26,7 +26,8 @@ import os.path
import re import re
import sys import sys
import threading import threading
from collections import deque, OrderedDict import zlib
from collections import deque, namedtuple, OrderedDict
from hashlib import sha1 from hashlib import sha1
import common import common
@ -36,8 +37,12 @@ __all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The tuple contains the style and bytes of a bsdiff|imgdiff patch.
PatchInfo = namedtuple("PatchInfo", ["imgdiff", "content"])
def compute_patch(srcfile, tgtfile, imgdiff=False): def compute_patch(srcfile, tgtfile, imgdiff=False):
"""Calls bsdiff|imgdiff to compute the patch data, returns a PatchInfo."""
patchfile = common.MakeTempFile(prefix='patch-') patchfile = common.MakeTempFile(prefix='patch-')
cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff'] cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
@ -52,7 +57,7 @@ def compute_patch(srcfile, tgtfile, imgdiff=False):
raise ValueError(output) raise ValueError(output)
with open(patchfile, 'rb') as f: with open(patchfile, 'rb') as f:
return f.read() return PatchInfo(imgdiff, f.read())
class Image(object): class Image(object):
@ -203,17 +208,17 @@ class Transfer(object):
self.id = len(by_id) self.id = len(by_id)
by_id.append(self) by_id.append(self)
self._patch = None self._patch_info = None
@property @property
def patch(self): def patch_info(self):
return self._patch return self._patch_info
@patch.setter @patch_info.setter
def patch(self, patch): def patch_info(self, info):
if patch: if info:
assert self.style == "diff" assert self.style == "diff"
self._patch = patch self._patch_info = info
def NetStashChange(self): def NetStashChange(self):
return (sum(sr.size() for (_, sr) in self.stash_before) - return (sum(sr.size() for (_, sr) in self.stash_before) -
@ -224,7 +229,7 @@ class Transfer(object):
self.use_stash = [] self.use_stash = []
self.style = "new" self.style = "new"
self.src_ranges = RangeSet() self.src_ranges = RangeSet()
self.patch = None self.patch_info = None
def __str__(self): def __str__(self):
return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
@ -462,16 +467,7 @@ class BlockImageDiff(object):
self.AbbreviateSourceNames() self.AbbreviateSourceNames()
self.FindTransfers() self.FindTransfers()
# Find the ordering dependencies among transfers (this is O(n^2) self.FindSequenceForTransfers()
# in the number of transfers).
self.GenerateDigraph()
# Find a sequence of transfers that satisfies as many ordering
# dependencies as possible (heuristically).
self.FindVertexSequence()
# Fix up the ordering dependencies that the sequence didn't
# satisfy.
self.ReverseBackwardEdges()
self.ImproveVertexSequence()
# Ensure the runtime stash size is under the limit. # Ensure the runtime stash size is under the limit.
if common.OPTIONS.cache_size is not None: if common.OPTIONS.cache_size is not None:
@ -829,7 +825,7 @@ class BlockImageDiff(object):
# These are identical; we don't need to generate a patch, # These are identical; we don't need to generate a patch,
# just issue copy commands on the device. # just issue copy commands on the device.
xf.style = "move" xf.style = "move"
xf.patch = None xf.patch_info = None
tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
if xf.src_ranges != xf.tgt_ranges: if xf.src_ranges != xf.tgt_ranges:
logger.info( logger.info(
@ -839,11 +835,10 @@ class BlockImageDiff(object):
xf.tgt_name + " (from " + xf.src_name + ")"), xf.tgt_name + " (from " + xf.src_name + ")"),
str(xf.tgt_ranges), str(xf.src_ranges)) str(xf.tgt_ranges), str(xf.src_ranges))
else: else:
if xf.patch: if xf.patch_info:
# We have already generated the patch with imgdiff, while # We have already generated the patch (e.g. during split of large
# splitting large APKs (i.e. in FindTransfers()). # APKs or reduction of stash size)
assert not self.disable_imgdiff imgdiff = xf.patch_info.imgdiff
imgdiff = True
else: else:
imgdiff = self.CanUseImgdiff( imgdiff = self.CanUseImgdiff(
xf.tgt_name, xf.tgt_ranges, xf.src_ranges) xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
@ -854,85 +849,16 @@ class BlockImageDiff(object):
else: else:
assert False, "unknown style " + xf.style assert False, "unknown style " + xf.style
if diff_queue: patches = self.ComputePatchesForInputList(diff_queue, False)
if self.threads > 1:
logger.info("Computing patches (using %d threads)...", self.threads)
else:
logger.info("Computing patches...")
diff_total = len(diff_queue)
patches = [None] * diff_total
error_messages = []
# Using multiprocessing doesn't give additional benefits, due to the
# pattern of the code. The diffing work is done by subprocess.call, which
# already runs in a separate process (not affected much by the GIL -
# Global Interpreter Lock). Using multiprocess also requires either a)
# writing the diff input files in the main process before forking, or b)
# reopening the image file (SparseImage) in the worker processes. Doing
# neither of them further improves the performance.
lock = threading.Lock()
def diff_worker():
while True:
with lock:
if not diff_queue:
return
xf_index, imgdiff, patch_index = diff_queue.pop()
xf = self.transfers[xf_index]
patch = xf.patch
if not patch:
src_ranges = xf.src_ranges
tgt_ranges = xf.tgt_ranges
src_file = common.MakeTempFile(prefix="src-")
with open(src_file, "wb") as fd:
self.src.WriteRangeDataToFd(src_ranges, fd)
tgt_file = common.MakeTempFile(prefix="tgt-")
with open(tgt_file, "wb") as fd:
self.tgt.WriteRangeDataToFd(tgt_ranges, fd)
message = []
try:
patch = compute_patch(src_file, tgt_file, imgdiff)
except ValueError as e:
message.append(
"Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
"imgdiff" if imgdiff else "bsdiff",
xf.tgt_name if xf.tgt_name == xf.src_name else
xf.tgt_name + " (from " + xf.src_name + ")",
xf.tgt_ranges, xf.src_ranges, e.message))
if message:
with lock:
error_messages.extend(message)
with lock:
patches[patch_index] = (xf_index, patch)
threads = [threading.Thread(target=diff_worker)
for _ in range(self.threads)]
for th in threads:
th.start()
while threads:
threads.pop().join()
if error_messages:
logger.error('ERROR:')
logger.error('\n'.join(error_messages))
logger.error('\n\n\n')
sys.exit(1)
else:
patches = []
offset = 0 offset = 0
with open(prefix + ".patch.dat", "wb") as patch_fd: with open(prefix + ".patch.dat", "wb") as patch_fd:
for index, patch in patches: for index, patch_info, _ in patches:
xf = self.transfers[index] xf = self.transfers[index]
xf.patch_len = len(patch) xf.patch_len = len(patch_info.content)
xf.patch_start = offset xf.patch_start = offset
offset += xf.patch_len offset += xf.patch_len
patch_fd.write(patch) patch_fd.write(patch_info.content)
tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
logger.info( logger.info(
@ -999,6 +925,32 @@ class BlockImageDiff(object):
for i in range(s, e): for i in range(s, e):
assert touched[i] == 1 assert touched[i] == 1
def FindSequenceForTransfers(self):
"""Finds a sequence for the given transfers.
The goal is to minimize the violation of order dependencies between these
transfers, so that fewer blocks are stashed when applying the update.
"""
# Clear the existing dependency between transfers
for xf in self.transfers:
xf.goes_before = OrderedDict()
xf.goes_after = OrderedDict()
xf.stash_before = []
xf.use_stash = []
# Find the ordering dependencies among transfers (this is O(n^2)
# in the number of transfers).
self.GenerateDigraph()
# Find a sequence of transfers that satisfies as many ordering
# dependencies as possible (heuristically).
self.FindVertexSequence()
# Fix up the ordering dependencies that the sequence didn't
# satisfy.
self.ReverseBackwardEdges()
self.ImproveVertexSequence()
def ImproveVertexSequence(self): def ImproveVertexSequence(self):
logger.info("Improving vertex order...") logger.info("Improving vertex order...")
@ -1248,6 +1200,105 @@ class BlockImageDiff(object):
b.goes_before[a] = size b.goes_before[a] = size
a.goes_after[b] = size a.goes_after[b] = size
def ComputePatchesForInputList(self, diff_queue, compress_target):
"""Returns a list of patch information for the input list of transfers.
Args:
diff_queue: a list of transfers with style 'diff'
compress_target: If True, compresses the target ranges of each
transfers; and save the size.
Returns:
A list of (transfer order, patch_info, compressed_size) tuples.
"""
if not diff_queue:
return []
if self.threads > 1:
logger.info("Computing patches (using %d threads)...", self.threads)
else:
logger.info("Computing patches...")
diff_total = len(diff_queue)
patches = [None] * diff_total
error_messages = []
# Using multiprocessing doesn't give additional benefits, due to the
# pattern of the code. The diffing work is done by subprocess.call, which
# already runs in a separate process (not affected much by the GIL -
# Global Interpreter Lock). Using multiprocess also requires either a)
# writing the diff input files in the main process before forking, or b)
# reopening the image file (SparseImage) in the worker processes. Doing
# neither of them further improves the performance.
lock = threading.Lock()
def diff_worker():
while True:
with lock:
if not diff_queue:
return
xf_index, imgdiff, patch_index = diff_queue.pop()
xf = self.transfers[xf_index]
message = []
compressed_size = None
patch_info = xf.patch_info
if not patch_info:
src_file = common.MakeTempFile(prefix="src-")
with open(src_file, "wb") as fd:
self.src.WriteRangeDataToFd(xf.src_ranges, fd)
tgt_file = common.MakeTempFile(prefix="tgt-")
with open(tgt_file, "wb") as fd:
self.tgt.WriteRangeDataToFd(xf.tgt_ranges, fd)
try:
patch_info = compute_patch(src_file, tgt_file, imgdiff)
except ValueError as e:
message.append(
"Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
"imgdiff" if imgdiff else "bsdiff",
xf.tgt_name if xf.tgt_name == xf.src_name else
xf.tgt_name + " (from " + xf.src_name + ")",
xf.tgt_ranges, xf.src_ranges, e.message))
if compress_target:
tgt_data = self.tgt.ReadRangeSet(xf.tgt_ranges)
try:
# Compresses with the default level
compress_obj = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS)
compressed_data = (compress_obj.compress("".join(tgt_data))
+ compress_obj.flush())
compressed_size = len(compressed_data)
except zlib.error as e:
message.append(
"Failed to compress the data in target range {} for {}:\n"
"{}".format(xf.tgt_ranges, xf.tgt_name, e.message))
if message:
with lock:
error_messages.extend(message)
with lock:
patches[patch_index] = (xf_index, patch_info, compressed_size)
threads = [threading.Thread(target=diff_worker)
for _ in range(self.threads)]
for th in threads:
th.start()
while threads:
threads.pop().join()
if error_messages:
logger.error('ERROR:')
logger.error('\n'.join(error_messages))
logger.error('\n\n\n')
sys.exit(1)
return patches
def FindTransfers(self): def FindTransfers(self):
"""Parse the file_map to generate all the transfers.""" """Parse the file_map to generate all the transfers."""
@ -1585,7 +1636,7 @@ class BlockImageDiff(object):
self.tgt.RangeSha1(tgt_ranges), self.tgt.RangeSha1(tgt_ranges),
self.src.RangeSha1(src_ranges), self.src.RangeSha1(src_ranges),
"diff", self.transfers) "diff", self.transfers)
transfer_split.patch = patch transfer_split.patch_info = PatchInfo(True, patch)
def AbbreviateSourceNames(self): def AbbreviateSourceNames(self):
for k in self.src.file_map.keys(): for k in self.src.file_map.keys():