diff --git a/Checkpoint.cpp b/Checkpoint.cpp index 7ff7766..eca49ef 100644 --- a/Checkpoint.cpp +++ b/Checkpoint.cpp @@ -597,21 +597,31 @@ void restoreSector(int device_fd, Used_Sectors& used_sectors, std::vector& // Read from the device // If we are validating, the read occurs as though the relocations had happened +// returns the amount asked for or an empty buffer on error. Partial reads are considered a failure std::vector relocatedRead(int device_fd, Relocations const& relocations, bool validating, sector_t sector, uint32_t size, uint32_t block_size) { if (!validating) { std::vector buffer(size); - lseek64(device_fd, sector * kSectorSize, SEEK_SET); - read(device_fd, &buffer[0], size); + off64_t offset = sector * kSectorSize; + if (lseek64(device_fd, offset, SEEK_SET) != offset) { + return std::vector(); + } + if (read(device_fd, &buffer[0], size) != static_cast(size)) { + return std::vector(); + } return buffer; } std::vector buffer(size); for (uint32_t i = 0; i < size; i += block_size, sector += block_size / kSectorSize) { auto relocation = --relocations.upper_bound(sector); - lseek64(device_fd, (sector + relocation->second - relocation->first) * kSectorSize, - SEEK_SET); - read(device_fd, &buffer[i], block_size); + off64_t offset = (sector + relocation->second - relocation->first) * kSectorSize; + if (lseek64(device_fd, offset, SEEK_SET) != offset) { + return std::vector(); + } + if (read(device_fd, &buffer[i], block_size) != static_cast(block_size)) { + return std::vector(); + } } return buffer; @@ -634,7 +644,10 @@ Status cp_restoreCheckpoint(const std::string& blockDevice, int restore_limit) { if (device_fd < 0) return error("Cannot open " + blockDevice); log_sector_v1_0 original_ls; - read(device_fd, reinterpret_cast(&original_ls), sizeof(original_ls)); + if (read(device_fd, reinterpret_cast(&original_ls), sizeof(original_ls)) != + sizeof(original_ls)) { + return error(EINVAL, "Cannot read sector"); + } if (original_ls.magic == kPartialRestoreMagic) { validating = false; action = "Restoring"; @@ -642,11 +655,19 @@ Status cp_restoreCheckpoint(const std::string& blockDevice, int restore_limit) { return error(EINVAL, "No magic"); } + if (original_ls.block_size < sizeof(log_sector_v1_0)) { + return error(EINVAL, "Block size is invalid"); + } + LOG(INFO) << action << " " << original_ls.sequence << " log sectors"; for (int sequence = original_ls.sequence; sequence >= 0 && status.isOk(); sequence--) { auto ls_buffer = relocatedRead(device_fd, relocations, validating, 0, original_ls.block_size, original_ls.block_size); + if (ls_buffer.size() != original_ls.block_size) { + status = error(EINVAL, "Failed to read log sector"); + break; + } log_sector_v1_0& ls = *reinterpret_cast(&ls_buffer[0]); Used_Sectors used_sectors; @@ -668,6 +689,14 @@ Status cp_restoreCheckpoint(const std::string& blockDevice, int restore_limit) { break; } + if (ls.header_size < sizeof(log_sector_v1_0) || ls.header_size > ls.block_size) { + status = error(EINVAL, "Log sector header size is invalid"); + break; + } + if (ls.count < 1 || ls.count > (ls.block_size - ls.header_size) / sizeof(log_entry)) { + status = error(EINVAL, "Log sector count is invalid"); + break; + } LOG(INFO) << action << " from log sector " << ls.sequence; for (log_entry* le = reinterpret_cast(&ls_buffer[ls.header_size]) + ls.count - 1; @@ -677,8 +706,16 @@ Status cp_restoreCheckpoint(const std::string& blockDevice, int restore_limit) { << " to " << le->source << " with checksum " << std::hex << le->checksum; + if (ls.block_size > UINT_MAX - le->size || le->size < ls.block_size) { + status = error(EINVAL, "log entry is invalid"); + break; + } auto buffer = relocatedRead(device_fd, relocations, validating, le->dest, le->size, ls.block_size); + if (buffer.size() != le->size) { + status = error(EINVAL, "Failed to read sector"); + break; + } uint32_t checksum = le->source / (ls.block_size / kSectorSize); for (size_t i = 0; i < le->size; i += ls.block_size) { crc32(&buffer[i], ls.block_size, &checksum); @@ -711,8 +748,17 @@ Status cp_restoreCheckpoint(const std::string& blockDevice, int restore_limit) { LOG(WARNING) << "Checkpoint validation failed - attempting to roll forward"; auto buffer = relocatedRead(device_fd, relocations, false, original_ls.sector0, original_ls.block_size, original_ls.block_size); - lseek64(device_fd, 0, SEEK_SET); - write(device_fd, &buffer[0], original_ls.block_size); + if (buffer.size() != original_ls.block_size) { + return error(EINVAL, "Failed to read original sector"); + } + + if (lseek64(device_fd, 0, SEEK_SET) != 0) { + return error(EINVAL, "Failed to seek to sector 0"); + } + if (write(device_fd, &buffer[0], original_ls.block_size) != + static_cast(original_ls.block_size)) { + return error(EINVAL, "Failed to write original sector"); + } return Status::ok(); }