欣淇
发布于 2026-05-29 / 0 阅读
0
0

🔥 keras:64,066 stars

🧠 Keras 3:一个框架同时跑JAX、PyTorch、TensorFlow,速度最高提升350%

你还在为一个模型要在TensorFlow和PyTorch之间来回改写而头疼?或者发现同样的架构在JAX上能跑出3倍速度,却因为代码不兼容只能干瞪眼?Keras 3直接把这个选择题给你端平了——它让你用一套API,想切哪个后端就切哪个后端,而且性能不降反升。

项目数据:300万开发者的选择

Keras 3目前在GitHub上有超过62k星标,全球近300万开发者在使用。它支持JAX、TensorFlow、PyTorch三大主流框架作为后端,再加上OpenVINO做纯推理。根据官方基准测试,在相同模型架构下,JAX后端相比其他框架能带来20%到350%的速度提升。这可不是PPT上的数字,你跑个Transformer或ResNet就能明显感觉到。

核心功能拆解:一套代码,三套引擎

1. 多后端自由切换

你不用再纠结“选哪个框架”。Keras 3把底层计算引擎抽象成可替换的组件。今天用JAX跑训练,明天切到TensorFlow做部署,代码一个字都不用改。

2. 兼容tf.keras的旧代码

如果你之前用tf.keras写的模型,只要保存格式是.keras,直接就能在Keras 3里跑。自定义层或训练逻辑也只需要几分钟就能改成后端无关的写法。

3. 原生支持PyTorch DataLoader和tf.data

不管你是PyTorch用户还是TensorFlow用户,都能直接用自己习惯的数据加载方式。Keras 3不挑食。

实操步骤:从安装到跑通第一个模型

第一步:安装核心库

pip install keras --upgrade

第二步:安装你需要的后端

选一个或几个都行:

# 安装TensorFlow后端
pip install tensorflow

# 安装JAX后端
pip install jax jaxlib

# 安装PyTorch后端
pip install torch

第三步:配置后端

在导入Keras之前,设置环境变量:

export KERAS_BACKEND="jax"

或者在Python代码里:

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras

注意:导入Keras之后就不能再切换后端了,所以这一步要在最前面做。

第四步:写一个完整的训练代码

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
from keras import layers

# 加载数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255

# 构建模型
model = keras.Sequential([
    layers.Dense(512, activation="relu"),
    layers.Dropout(0.2),
    layers.Dense(10, activation="softmax")
])

# 编译
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

# 训练
model.fit(x_train, y_train, batch_size=128, epochs=5, validation_split=0.2)

# 评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试准确率: {test_acc:.4f}")

这段代码你换成KERAS_BACKEND="torch""tensorflow",一样跑得通。

第五步:用OpenVINO做纯推理(可选)

import os
os.environ["KERAS_BACKEND"] = "openvino"

import keras

# 加载训练好的模型
model = keras.models.load_model("my_model.keras")

# 推理
predictions = model.predict(x_test[:10])

OpenVINO后端只支持model.predict(),不能做训练,但推理速度在Intel硬件上非常快。

避坑指南

  1. 后端必须在导入Keras前配置。如果你先import keras再设环境变量,系统不会报错,但后端不会切换,你会一脸懵逼。

  2. OpenVINO只能做推理。别想着用它来训练,model.fit()会直接报错。

  3. 自定义层要注意后端兼容。如果你用了tf.Tensor特有的方法(比如tf.shape),切换到JAX或PyTorch后端时会挂。建议用Keras提供的后端无关API,比如keras.ops.shape

  4. 模型保存格式要统一。Keras 3推荐用.keras格式,旧版的.h5虽然也能读,但有些自定义组件可能会丢失。

要点总结

  • Keras 3让你用同一套代码跑JAX、TensorFlow、PyTorch,切换后端只需改一个环境变量
  • JAX后端在多数模型上能带来20%到350%的速度提升
  • 兼容tf.keras旧代码,迁移成本低
  • 支持PyTorch DataLoader和tf.data,数据加载方式随你选
  • OpenVINO后端适合Intel硬件上的纯推理场景
  • 安装简单,一行pip install keras就能开始

你现在就可以把手上那个跑得慢的模型,切到JAX后端试试,说不定训练时间直接砍半。


评论