• Keras 3.0发布:全面拥抱 PyTorch!


     

     

    Keras 3.0 介绍

    https://keras.io/keras_3/

    Keras 3.0 升级是对 Keras 的全面重写,引入了一系列令人振奋的新特性,为深度学习领域带来了全新的可能性。

    多框架支持

    Keras 3.0 的最大亮点之一是支持多框架。Keras 3 实现了完整的 Keras API,并使其可用于 TensorFlow、JAX 和 PyTorch —— 包括一百多个层、数十种度量标准、损失函数、优化器和回调函数,以及 Keras 的训练和评估循环,以及 Keras 的保存和序列化基础设施。所有您熟悉和喜爱的 API 都在这里。

    大规模模型训练和部署

    新版本的 Keras 为大规模模型训练和部署提供了全新的能力。借助优化的算法和性能改进,现在您可以处理更大规模、更复杂的深度学习模型,而无需担心性能问题。

    使用任何来源的数据管道。

    Keras 3 的 fit()/evaluate()/predict()例程兼容 tf.data.Dataset 对象、PyTorch 的 DataLoader 对象、NumPy 数组和 Pandas 数据框,无论您使用的是哪个后端。您可以在 PyTorch 的 DataLoader 上训练 Keras 3 + TensorFlow 模型,或者在 tf.data.Dataset 上训练 Keras 3 + PyTorch 模型。

     

    案例1:搭配Pytorch训练

    https://keras.io/guides/custom_train_step_in_torch/

    • 导入环境

    1. import os
    2. # This guide can only be run with the torch backend.
    3. os.environ["KERAS_BACKEND"] = "torch"
    4. import torch
    5. import keras
    6. from keras import layers
    7. import numpy as np
    • 定义模型

    在 train_step() 方法的主体中,实现了一个常规的训练更新,类似于您已经熟悉的内容。重要的是,我们通过 self.compute_loss() 计算损失,它包装了传递给 compile() 的损失函数。

    1. class CustomModel(keras.Model):
    2.     def train_step(selfdata):
    3.         # Unpack the data. Its structure depends on your model and
    4.         # on what you pass to `fit()`.
    5.         x, y = data
    6.         # Call torch.nn.Module.zero_grad() to clear the leftover gradients
    7.         # for the weights from the previous train step.
    8.         self.zero_grad()
    9.         # Compute loss
    10.         y_pred = self(x, training=True)  # Forward pass
    11.         loss = self.compute_loss(y=y, y_pred=y_pred)
    12.         # Call torch.Tensor.backward() on the loss to compute gradients
    13.         # for the weights.
    14.         loss.backward()
    15.         trainable_weights = [v for v in self.trainable_weights]
    16.         gradients = [v.value.grad for v in trainable_weights]
    17.         # Update weights
    18.         with torch.no_grad():
    19.             self.optimizer.apply(gradients, trainable_weights)
    20.         # Update metrics (includes the metric that tracks the loss)
    21.         for metric in self.metrics:
    22.             if metric.name == "loss":
    23.                 metric.update_state(loss)
    24.             else:
    25.                 metric.update_state(y, y_pred)
    26.         # Return a dict mapping metric names to current value
    27.         # Note that it will include the loss (tracked in self.metrics).
    28.         return {m.name: m.result() for m in self.metrics}
    • 训练模型

    1. # Construct and compile an instance of CustomModel
    2. inputs = keras.Input(shape=(32,))
    3. outputs = keras.layers.Dense(1)(inputs)
    4. model = CustomModel(inputs, outputs)
    5. model.compile(optimizer="adam", loss="mse", metrics=["mae"])
    6. Just use `fit` as usual
    7. = np.random.random((100032))
    8. = np.random.random((10001))
    9. model.fit(x, y, epochs=3)

    案例2:自定义Pytorch流程

    https://keras.io/guides/writing_a_custom_training_loop_in_torch/

    • 导入环境

    1. import os
    2. # This guide can only be run with the torch backend.
    3. os.environ["KERAS_BACKEND"] = "torch"
    4. import torch
    5. import keras
    6. from keras import layers
    7. import numpy as np
    • 定义模型、加载数据集

    1. # Let's consider a simple MNIST model
    2. def get_model():
    3.     inputs = keras.Input(shape=(784,), name="digits")
    4.     x1 = keras.layers.Dense(64, activation="relu")(inputs)
    5.     x2 = keras.layers.Dense(64, activation="relu")(x1)
    6.     outputs = keras.layers.Dense(10, name="predictions")(x2)
    7.     model = keras.Model(inputs=inputs, outputs=outputs)
    8.     return model
    9. # Create load up the MNIST dataset and put it in a torch DataLoader
    10. # Prepare the training dataset.
    11. batch_size = 32
    12. (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    13. x_train = np.reshape(x_train, (-1784)).astype("float32")
    14. x_test = np.reshape(x_test, (-1784)).astype("float32")
    15. y_train = keras.utils.to_categorical(y_train)
    16. y_test = keras.utils.to_categorical(y_test)
    17. # Reserve 10,000 samples for validation.
    18. x_val = x_train[-10000:]
    19. y_val = y_train[-10000:]
    20. x_train = x_train[:-10000]
    21. y_train = y_train[:-10000]
    22. # Create torch Datasets
    23. train_dataset = torch.utils.data.TensorDataset(
    24.     torch.from_numpy(x_train), torch.from_numpy(y_train)
    25. )
    26. val_dataset = torch.utils.data.TensorDataset(
    27.     torch.from_numpy(x_val), torch.from_numpy(y_val)
    28. )
    29. # Create DataLoaders for the Datasets
    30. train_dataloader = torch.utils.data.DataLoader(
    31.     train_datasetbatch_size=batch_sizeshuffle=True
    32. )
    33. val_dataloader = torch.utils.data.DataLoader(
    34.     val_datasetbatch_size=batch_sizeshuffle=False
    35. )
    • 定义优化器

    1. # Instantiate a torch optimizer
    2. model = get_model()
    3. optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    4. # Instantiate a torch loss function
    5. loss_fn = torch.nn.CrossEntropyLoss()
    • 训练模型

    1. epochs = 3
    2. for epoch in range(epochs):
    3.     for step, (inputs, targets) in enumerate(train_dataloader):
    4.         # Forward pass
    5.         logits = model(inputs)
    6.         loss = loss_fn(logits, targets)
    7.         # Backward pass
    8.         model.zero_grad()
    9.         loss.backward()
    10.         # Optimizer variable updates
    11.         optimizer.step()
    12.         # Log every 100 batches.
    13.         if step % 100 == 0:
    14.             print(
    15.                 f"Training loss (for 1 batch) at step {step}{loss.detach().numpy():.4f}"
    16.             )
    17.             print(f"Seen so far: {(step + 1) * batch_size} samples")

     

     

     

  • 相关阅读:
    树莓派(以及各种派)使用指南
    【Vue】VueX 的语法详解(2)
    ELK快速搭建图文详细步骤
    Shiro【散列算法、Shiro会话、退出登录 、权限表设计、注解配置鉴权 】(五)-全面详解(学习总结---从入门到深化)
    SQL 力扣 LeetCode First Day
    【JavaScript复习八】内置对象String和Math
    2022-06-28管理心得
    免费研讨会 | 邀您体验 Ansys Zemax Enterprise 的 STAR 模块
    A星(A*、A Star)路径规划算法详解(附MATLAB代码)
    索引构建磁盘IO太高,巧用tmpfs让内存来帮忙
  • 原文地址:https://blog.csdn.net/m0_58552717/article/details/136343459