import subprocess
import sys
import re
def get_staged_py_files():
"""获取暂存区已修改的 .py 文件名列表"""
res = subprocess.run(
["git", "diff", "--cached", "--name-only", "--diff-filter=ACM"], capture_output=True, text=True
)
return [f for f in res.stdout.strip().split("\n") if f.endswith(".py")]
def process_file(path):
with open(path, encoding="utf-8") as f:
lines = f.readlines()
# 记录所有import的行索引
import_last_idx = -1
logger_idx = -1
import_pattern = re.compile(r"^(import |from )")
logger_pattern = re.compile(r"^\s*logger\s*=\s*logging\.getLogger\(\s*__name__\s*\)")
for idx, line in enumerate(lines):
if logger_pattern.match(line):
logger_idx = idx
if import_pattern.match(line):
import_last_idx = idx
if logger_idx == -1:
return
while True:
if lines[import_last_idx + 1].strip():
import_last_idx += 1
else:
break
# 移动
logger_line = lines.pop(logger_idx)
# insert 之后,如果logger原位置在最后一个import前,那么import_last_idx需要 --
if logger_idx < import_last_idx:
import_last_idx -= 1
lines.insert(import_last_idx + 1, logger_line)
# 写回
with open(path, "w", encoding="utf-8") as f:
f.writelines(lines)
print(f"Moved logger in {path}")
def main():
files = get_staged_py_files()
for f in files:
process_file(f)
if __name__ == "__main__":
main()