Add additional bounds checks to NNAPI FMQ deserialize utility functions

This CL adds the following additional bounds checks:
* Adds additional checks of the index of the std::vector before
  accessing the element at the index
* Changes the array index operator [] to the checked std::vector::at
  method

Bug: 256589724
Test: mma
Merged-In: I6bfb02a5cd76258284cc4d797a4508b21e672c4b
Change-Id: I6bfb02a5cd76258284cc4d797a4508b21e672c4b
This commit is contained in:
Michael Butler 2022-11-14 19:00:25 -08:00
parent 0891de19e6
commit 06493bc122

View file

@ -190,12 +190,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
size_t index = 0; size_t index = 0;
// validate packet information // validate packet information
if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::packetInformation) {
return NN_ERROR() << "FMQ Request packet ill-formed"; return NN_ERROR() << "FMQ Request packet ill-formed";
} }
// unpackage packet information // unpackage packet information
const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation(); const FmqRequestDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
index++; index++;
const uint32_t packetSize = packetInfo.packetSize; const uint32_t packetSize = packetInfo.packetSize;
const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands; const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
@ -212,13 +213,14 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
inputs.reserve(numberOfInputOperands); inputs.reserve(numberOfInputOperands);
for (size_t operand = 0; operand < numberOfInputOperands; ++operand) { for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
// validate input operand information // validate input operand information
if (data[index].getDiscriminator() != discriminator::inputOperandInformation) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::inputOperandInformation) {
return NN_ERROR() << "FMQ Request packet ill-formed"; return NN_ERROR() << "FMQ Request packet ill-formed";
} }
// unpackage operand information // unpackage operand information
const FmqRequestDatum::OperandInformation& operandInfo = const FmqRequestDatum::OperandInformation& operandInfo =
data[index].inputOperandInformation(); data.at(index).inputOperandInformation();
index++; index++;
const bool hasNoValue = operandInfo.hasNoValue; const bool hasNoValue = operandInfo.hasNoValue;
const V1_0::DataLocation location = operandInfo.location; const V1_0::DataLocation location = operandInfo.location;
@ -229,12 +231,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
dimensions.reserve(numberOfDimensions); dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) { for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension // validate dimension
if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::inputOperandDimensionValue) {
return NN_ERROR() << "FMQ Request packet ill-formed"; return NN_ERROR() << "FMQ Request packet ill-formed";
} }
// unpackage dimension // unpackage dimension
const uint32_t dimension = data[index].inputOperandDimensionValue(); const uint32_t dimension = data.at(index).inputOperandDimensionValue();
index++; index++;
// store result // store result
@ -251,13 +254,14 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
outputs.reserve(numberOfOutputOperands); outputs.reserve(numberOfOutputOperands);
for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) { for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
// validate output operand information // validate output operand information
if (data[index].getDiscriminator() != discriminator::outputOperandInformation) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::outputOperandInformation) {
return NN_ERROR() << "FMQ Request packet ill-formed"; return NN_ERROR() << "FMQ Request packet ill-formed";
} }
// unpackage operand information // unpackage operand information
const FmqRequestDatum::OperandInformation& operandInfo = const FmqRequestDatum::OperandInformation& operandInfo =
data[index].outputOperandInformation(); data.at(index).outputOperandInformation();
index++; index++;
const bool hasNoValue = operandInfo.hasNoValue; const bool hasNoValue = operandInfo.hasNoValue;
const V1_0::DataLocation location = operandInfo.location; const V1_0::DataLocation location = operandInfo.location;
@ -268,12 +272,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
dimensions.reserve(numberOfDimensions); dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) { for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension // validate dimension
if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::outputOperandDimensionValue) {
return NN_ERROR() << "FMQ Request packet ill-formed"; return NN_ERROR() << "FMQ Request packet ill-formed";
} }
// unpackage dimension // unpackage dimension
const uint32_t dimension = data[index].outputOperandDimensionValue(); const uint32_t dimension = data.at(index).outputOperandDimensionValue();
index++; index++;
// store result // store result
@ -290,12 +295,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
slots.reserve(numberOfPools); slots.reserve(numberOfPools);
for (size_t pool = 0; pool < numberOfPools; ++pool) { for (size_t pool = 0; pool < numberOfPools; ++pool) {
// validate input operand information // validate input operand information
if (data[index].getDiscriminator() != discriminator::poolIdentifier) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::poolIdentifier) {
return NN_ERROR() << "FMQ Request packet ill-formed"; return NN_ERROR() << "FMQ Request packet ill-formed";
} }
// unpackage operand information // unpackage operand information
const int32_t poolId = data[index].poolIdentifier(); const int32_t poolId = data.at(index).poolIdentifier();
index++; index++;
// store result // store result
@ -303,17 +309,17 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
} }
// validate measureTiming // validate measureTiming
if (data[index].getDiscriminator() != discriminator::measureTiming) { if (index >= data.size() || data.at(index).getDiscriminator() != discriminator::measureTiming) {
return NN_ERROR() << "FMQ Request packet ill-formed"; return NN_ERROR() << "FMQ Request packet ill-formed";
} }
// unpackage measureTiming // unpackage measureTiming
const V1_2::MeasureTiming measure = data[index].measureTiming(); const V1_2::MeasureTiming measure = data.at(index).measureTiming();
index++; index++;
// validate packet information // validate packet information
if (index != packetSize) { if (index != packetSize) {
return NN_ERROR() << "FMQ Result packet ill-formed"; return NN_ERROR() << "FMQ Request packet ill-formed";
} }
// return request // return request
@ -328,12 +334,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
size_t index = 0; size_t index = 0;
// validate packet information // validate packet information
if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::packetInformation) {
return NN_ERROR() << "FMQ Result packet ill-formed"; return NN_ERROR() << "FMQ Result packet ill-formed";
} }
// unpackage packet information // unpackage packet information
const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation(); const FmqResultDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
index++; index++;
const uint32_t packetSize = packetInfo.packetSize; const uint32_t packetSize = packetInfo.packetSize;
const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus; const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
@ -349,12 +356,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
outputShapes.reserve(numberOfOperands); outputShapes.reserve(numberOfOperands);
for (size_t operand = 0; operand < numberOfOperands; ++operand) { for (size_t operand = 0; operand < numberOfOperands; ++operand) {
// validate operand information // validate operand information
if (data[index].getDiscriminator() != discriminator::operandInformation) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::operandInformation) {
return NN_ERROR() << "FMQ Result packet ill-formed"; return NN_ERROR() << "FMQ Result packet ill-formed";
} }
// unpackage operand information // unpackage operand information
const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation(); const FmqResultDatum::OperandInformation& operandInfo = data.at(index).operandInformation();
index++; index++;
const bool isSufficient = operandInfo.isSufficient; const bool isSufficient = operandInfo.isSufficient;
const uint32_t numberOfDimensions = operandInfo.numberOfDimensions; const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
@ -364,12 +372,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
dimensions.reserve(numberOfDimensions); dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) { for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension // validate dimension
if (data[index].getDiscriminator() != discriminator::operandDimensionValue) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::operandDimensionValue) {
return NN_ERROR() << "FMQ Result packet ill-formed"; return NN_ERROR() << "FMQ Result packet ill-formed";
} }
// unpackage dimension // unpackage dimension
const uint32_t dimension = data[index].operandDimensionValue(); const uint32_t dimension = data.at(index).operandDimensionValue();
index++; index++;
// store result // store result
@ -381,12 +390,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
} }
// validate execution timing // validate execution timing
if (data[index].getDiscriminator() != discriminator::executionTiming) { if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::executionTiming) {
return NN_ERROR() << "FMQ Result packet ill-formed"; return NN_ERROR() << "FMQ Result packet ill-formed";
} }
// unpackage execution timing // unpackage execution timing
const V1_2::Timing timing = data[index].executionTiming(); const V1_2::Timing timing = data.at(index).executionTiming();
index++; index++;
// validate packet information // validate packet information