Merge "cpp codegen redesign, unit test support"

This commit is contained in:
Dennis Shen 2023-07-05 19:03:11 +00:00 committed by Gerrit Code Review
commit 99d4a49d68
7 changed files with 467 additions and 33 deletions

View file

@ -16,13 +16,18 @@
use anyhow::{ensure, Result}; use anyhow::{ensure, Result};
use serde::Serialize; use serde::Serialize;
use std::path::PathBuf;
use tinytemplate::TinyTemplate; use tinytemplate::TinyTemplate;
use crate::codegen; use crate::codegen;
use crate::commands::OutputFile; use crate::commands::{CodegenMode, OutputFile};
use crate::protos::{ProtoFlagPermission, ProtoFlagState, ProtoParsedFlag}; use crate::protos::{ProtoFlagPermission, ProtoFlagState, ProtoParsedFlag};
pub fn generate_cpp_code<'a, I>(package: &str, parsed_flags_iter: I) -> Result<OutputFile> pub fn generate_cpp_code<'a, I>(
package: &str,
parsed_flags_iter: I,
codegen_mode: CodegenMode,
) -> Result<Vec<OutputFile>>
where where
I: Iterator<Item = &'a ProtoParsedFlag>, I: Iterator<Item = &'a ProtoParsedFlag>,
{ {
@ -37,29 +42,66 @@ where
cpp_namespace, cpp_namespace,
package: package.to_string(), package: package.to_string(),
readwrite, readwrite,
for_prod: codegen_mode == CodegenMode::Production,
class_elements, class_elements,
}; };
let files = [
FileSpec {
name: &format!("{}.h", header),
template: include_str!("../templates/cpp_exported_header.template"),
dir: "include",
},
FileSpec {
name: &format!("{}.cc", header),
template: include_str!("../templates/cpp_source_file.template"),
dir: "",
},
FileSpec {
name: &format!("{}_flag_provider.h", header),
template: match codegen_mode {
CodegenMode::Production => {
include_str!("../templates/cpp_prod_flag_provider.template")
}
CodegenMode::Test => include_str!("../templates/cpp_test_flag_provider.template"),
},
dir: "",
},
];
files.iter().map(|file| generate_file(file, &context)).collect()
}
pub fn generate_file(file: &FileSpec, context: &Context) -> Result<OutputFile> {
let mut template = TinyTemplate::new(); let mut template = TinyTemplate::new();
template.add_template("cpp_code_gen", include_str!("../templates/cpp.template"))?; template.add_template(file.name, file.template)?;
let contents = template.render("cpp_code_gen", &context)?; let contents = template.render(file.name, &context)?;
let path = ["aconfig", &(header + ".h")].iter().collect(); let path: PathBuf = [&file.dir, &file.name].iter().collect();
Ok(OutputFile { contents: contents.into(), path }) Ok(OutputFile { contents: contents.into(), path })
} }
#[derive(Serialize)] #[derive(Serialize)]
struct Context { pub struct FileSpec<'a> {
pub name: &'a str,
pub template: &'a str,
pub dir: &'a str,
}
#[derive(Serialize)]
pub struct Context {
pub header: String, pub header: String,
pub cpp_namespace: String, pub cpp_namespace: String,
pub package: String, pub package: String,
pub readwrite: bool, pub readwrite: bool,
pub for_prod: bool,
pub class_elements: Vec<ClassElement>, pub class_elements: Vec<ClassElement>,
} }
#[derive(Serialize)] #[derive(Serialize)]
struct ClassElement { pub struct ClassElement {
pub readwrite: bool, pub readwrite: bool,
pub default_value: String, pub default_value: String,
pub flag_name: String, pub flag_name: String,
pub uppercase_flag_name: String,
pub device_config_namespace: String, pub device_config_namespace: String,
pub device_config_flag: String, pub device_config_flag: String,
} }
@ -73,6 +115,7 @@ fn create_class_element(package: &str, pf: &ProtoParsedFlag) -> ClassElement {
"false".to_string() "false".to_string()
}, },
flag_name: pf.name().to_string(), flag_name: pf.name().to_string(),
uppercase_flag_name: pf.name().to_string().to_ascii_uppercase(),
device_config_namespace: pf.namespace().to_string(), device_config_namespace: pf.namespace().to_string(),
device_config_flag: codegen::create_device_config_ident(package, pf.name()) device_config_flag: codegen::create_device_config_ident(package, pf.name())
.expect("values checked at flag parse time"), .expect("values checked at flag parse time"),
@ -82,51 +125,325 @@ fn create_class_element(package: &str, pf: &ProtoParsedFlag) -> ClassElement {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::collections::HashMap;
#[test] const EXPORTED_PROD_HEADER_EXPECTED: &str = r#"
fn test_generate_cpp_code() {
let parsed_flags = crate::test::parse_test_flags();
let generated =
generate_cpp_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter()).unwrap();
assert_eq!("aconfig/com_android_aconfig_test.h", format!("{}", generated.path.display()));
let expected = r#"
#ifndef com_android_aconfig_test_HEADER_H #ifndef com_android_aconfig_test_HEADER_H
#define com_android_aconfig_test_HEADER_H #define com_android_aconfig_test_HEADER_H
#include <server_configurable_flags/get_flags.h>
#include <string>
#include <memory>
#include <server_configurable_flags/get_flags.h>
using namespace server_configurable_flags; using namespace server_configurable_flags;
namespace com::android::aconfig::test { namespace com::android::aconfig::test {
static const bool disabled_ro() { class flag_provider_interface {
public:
virtual ~flag_provider_interface() = default;
virtual bool disabled_ro() = 0;
virtual bool disabled_rw() = 0;
virtual bool enabled_ro() = 0;
virtual bool enabled_rw() = 0;
virtual void override_flag(std::string const&, bool) {}
virtual void reset_overrides() {}
};
extern std::unique_ptr<flag_provider_interface> provider_;
extern std::string const DISABLED_RO;
extern std::string const DISABLED_RW;
extern std::string const ENABLED_RO;
extern std::string const ENABLED_RW;
inline bool disabled_ro() {
return false;
}
inline bool disabled_rw() {
return provider_->disabled_rw();
}
inline bool enabled_ro() {
return true;
}
inline bool enabled_rw() {
return provider_->enabled_rw();
}
inline void override_flag(std::string const& name, bool val) {
return provider_->override_flag(name, val);
}
inline void reset_overrides() {
return provider_->reset_overrides();
}
}
#endif
"#;
const EXPORTED_TEST_HEADER_EXPECTED: &str = r#"
#ifndef com_android_aconfig_test_HEADER_H
#define com_android_aconfig_test_HEADER_H
#include <string>
#include <memory>
#include <server_configurable_flags/get_flags.h>
using namespace server_configurable_flags;
namespace com::android::aconfig::test {
class flag_provider_interface {
public:
virtual ~flag_provider_interface() = default;
virtual bool disabled_ro() = 0;
virtual bool disabled_rw() = 0;
virtual bool enabled_ro() = 0;
virtual bool enabled_rw() = 0;
virtual void override_flag(std::string const&, bool) {}
virtual void reset_overrides() {}
};
extern std::unique_ptr<flag_provider_interface> provider_;
extern std::string const DISABLED_RO;
extern std::string const DISABLED_RW;
extern std::string const ENABLED_RO;
extern std::string const ENABLED_RW;
inline bool disabled_ro() {
return provider_->disabled_ro();
}
inline bool disabled_rw() {
return provider_->disabled_rw();
}
inline bool enabled_ro() {
return provider_->enabled_ro();
}
inline bool enabled_rw() {
return provider_->enabled_rw();
}
inline void override_flag(std::string const& name, bool val) {
return provider_->override_flag(name, val);
}
inline void reset_overrides() {
return provider_->reset_overrides();
}
}
#endif
"#;
const PROD_FLAG_PROVIDER_HEADER_EXPECTED: &str = r#"
#ifndef com_android_aconfig_test_flag_provider_HEADER_H
#define com_android_aconfig_test_flag_provider_HEADER_H
#include "com_android_aconfig_test.h"
namespace com::android::aconfig::test {
class flag_provider : public flag_provider_interface {
public:
virtual bool disabled_ro() override {
return false; return false;
} }
static const bool disabled_rw() { virtual bool disabled_rw() override {
return GetServerConfigurableFlag( return GetServerConfigurableFlag(
"aconfig_test", "aconfig_test",
"com.android.aconfig.test.disabled_rw", "com.android.aconfig.test.disabled_rw",
"false") == "true"; "false") == "true";
} }
static const bool enabled_ro() { virtual bool enabled_ro() override {
return true; return true;
} }
static const bool enabled_rw() { virtual bool enabled_rw() override {
return GetServerConfigurableFlag( return GetServerConfigurableFlag(
"aconfig_test", "aconfig_test",
"com.android.aconfig.test.enabled_rw", "com.android.aconfig.test.enabled_rw",
"true") == "true"; "true") == "true";
} }
};
} }
#endif #endif
"#; "#;
const TEST_FLAG_PROVIDER_HEADER_EXPECTED: &str = r#"
#ifndef com_android_aconfig_test_flag_provider_HEADER_H
#define com_android_aconfig_test_flag_provider_HEADER_H
#include "com_android_aconfig_test.h"
#include <unordered_map>
#include <unordered_set>
#include <cassert>
namespace com::android::aconfig::test {
class flag_provider : public flag_provider_interface {
private:
std::unordered_map<std::string, bool> overrides_;
std::unordered_set<std::string> flag_names_;
public:
flag_provider()
: overrides_(),
flag_names_() {
flag_names_.insert(DISABLED_RO);
flag_names_.insert(DISABLED_RW);
flag_names_.insert(ENABLED_RO);
flag_names_.insert(ENABLED_RW);
}
virtual bool disabled_ro() override {
auto it = overrides_.find(DISABLED_RO);
if (it != overrides_.end()) {
return it->second;
} else {
return false;
}
}
virtual bool disabled_rw() override {
auto it = overrides_.find(DISABLED_RW);
if (it != overrides_.end()) {
return it->second;
} else {
return GetServerConfigurableFlag(
"aconfig_test",
"com.android.aconfig.test.disabled_rw",
"false") == "true";
}
}
virtual bool enabled_ro() override {
auto it = overrides_.find(ENABLED_RO);
if (it != overrides_.end()) {
return it->second;
} else {
return true;
}
}
virtual bool enabled_rw() override {
auto it = overrides_.find(ENABLED_RW);
if (it != overrides_.end()) {
return it->second;
} else {
return GetServerConfigurableFlag(
"aconfig_test",
"com.android.aconfig.test.enabled_rw",
"true") == "true";
}
}
virtual void override_flag(std::string const& flag, bool val) override {
assert(flag_names_.count(flag));
overrides_[flag] = val;
}
virtual void reset_overrides() override {
overrides_.clear();
}
};
}
#endif
"#;
const SOURCE_FILE_EXPECTED: &str = r#"
#include "com_android_aconfig_test.h"
#include "com_android_aconfig_test_flag_provider.h"
namespace com::android::aconfig::test {
std::string const DISABLED_RO = "com.android.aconfig.test.disabled_ro";
std::string const DISABLED_RW = "com.android.aconfig.test.disabled_rw";
std::string const ENABLED_RO = "com.android.aconfig.test.enabled_ro";
std::string const ENABLED_RW = "com.android.aconfig.test.enabled_rw";
std::unique_ptr<flag_provider_interface> provider_ =
std::make_unique<flag_provider>();
}
"#;
fn test_generate_cpp_code(mode: CodegenMode) {
let parsed_flags = crate::test::parse_test_flags();
let generated =
generate_cpp_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter(), mode)
.unwrap();
let mut generated_files_map = HashMap::new();
for file in generated {
generated_files_map.insert(
String::from(file.path.to_str().unwrap()),
String::from_utf8(file.contents.clone()).unwrap(),
);
}
let mut target_file_path = String::from("include/com_android_aconfig_test.h");
assert!(generated_files_map.contains_key(&target_file_path));
assert_eq!( assert_eq!(
None, None,
crate::test::first_significant_code_diff( crate::test::first_significant_code_diff(
expected, match mode {
&String::from_utf8(generated.contents).unwrap() CodegenMode::Production => EXPORTED_PROD_HEADER_EXPECTED,
CodegenMode::Test => EXPORTED_TEST_HEADER_EXPECTED,
},
generated_files_map.get(&target_file_path).unwrap()
)
);
target_file_path = String::from("com_android_aconfig_test_flag_provider.h");
assert!(generated_files_map.contains_key(&target_file_path));
assert_eq!(
None,
crate::test::first_significant_code_diff(
match mode {
CodegenMode::Production => PROD_FLAG_PROVIDER_HEADER_EXPECTED,
CodegenMode::Test => TEST_FLAG_PROVIDER_HEADER_EXPECTED,
},
generated_files_map.get(&target_file_path).unwrap()
)
);
target_file_path = String::from("com_android_aconfig_test.cc");
assert!(generated_files_map.contains_key(&target_file_path));
assert_eq!(
None,
crate::test::first_significant_code_diff(
SOURCE_FILE_EXPECTED,
generated_files_map.get(&target_file_path).unwrap()
) )
); );
} }
#[test]
fn test_generate_cpp_code_for_prod() {
test_generate_cpp_code(CodegenMode::Production);
}
#[test]
fn test_generate_cpp_code_for_test() {
test_generate_cpp_code(CodegenMode::Test);
}
} }

View file

@ -143,12 +143,12 @@ pub fn create_java_lib(mut input: Input, codegen_mode: CodegenMode) -> Result<Ve
generate_java_code(package, parsed_flags.parsed_flag.iter(), codegen_mode) generate_java_code(package, parsed_flags.parsed_flag.iter(), codegen_mode)
} }
pub fn create_cpp_lib(mut input: Input) -> Result<OutputFile> { pub fn create_cpp_lib(mut input: Input, codegen_mode: CodegenMode) -> Result<Vec<OutputFile>> {
let parsed_flags = input.try_parse_flags()?; let parsed_flags = input.try_parse_flags()?;
let Some(package) = find_unique_package(&parsed_flags) else { let Some(package) = find_unique_package(&parsed_flags) else {
bail!("no parsed flags, or the parsed flags use different packages"); bail!("no parsed flags, or the parsed flags use different packages");
}; };
generate_cpp_code(package, parsed_flags.parsed_flag.iter()) generate_cpp_code(package, parsed_flags.parsed_flag.iter(), codegen_mode)
} }
pub fn create_rust_lib(mut input: Input) -> Result<OutputFile> { pub fn create_rust_lib(mut input: Input) -> Result<OutputFile> {

View file

@ -60,7 +60,13 @@ fn cli() -> Command {
.subcommand( .subcommand(
Command::new("create-cpp-lib") Command::new("create-cpp-lib")
.arg(Arg::new("cache").long("cache").required(true)) .arg(Arg::new("cache").long("cache").required(true))
.arg(Arg::new("out").long("out").required(true)), .arg(Arg::new("out").long("out").required(true))
.arg(
Arg::new("mode")
.long("mode")
.value_parser(EnumValueParser::<commands::CodegenMode>::new())
.default_value("production"),
),
) )
.subcommand( .subcommand(
Command::new("create-rust-lib") Command::new("create-rust-lib")
@ -163,9 +169,12 @@ fn main() -> Result<()> {
} }
Some(("create-cpp-lib", sub_matches)) => { Some(("create-cpp-lib", sub_matches)) => {
let cache = open_single_file(sub_matches, "cache")?; let cache = open_single_file(sub_matches, "cache")?;
let generated_file = commands::create_cpp_lib(cache)?; let mode = get_required_arg::<CodegenMode>(sub_matches, "mode")?;
let generated_files = commands::create_cpp_lib(cache, *mode)?;
let dir = PathBuf::from(get_required_arg::<String>(sub_matches, "out")?); let dir = PathBuf::from(get_required_arg::<String>(sub_matches, "out")?);
write_output_file_realtive_to_dir(&dir, &generated_file)?; generated_files
.iter()
.try_for_each(|file| write_output_file_realtive_to_dir(&dir, file))?;
} }
Some(("create-rust-lib", sub_matches)) => { Some(("create-rust-lib", sub_matches)) => {
let cache = open_single_file(sub_matches, "cache")?; let cache = open_single_file(sub_matches, "cache")?;

View file

@ -0,0 +1,48 @@
#ifndef {header}_HEADER_H
#define {header}_HEADER_H
#include <string>
#include <memory>
{{ if readwrite }}
#include <server_configurable_flags/get_flags.h>
using namespace server_configurable_flags;
{{ endif }}
namespace {cpp_namespace} \{
class flag_provider_interface \{
public:
virtual ~flag_provider_interface() = default;
{{ for item in class_elements}}
virtual bool {item.flag_name}() = 0;
{{ endfor }}
virtual void override_flag(std::string const&, bool) \{}
virtual void reset_overrides() \{}
};
extern std::unique_ptr<flag_provider_interface> provider_;
{{ for item in class_elements}}
extern std::string const {item.uppercase_flag_name};{{ endfor }}
{{ for item in class_elements}}
inline bool {item.flag_name}() \{
{{ if for_prod }}
{{ if not item.readwrite- }}
return {item.default_value};
{{ -else- }}
return provider_->{item.flag_name}();
{{ -endif }}
{{ -else- }}
return provider_->{item.flag_name}();
{{ -endif }}
}
{{ endfor }}
inline void override_flag(std::string const& name, bool val) \{
return provider_->override_flag(name, val);
}
inline void reset_overrides() \{
return provider_->reset_overrides();
}
}
#endif

View file

@ -1,12 +1,12 @@
#ifndef {header}_HEADER_H #ifndef {header}_flag_provider_HEADER_H
#define {header}_HEADER_H #define {header}_flag_provider_HEADER_H
{{ if readwrite }} #include "{header}.h"
#include <server_configurable_flags/get_flags.h>
using namespace server_configurable_flags;
{{ endif }}
namespace {cpp_namespace} \{ namespace {cpp_namespace} \{
class flag_provider : public flag_provider_interface \{
public:
{{ for item in class_elements}} {{ for item in class_elements}}
static const bool {item.flag_name}() \{ virtual bool {item.flag_name}() override \{
{{ if item.readwrite- }} {{ if item.readwrite- }}
return GetServerConfigurableFlag( return GetServerConfigurableFlag(
"{item.device_config_namespace}", "{item.device_config_namespace}",
@ -17,5 +17,6 @@ namespace {cpp_namespace} \{
{{ -endif }} {{ -endif }}
} }
{{ endfor }} {{ endfor }}
};
} }
#endif #endif

View file

@ -0,0 +1,10 @@
#include "{header}.h"
#include "{header}_flag_provider.h"
namespace {cpp_namespace} \{
{{ for item in class_elements}}
std::string const {item.uppercase_flag_name} = "{item.device_config_flag}";{{ endfor }}
std::unique_ptr<flag_provider_interface> provider_ =
std::make_unique<flag_provider>();
}

View file

@ -0,0 +1,49 @@
#ifndef {header}_flag_provider_HEADER_H
#define {header}_flag_provider_HEADER_H
#include "{header}.h"
#include <unordered_map>
#include <unordered_set>
#include <cassert>
namespace {cpp_namespace} \{
class flag_provider : public flag_provider_interface \{
private:
std::unordered_map<std::string, bool> overrides_;
std::unordered_set<std::string> flag_names_;
public:
flag_provider()
: overrides_(),
flag_names_() \{
{{ for item in class_elements}}
flag_names_.insert({item.uppercase_flag_name});{{ endfor }}
}
{{ for item in class_elements}}
virtual bool {item.flag_name}() override \{
auto it = overrides_.find({item.uppercase_flag_name});
if (it != overrides_.end()) \{
return it->second;
} else \{
{{ if item.readwrite- }}
return GetServerConfigurableFlag(
"{item.device_config_namespace}",
"{item.device_config_flag}",
"{item.default_value}") == "true";
{{ -else- }}
return {item.default_value};
{{ -endif }}
}
}
{{ endfor }}
virtual void override_flag(std::string const& flag, bool val) override \{
assert(flag_names_.count(flag));
overrides_[flag] = val;
}
virtual void reset_overrides() override \{
overrides_.clear();
}
};
}
#endif