🧠 Keras 3:一个框架同时跑JAX、PyTorch、TensorFlow,速度最高提升350%
你还在为一个模型要在TensorFlow和PyTorch之间来回改写而头疼?或者发现同样的架构在JAX上能跑出3倍速度,却因为代码不兼容只能干瞪眼?Keras 3直接把这个选择题给你端平了——它让你用一套API,想切哪个后端就切哪个后端,而且性能不降反升。
项目数据:300万开发者的选择
Keras 3目前在GitHub上有超过62k星标,全球近300万开发者在使用。它支持JAX、TensorFlow、PyTorch三大主流框架作为后端,再加上OpenVINO做纯推理。根据官方基准测试,在相同模型架构下,JAX后端相比其他框架能带来20%到350%的速度提升。这可不是PPT上的数字,你跑个Transformer或ResNet就能明显感觉到。
核心功能拆解:一套代码,三套引擎
1. 多后端自由切换
你不用再纠结“选哪个框架”。Keras 3把底层计算引擎抽象成可替换的组件。今天用JAX跑训练,明天切到TensorFlow做部署,代码一个字都不用改。
2. 兼容tf.keras的旧代码
如果你之前用tf.keras写的模型,只要保存格式是.keras,直接就能在Keras 3里跑。自定义层或训练逻辑也只需要几分钟就能改成后端无关的写法。
3. 原生支持PyTorch DataLoader和tf.data
不管你是PyTorch用户还是TensorFlow用户,都能直接用自己习惯的数据加载方式。Keras 3不挑食。
实操步骤:从安装到跑通第一个模型
第一步:安装核心库
pip install keras --upgrade
第二步:安装你需要的后端
选一个或几个都行:
# 安装TensorFlow后端
pip install tensorflow
# 安装JAX后端
pip install jax jaxlib
# 安装PyTorch后端
pip install torch
第三步:配置后端
在导入Keras之前,设置环境变量:
export KERAS_BACKEND="jax"
或者在Python代码里:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
注意:导入Keras之后就不能再切换后端了,所以这一步要在最前面做。
第四步:写一个完整的训练代码
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras import layers
# 加载数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255
# 构建模型
model = keras.Sequential([
layers.Dense(512, activation="relu"),
layers.Dropout(0.2),
layers.Dense(10, activation="softmax")
])
# 编译
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
# 训练
model.fit(x_train, y_train, batch_size=128, epochs=5, validation_split=0.2)
# 评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试准确率: {test_acc:.4f}")
这段代码你换成KERAS_BACKEND="torch"或"tensorflow",一样跑得通。
第五步:用OpenVINO做纯推理(可选)
import os
os.environ["KERAS_BACKEND"] = "openvino"
import keras
# 加载训练好的模型
model = keras.models.load_model("my_model.keras")
# 推理
predictions = model.predict(x_test[:10])
OpenVINO后端只支持model.predict(),不能做训练,但推理速度在Intel硬件上非常快。
避坑指南
-
后端必须在导入Keras前配置。如果你先
import keras再设环境变量,系统不会报错,但后端不会切换,你会一脸懵逼。 -
OpenVINO只能做推理。别想着用它来训练,
model.fit()会直接报错。 -
自定义层要注意后端兼容。如果你用了
tf.Tensor特有的方法(比如tf.shape),切换到JAX或PyTorch后端时会挂。建议用Keras提供的后端无关API,比如keras.ops.shape。 -
模型保存格式要统一。Keras 3推荐用
.keras格式,旧版的.h5虽然也能读,但有些自定义组件可能会丢失。
要点总结
- Keras 3让你用同一套代码跑JAX、TensorFlow、PyTorch,切换后端只需改一个环境变量
- JAX后端在多数模型上能带来20%到350%的速度提升
- 兼容
tf.keras旧代码,迁移成本低 - 支持PyTorch DataLoader和tf.data,数据加载方式随你选
- OpenVINO后端适合Intel硬件上的纯推理场景
- 安装简单,一行
pip install keras就能开始
你现在就可以把手上那个跑得慢的模型,切到JAX后端试试,说不定训练时间直接砍半。