diff --git a/java/com/google/testing/builddefs/GenTestRules.bzl b/java/com/google/testing/builddefs/GenTestRules.bzl index 3e1ba7301..17acf17b3 100644 --- a/java/com/google/testing/builddefs/GenTestRules.bzl +++ b/java/com/google/testing/builddefs/GenTestRules.bzl @@ -20,88 +20,87 @@ files. """ - -def GenTestRules(name, - test_files, - deps, - exclude_tests=[], - default_test_size="small", - small_tests=[], - medium_tests=[], - large_tests=[], - enormous_tests=[], - resources=[], - flaky_tests=[], - tags=[], - prefix="", - jvm_flags=[], - args=[], - visibility=None, - shard_count=1): - for test in _get_test_names(test_files): - if test in exclude_tests: - continue - test_size = default_test_size - if test in small_tests: - test_size = "small" - if test in medium_tests: - test_size = "medium" - if test in large_tests: - test_size = "large" - if test in enormous_tests: - test_size = "enormous" - flaky = 0 - if (test in flaky_tests) or ("flaky" in tags): - flaky = 1 - java_class = _package_from_path( - native.package_name() + "/" + _strip_right(test, ".java")) - package = java_class[:java_class.rfind(".")] - native.java_test(name = prefix + test, - runtime_deps = deps, - resources = resources, - size = test_size, - jvm_flags = jvm_flags, - args = args, - flaky = flaky, - tags = tags, - test_class = java_class, - visibility = visibility, - shard_count = shard_count) - +def GenTestRules( + name, + test_files, + deps, + exclude_tests = [], + default_test_size = "small", + small_tests = [], + medium_tests = [], + large_tests = [], + enormous_tests = [], + resources = [], + flaky_tests = [], + tags = [], + prefix = "", + jvm_flags = [], + args = [], + visibility = None, + shard_count = 1): + for test in _get_test_names(test_files): + if test in exclude_tests: + continue + test_size = default_test_size + if test in small_tests: + test_size = "small" + if test in medium_tests: + test_size = "medium" + if test in large_tests: + test_size = "large" + if test in enormous_tests: + test_size = "enormous" + flaky = 0 + if (test in flaky_tests) or ("flaky" in tags): + flaky = 1 + java_class = _package_from_path( + native.package_name() + "/" + _strip_right(test, ".java"), + ) + package = java_class[:java_class.rfind(".")] + native.java_test( + name = prefix + test, + runtime_deps = deps, + resources = resources, + size = test_size, + jvm_flags = jvm_flags, + args = args, + flaky = flaky, + tags = tags, + test_class = java_class, + visibility = visibility, + shard_count = shard_count, + ) def _get_test_names(test_files): - test_names = [] - for test_file in test_files: - if not test_file.endswith("Test.java"): - continue - test_names += [test_file[:-5]] - return test_names - - -def _package_from_path(package_path, src_impls=None): - src_impls = src_impls or ['javatests/', 'java/'] - for src_impl in src_impls: - if not src_impl.endswith('/'): - src_impl += '/' - index = _index_of_end(package_path, src_impl) - if index >= 0: - package_path = package_path[index:] - break - return package_path.replace('/', '.') + test_names = [] + for test_file in test_files: + if not test_file.endswith("Test.java"): + continue + test_names += [test_file[:-5]] + return test_names +def _package_from_path(package_path, src_impls = None): + src_impls = src_impls or ["javatests/", "java/"] + for src_impl in src_impls: + if not src_impl.endswith("/"): + src_impl += "/" + index = _index_of_end(package_path, src_impl) + if index >= 0: + package_path = package_path[index:] + break + return package_path.replace("/", ".") def _strip_right(str, suffix): - """Returns str without the suffix if it ends with suffix.""" - if str.endswith(suffix): - return str[0: len(str) - len(suffix)] - else: - return str - + """Returns str without the suffix if it ends with suffix.""" + if str.endswith(suffix): + return str[0:len(str) - len(suffix)] + else: + return str def _index_of_end(str, part): - """If part is in str, return the index of the first character after part. - Return -1 if part is not in str.""" - index = str.find(part) - if index >= 0: - return index + len(part) - return -1 + """If part is in str, return the index of the first character after part. + Return -1 if part is not in str.""" + index = str.find(part) + if index >= 0: + return index + len(part) + return -1