RISC-V MCU中文社区

Keras入门第2讲顺序模型与函数式模型

发表于 2023-05-30 09:33:38
0
2734
0

keras顺序模型与函数式模型

  • keras提供两种搭建模型的方式:

    1. 1. 顺序模型(也可看做是函数式模型的一种特殊情况)
    2. 2. 函数式模型

    两种方式优点缺点
    顺序模型单输入单输出,搭建比较简单,是多个网络层的线性堆叠,不发生分叉,编译速度快不能定义复杂模型
    函数式模型灵活,层与层之间可以任意连接可以定义复杂模型(如多输出模型、有向无环图,或具有共享层的模型)搭建复杂,编译速度慢

    顺序模型也可看做是函数式模型的一种特殊情况。

    以下以mnist为例分别进行说明:

    1 顺序模型Sequential

    顺序模型也即一层叠一层,所以模型不发生分叉,单输入单输出,这意味着顺序模型不支持共享图层或多输入多输出模型。

    第0讲顺序模型搭建mnist的代码如下:

    方式1:采用model.add 一层层添加

    import tensorflow as tf
    import tensorflow.keras as keras

    print(keras.__version__)

    (x_train, y_train), (x_valid, y_valid) = keras.datasets.mnist.load_data()
    assert x_train.shape == (60000, 28, 28)
    assert x_valid.shape == (10000, 28, 28)
    assert y_train.shape == (60000,)
    assert y_valid.shape == (10000,)

    # step1: use sequential
    model = keras.models.Sequential()

    # step2: add layer
    model.add(keras.layers.Flatten(input_shape=(x_train.shape[1], x_train.shape[2])))
    model.add(keras.layers.Dense(units=784, activation="relu", input_dim=784))
    model.add(keras.layers.Dense(units=10, activation="softmax"))

    # step3: compile model
    model.compile(optimizer="Adam", loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    print("model:")
    model.summary()

    # step4: train
    model.fit(x_train, y_train, batch_size=64, epochs=5)

    # step5: evaluate model
    model.Evaluate (x_valid, y_valid)

    # save model
    model.save('keras_mnist.h5')

    方式2:初始化Sequential时同时定义层

    import tensorflow as tf
    import tensorflow.keras as keras

    print(keras.__version__)

    (x_train, y_train), (x_valid, y_valid) = keras.datasets.mnist.load_data()
    assert x_train.shape == (60000, 28, 28)
    assert x_valid.shape == (10000, 28, 28)
    assert y_train.shape == (60000,)
    assert y_valid.shape == (10000,)

    # step1, step2: Init sequential
    model = keras.models.Sequential([
       keras.layers.Flatten(input_shape=(x_train.shape[1], x_train.shape[2])),
       keras.layers.Dense(units=784, activation="relu", input_dim=784),
       keras.layers.Dense(units=10, activation="softmax")
    ])

    # step3: compile model
    model.compile(optimizer="Adam", loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    print("model:")
    model.summary()

    # step4: train
    model.fit(x_train, y_train, batch_size=64, epochs=5)

    # step5: evaluate model
    model.evaluate(x_valid, y_valid)

    # save model
    model.save('keras_mnist.h5')

    2 函数式API模型

    顺序模型允许逐层创建模型以解决许多问题,但是有时我们需要创建更加复杂的模型,比如具有共享图层或多输入多输出的模型,这时我们可以使用keras提供的函数式API模型,它定义了每一层的输入输出,像函数调用一样,用上一层的输出作为这一层的入参,这一层的输出作为函数返回值,这就是函数式API的由来,(这一点不同与顺序模型,顺序模型每层的输入输出是一定的,不需要专门指定)。

    同样的,利用函数式API模型搭建与之前等价的mnist模型,代码如下:

    import tensorflow as tf
    import tensorflow.keras as keras

    print(keras.__version__)

    (x_train, y_train), (x_valid, y_valid) = keras.datasets.mnist.load_data()
    assert x_train.shape == (60000, 28, 28)
    assert x_valid.shape == (10000, 28, 28)
    assert y_train.shape == (60000,)
    assert y_valid.shape == (10000,)

    # step1: Model structure
    # 层的实例是可调用的,它以张量为参数,并且返回一个张量
    input = keras.Input(shape=(x_train.shape[1], x_train.shape[2]))
    x = keras.layers.Flatten()(input)
    x = keras.layers.Dense(units=784, activation="relu", input_dim=784)(x)
    output = keras.layers.Dense(units=10, activation="softmax")(x)

    # step2: define Model
    model = keras.Model(inputs=input, outputs=output)

    # step3: compile model
    model.compile(optimizer="Adam", loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    print("model:")
    model.summary()

    # step4: train
    model.fit(x_train, y_train, batch_size=64, epochs=5)

    # step5: evaluate model
    model.evaluate(x_valid, y_valid)

    # save model
    model.save('keras_mnist.h5')

    由这个例子可见:函数式API可以搭建顺序模型Sequential,所以顺序模型也可看做是函数式模型的一种特殊情况。

    还有另外一种Model子类的搭建方法,这是利用python call方法的一种写法,我觉得只是写法上的区别,本质也是函数式API,代码如下:

    import tensorflow as tf
    import tensorflow.keras as keras

    print(keras.__version__)

    (x_train, y_train), (x_valid, y_valid) = keras.datasets.mnist.load_data()
    assert x_train.shape == (60000, 28, 28)
    assert x_valid.shape == (10000, 28, 28)
    assert y_train.shape == (60000,)
    assert y_valid.shape == (10000,)

    # step1: Model structure
    class MnistModel(tf.keras.Model):
       def __init__(self):
           super(MnistModel, self).__init__()
           self.x0 = keras.layers.Flatten()
           self.x1 = keras.layers.Dense(units=784, activation="relu", input_dim=784)
           self.x2 = keras.layers.Dense(units=10, activation="softmax")

       def call(self, input):
           x = self.x0(input)
           x = self.x1(x)
           x = self.x2(x)
           return x

    # step2: define Model
    model = MnistModel()

    # step3: compile model
    model.compile(optimizer="Adam", loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    # step4: train
    model.fit(x_train, y_train, batch_size=64, epochs=5)

    # 要放到model.fit之后
    print("model:")
    model.summary()

    # step5: evaluate model
    model.evaluate(x_valid, y_valid)

    # save model (不能使用)
    # model.save('keras_mnist.h5')

    采用model子类方法有几个问题需要注意:

    1. 1. 采用model子类方法时,model.summary方法需要放到model.fit之后才能调用,否则报如下错误:

      ValueError: This model has not yet been built. Build the model first by calling `build()` or by calling the model on a batch of data.
    2. 2. 采用model子类方法时,调用model.save会失败:

      NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using `save_weights`.

    参考:

    1. 【keras入门】使用 keras 训练 MNIST 数据集

    2. Keras三种搭建模型的方式——序列式、函数式、Model子类

    3. keras建模的3种方式——序列模型、函数模型、子类模型

    喜欢0
    用户评论
    sureZ-ok

    sureZ-ok 实名认证

    懒的都不写签名

    积分
    问答
    粉丝
    关注
    专栏作者
    • RV-STAR 开发板
    • RISC-V处理器设计系列课程
    • 培养RISC-V大学土壤 共建RISC-V教育生态
    RV-STAR 开发板