1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
| import subprocess
import sys
import re
def get_staged_py_files():
"""获取暂存区已修改的 .py 文件名列表"""
res = subprocess.run(
["git", "diff", "--cached", "--name-only", "--diff-filter=ACM"], capture_output=True, text=True
)
return [f for f in res.stdout.strip().split("\n") if f.endswith(".py")]
def process_file(path):
with open(path, encoding="utf-8") as f:
lines = f.readlines()
import_last_idx = -1
logger_idx = -1
import_pattern = re.compile(r"^(import |from )")
logger_pattern = re.compile(r"^\s*logger\s*=\s*logging\.getLogger\(\s*__name__\s*\)")
for idx, line in enumerate(lines):
if logger_pattern.match(line):
logger_idx = idx
if import_pattern.match(line):
import_last_idx = idx
if logger_idx == -1:
return
while True:
if lines[import_last_idx + 1].strip():
import_last_idx += 1
else:
break
logger_line = lines.pop(logger_idx)
if logger_idx < import_last_idx:
import_last_idx -= 1
lines.insert(import_last_idx + 1, logger_line)
with open(path, "w", encoding="utf-8") as f:
f.writelines(lines)
print(f"Moved logger in {path}")
def main():
files = get_staged_py_files()
for f in files:
process_file(f)
if __name__ == "__main__":
main()
|