1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import subprocess

import sys

import re




def get_staged_py_files():

    """获取暂存区已修改的 .py 文件名列表"""

    res = subprocess.run(

        ["git", "diff", "--cached", "--name-only", "--diff-filter=ACM"], capture_output=True, text=True

    )

    return [f for f in res.stdout.strip().split("\n") if f.endswith(".py")]




def process_file(path):

    with open(path, encoding="utf-8") as f:

        lines = f.readlines()



    # 记录所有import的行索引

    import_last_idx = -1

    logger_idx = -1

    import_pattern = re.compile(r"^(import |from )")

    logger_pattern = re.compile(r"^\s*logger\s*=\s*logging\.getLogger\(\s*__name__\s*\)")



    for idx, line in enumerate(lines):

        if logger_pattern.match(line):

            logger_idx = idx

        if import_pattern.match(line):

            import_last_idx = idx



    if logger_idx == -1:

        return



    while True:

        if lines[import_last_idx + 1].strip():

            import_last_idx += 1

        else:

            break



    # 移动

    logger_line = lines.pop(logger_idx)

    # insert 之后,如果logger原位置在最后一个import前,那么import_last_idx需要 --

    if logger_idx < import_last_idx:

        import_last_idx -= 1

    lines.insert(import_last_idx + 1, logger_line)



    # 写回

    with open(path, "w", encoding="utf-8") as f:

        f.writelines(lines)

    print(f"Moved logger in {path}")




def main():

    files = get_staged_py_files()

    for f in files:

        process_file(f)




if __name__ == "__main__":

    main()