mirror of
https://github.com/scylladb/scylladb.git
synced 2026-05-28 18:50:53 +00:00
Adding an auto-backport.py script to handle backport automation instead of Mergify. The rules of backport are as follows: * Merged or Closed PRs with any backport/x.y label (one or more) and promoted-to-master label * Backport PR will be automatically assigned to the original PR author * In case of conflicts the backport PR will be open in the original autoor fork in draft mode. This will give the PR owner the option to resolve conflicts and push those changes to the PR branch (Today in Scylla when we have conflicts, the developers are forced to open another PR and manually close the backport PR opened by Mergify) * Fixing cherry-pick the wrong commit SHA. With the new script, we always take the SHA from the stable branch * Support backport for enterprise releases (from Enterprise branch) Fixes: https://github.com/scylladb/scylladb/issues/18973 Closes scylladb/scylladb#21302
191 lines
8.2 KiB
Python
Executable File
191 lines
8.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import os
|
|
import re
|
|
import sys
|
|
import tempfile
|
|
import logging
|
|
|
|
from github import Github, GithubException
|
|
from git import Repo, GitCommandError
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
try:
|
|
github_token = os.environ["GITHUB_TOKEN"]
|
|
except KeyError:
|
|
print("Please set the 'GITHUB_TOKEN' environment variable")
|
|
sys.exit(1)
|
|
|
|
|
|
def is_pull_request():
|
|
return '--pull-request' in sys.argv[1:]
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--repo', type=str, required=True, help='Github repository name')
|
|
parser.add_argument('--base-branch', type=str, default='refs/heads/master', help='Base branch')
|
|
parser.add_argument('--commits', default=None, type=str, help='Range of promoted commits.')
|
|
parser.add_argument('--pull-request', type=int, help='Pull request number to be backported')
|
|
parser.add_argument('--head-commit', type=str, required=is_pull_request(), help='The HEAD of target branch after the pull request specified by --pull-request is merged')
|
|
return parser.parse_args()
|
|
|
|
|
|
def create_pull_request(repo, new_branch_name, base_branch_name, pr, backport_pr_title, commits, is_draft=False):
|
|
pr_body = f'{pr.body}\n\n'
|
|
for commit in commits:
|
|
pr_body += f'- (cherry picked from commit {commit})\n\n'
|
|
pr_body += f'Parent PR: #{pr.number}'
|
|
if is_draft:
|
|
new_branch_name = f'{pr.user.login}:{new_branch_name}'
|
|
try:
|
|
backport_pr = repo.create_pull(
|
|
title=backport_pr_title,
|
|
body=pr_body,
|
|
head=new_branch_name,
|
|
base=base_branch_name,
|
|
draft=is_draft
|
|
)
|
|
logging.info(f"Pull request created: {backport_pr.html_url}")
|
|
backport_pr.add_to_assignees(pr.user)
|
|
logging.info(f"Assigned PR to original author: {pr.user}")
|
|
return backport_pr
|
|
except GithubException as e:
|
|
if 'A pull request already exists' in str(e):
|
|
logging.warning(f'A pull request already exists for {pr.user}:{new_branch_name}')
|
|
else:
|
|
logging.error(f'Failed to create PR: {e}')
|
|
|
|
|
|
def get_pr_commits(repo, pr, stable_branch, start_commit=None):
|
|
commits = []
|
|
if pr.merged:
|
|
merge_commit = repo.get_commit(pr.merge_commit_sha)
|
|
if len(merge_commit.parents) > 1: # Check if this merge commit includes multiple commits
|
|
commits.append(pr.merge_commit_sha)
|
|
else:
|
|
if start_commit:
|
|
promoted_commits = repo.compare(start_commit, stable_branch).commits
|
|
else:
|
|
promoted_commits = repo.get_commits(sha=stable_branch)
|
|
for commit in pr.get_commits():
|
|
for promoted_commit in promoted_commits:
|
|
commit_title = commit.commit.message.splitlines()[0]
|
|
# In Scylla-pkg and scylla-dtest, for example,
|
|
# we don't create a merge commit for a PR with multiple commits,
|
|
# according to the GitHub API, the last commit will be the merge commit,
|
|
# which is not what we need when backporting (we need all the commits).
|
|
# So here, we are validating the correct SHA for each commit so we can cherry-pick
|
|
if promoted_commit.commit.message.startswith(commit_title):
|
|
commits.append(promoted_commit.sha)
|
|
|
|
elif pr.state == 'closed':
|
|
events = pr.get_issue_events()
|
|
for event in events:
|
|
if event.event == 'closed':
|
|
commits.append(event.commit_id)
|
|
return commits
|
|
|
|
|
|
def backport(repo, pr, version, commits, backport_base_branch, user):
|
|
with (tempfile.TemporaryDirectory() as local_repo_path):
|
|
try:
|
|
new_branch_name = f'backport/{pr.number}/to-{version}'
|
|
backport_pr_title = f'[Backport {version}] {pr.title}'
|
|
repo_local = Repo.clone_from(f'https://{user.login}:{github_token}@github.com/{repo.full_name}.git', local_repo_path, branch=backport_base_branch)
|
|
repo_local.git.checkout(b=new_branch_name)
|
|
try:
|
|
fork_repo = pr.user.get_repo(repo.name)
|
|
except Exception as e:
|
|
print(f"Error retrieving repository: {e}")
|
|
# Since Scylla core repo was modified a few years ago to ScyllaDB,
|
|
# some developers may have forks based on the original name `scylla`
|
|
print(f"{pr.user.login} fork repo is not {repo.full_name}, trying {repo.organization.login}/scylla")
|
|
fork_repo = pr.user.get_repo('scylla')
|
|
fork_repo_url = f'https://{user.login}:{github_token}@github.com/{fork_repo.full_name}.git'
|
|
repo_local.create_remote('fork', fork_repo_url)
|
|
remote = 'origin'
|
|
is_draft = False
|
|
for commit in commits:
|
|
try:
|
|
repo_local.git.cherry_pick(commit, '-m1', '-x')
|
|
except GitCommandError as e:
|
|
logging.warning(f'Cherry-pick conflict on commit {commit}: {e}')
|
|
remote = 'fork'
|
|
is_draft = True
|
|
repo_local.git.add(A=True)
|
|
repo_local.git.cherry_pick('--continue')
|
|
repo_local.git.push(remote, new_branch_name, force=True)
|
|
create_pull_request(repo, new_branch_name, backport_base_branch, pr, backport_pr_title, commits,
|
|
is_draft=is_draft)
|
|
|
|
except GitCommandError as e:
|
|
logging.warning(f"GitCommandError: {e}")
|
|
|
|
|
|
def create_pr_comment_and_remove_label(pr):
|
|
comment_body = f':warning: @{pr.user.login} PR body does not contain a valid reference to an issue '
|
|
comment_body += ' based on [linking-a-pull-request-to-an-issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/using-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword)'
|
|
comment_body += ' and can not be backported\n\n'
|
|
comment_body += 'The following labels were removed:\n'
|
|
labels = pr.get_labels()
|
|
pattern = re.compile(r"backport/\d+\.\d+$")
|
|
for label in labels:
|
|
if pattern.match(label.name):
|
|
print(f"Removing label: {label.name}")
|
|
comment_body += f'- {label.name}\n'
|
|
pr.remove_from_labels(label)
|
|
comment_body += f'\nPlease add the relevant backport labels after PR body is fixed'
|
|
pr.create_issue_comment(comment_body)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
base_branch = args.base_branch.split('/')[2]
|
|
promoted_label = 'promoted-to-master'
|
|
repo_name = args.repo
|
|
if 'scylla-enterprise' in args.repo:
|
|
promoted_label = 'promoted-to-enterprise'
|
|
stable_branch = base_branch
|
|
backport_branch = 'branch-'
|
|
|
|
backport_label_pattern = re.compile(r'backport/\d+\.\d+$')
|
|
|
|
g = Github(github_token)
|
|
repo = g.get_repo(repo_name)
|
|
user = g.get_user()
|
|
closed_prs = []
|
|
start_commit = None
|
|
|
|
if args.commits:
|
|
start_commit, end_commit = args.commits.split('..')
|
|
commits = repo.compare(start_commit, end_commit).commits
|
|
for commit in commits:
|
|
for pr in commit.get_pulls():
|
|
closed_prs.append(pr)
|
|
if args.pull_request:
|
|
start_commit = args.head_commit
|
|
pr = repo.get_pull(args.pull_request)
|
|
closed_prs = [pr]
|
|
|
|
for pr in closed_prs:
|
|
labels = [label.name for label in pr.labels]
|
|
backport_labels = [label for label in labels if backport_label_pattern.match(label)]
|
|
if promoted_label not in labels:
|
|
print(f'no {promoted_label} label: {pr.number}')
|
|
continue
|
|
if not backport_labels:
|
|
print(f'no backport label: {pr.number}')
|
|
continue
|
|
commits = get_pr_commits(repo, pr, stable_branch, start_commit)
|
|
logging.info(f"Found PR #{pr.number} with commit {commits} and the following labels: {backport_labels}")
|
|
for backport_label in backport_labels:
|
|
version = backport_label.replace('backport/', '')
|
|
backport_base_branch = backport_label.replace('backport/', backport_branch)
|
|
backport(repo, pr, version, commits, backport_base_branch, user)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|