跳转至

Python 内存剖析与优化

本文以一个矩阵相乘的场景为例,介绍了 memory_profilermemray 这两个剖析 Python 内存使用情况的工具。

np

数据

矩阵相乘涉及两个数据框,分别是 data_1data_2。生成它们的代码为:

Python
class TestMemory:
    def __init__(self):
        dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
        items_data1 = list(range(1, 501))
        items_data2 = list(range(201, 1001))

        index_data2 = pd.MultiIndex.from_product(
            [dates, items_data2], names=["date", "item"]
        )

        self.data_1 = pd.DataFrame(
            np.random.rand(len(dates), 500), index=dates, columns=items_data1
        )
        self.data_1.index.name = "date"
        self.data_2 = pd.DataFrame(
            np.random.rand(len(index_data2), 800),
            index=index_data2,
            columns=items_data2,
        )

data_1 是一个 365 行、500 列的数据框。它的索引为 2023 年的日频日期,列名为 1 到 500。

image-20240411225318131

data_2 是一个 365*800=292000 行、800 列的数据框。它的索引是一个 MultiIndex 类型,其中两层索引分别为 dateitem。它的列名为 201 到 1000,与 data_1 的列名存在交集。

image-20240411225502474

使用 pandas 直接相乘

Python
def multiply_and_sum_data_pd(self):
    result = (self.data_1 * self.data_2).sum(axis=1)
    return result

上述代码会将 data_1 中的每一行与 data_2 中相应日期的一个矩阵相乘。并且,虽然 data_1data_2 的列名并不一致,pandas 也能自动将相同的列名位置的元素相乘(即 201 到 500)。若没有相同的列名(即 1 到 200 和 501 到 1000),则结果为 NaN

image-20240411230631810

这样的写法十分简洁,阅读起来也容易理解。但是,当 data_1data_2 的数据量较大时,可能会遇到性能瓶颈,例如运算时间太长、消耗内存太大等。

使用 numpy 转换为 numpy.array 后相乘

我们提取 data_2indexcolumns,将待计算的数据用 to_numpy (官方不推荐使用 .values,详见 pandas.DataFrame.values 文档)转换为 numpy.array 后进行计算。改写后的代码更长,但计算逻辑与上文使用 pandas 直接相乘是一致的。

Python
def multiply_and_sum_data_np(self):
    index = self.data_2.index
    columns = self.data_2.columns
    self.data_2 = self.data_2.to_numpy().reshape(
        index.get_level_values("date").nunique(),
        index.get_level_values("item").nunique(),
        -1,
    )
    result = pd.Series(
        np.nansum(
            np.vstack(
                (self.data_1.reindex(columns=columns).to_numpy()[:, np.newaxis, :])
                * self.data_2
            ),
            axis=1,
        ),
        index=index,
    )
    return result

最终的代码如下:

manage-memory.py
import numpy as np
import pandas as pd

np.random.seed(0)


class TestMemory:
    def __init__(self):
        dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
        items_data1 = list(range(1, 501))
        items_data2 = list(range(201, 1001))

        index_data2 = pd.MultiIndex.from_product(
            [dates, items_data2], names=["date", "item"]
        )

        self.data_1 = pd.DataFrame(
            np.random.rand(len(dates), 500), index=dates, columns=items_data1
        )
        self.data_1.index.name = "date"
        self.data_2 = pd.DataFrame(
            np.random.rand(len(index_data2), 800),
            index=index_data2,
            columns=items_data2,
        )

    def multiply_and_sum_data_pd(self):
        result = (self.data_1 * self.data_2).sum(axis=1)
        return result

    def multiply_and_sum_data_np(self):
        index = self.data_2.index
        columns = self.data_2.columns
        self.data_2 = self.data_2.to_numpy().reshape(
            index.get_level_values("date").nunique(),
            index.get_level_values("item").nunique(),
            -1,
        )
        result = pd.Series(
            np.nansum(
                np.vstack(
                    (self.data_1.reindex(columns=columns).to_numpy()[:, np.newaxis, :])
                    * self.data_2
                ),
                axis=1,
            ),
            index=index,
        )
        return result


test_memory = TestMemory()
print(test_memory.multiply_and_sum_data_pd())
print(test_memory.multiply_and_sum_data_np())

下面我们介绍若干剖析 Python 内存使用情况的工具。

memory_profiler

安装 memory_profiler

Bash
pip install -U memory_profiler

定时记录内存使用情况

无需改动代码,我们可以直接在命令行中运行:

Bash
mprof run manage-memory.py # (1)
  1. manage-memory.py 是待剖析的程序文件。

这会定时记录程序运行时所占用的内存,并在当前目录生成一个 mprofile_xxxxx.dat 的文件。我们可以继续使用:

Bash
mprof plot

将绘制内存使用情况。

pd

上图是运行 print(test_memory.multiply_and_sum_data_pd()) 的内存使用情况,它表明使用 pandas 直接相乘的程序耗时约 8 秒,内存占用峰值约为 5400 MB。

下面我们运行 print(test_memory.multiply_and_sum_data_np()),同样使用:

Bash
mprof run manage-memory.py
mprof plot

np

上图表明使用 numpy 的方法耗时仅 3 秒多,内存占用也快速下降。

记录部分代码的内存使用情况

若仅需要记录部分代码(例如函数、方法等)的内存使用情况,可以为函数添加 @profile 装饰器:

Python
from memory_profiler import profile


class TestMemory:
    def __init__(self):
        pass

    @profile
    def multiply_and_sum_data_np(self):
        index = self.data_2.index
        columns = self.data_2.columns
        self.data_2 = self.data_2.to_numpy().reshape(
            index.get_level_values("date").nunique(),
            index.get_level_values("item").nunique(),
            -1,
        )
        result = pd.Series(
            np.nansum(
                np.vstack(
                    (self.data_1.reindex(columns=columns).to_numpy()[:, np.newaxis, :])
                    * self.data_2
                ),
                axis=1,
            ),
            index=index,
        )
        return result

再运行 python manage-memory.py 即可剖析这段函数的内存使用情况。

Text Only
Filename: manage-memory.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    34   1886.9 MiB   1886.9 MiB           1       @profile
    35                                             def multiply_and_sum_data_np(self):
    36   1886.9 MiB      0.0 MiB           1           index = self.data_2.index
    37   1886.9 MiB      0.0 MiB           1           columns = self.data_2.columns
    38   1918.9 MiB      0.0 MiB           2           self.data_2 = self.data_2.to_numpy().reshape(
    39   1902.4 MiB     15.5 MiB           1               index.get_level_values("date").nunique(),
    40   1918.9 MiB     16.5 MiB           1               index.get_level_values("item").nunique(),
    41   1918.9 MiB      0.0 MiB           1               -1,
    42                                                 )
    43   1918.9 MiB      3.9 MiB           2           result = pd.Series(
    44   1918.9 MiB  -1311.8 MiB           2               np.nansum(
    45   1918.9 MiB  -1839.8 MiB           2                   np.vstack(
    46   3703.6 MiB   1784.8 MiB           2                       (self.data_1.reindex(columns=columns).to_numpy()[:, np.newaxis, :])
    47   1921.4 MiB      0.0 MiB           1                       * self.data_2
    48                                                         ),
    49   1863.8 MiB    -55.0 MiB           1                   axis=1,
    50                                                     ),
    51    552.1 MiB  -1366.8 MiB           1               index=index,
    52                                                 )
    53    556.0 MiB  -1362.9 MiB           1           return result

memray

memray 是 Bloomberg 开源的一个内存剖析工具。

安装 memray

Bash
pip install memray

运行:

Bash
memray run --live manage-memory.py

就可以看到实时的内存使用情况。下面分别是使用 pandasnumpy 的方法所呈现的实时内存使用情况。

pandas

pd

numpy

np

可以看到,使用 pandas 方法占用的内存峰值更大、持续时间更长。

更多用法可以参考 memray 官方文档

其他工具

评论