인공지능/공부

FCN - tensorflow

이게될까 2023. 11. 15. 18:53
728x90
728x90
import tensorflow as tf
import numpy as np
import pandas as pd
from keras.datasets.mnist import load_data
from keras.models import Sequential, Model
from keras.layers import Dense, Input ,Flatten, Dropout
from keras.utils import plot_model, to_categorical
from keras.regularizers import l2
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount('/content/drive')
tf.debugging.set_log_device_placement (False)
train_data = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/fashion-mnist_train.csv')
test_data = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/fashion-mnist_test.csv')
#tf.debugging.set_log_device_placement(True)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # 텐서플로가 첫 번째 GPU만 사용하도록 제한
  try:
    tf.config.set_visible_devices(gpus[0], 'GPU')
  except RuntimeError as e:
    # 프로그램 시작시에 접근 가능한 장치가 설정되어야만 합니다
    print(e)
train_data = np.array(train_data)
test_data = np.array(test_data)
x = train_data[:,1:]
y =  train_data[:,0]
x_test = test_data[:,1:]
y_test = test_data[:,0]
x = x/255
x_test = x_test / 255
y = to_categorical(y)
y_test = to_categorical(y_test)
input1 = Input(shape = 784)
f=Dense(392,activation = 'relu')(input1)
f = Dropout (0.1)(f)
f=Dense(392,activation = 'relu')(f)
f = Dense(181,activation = 'relu', kernel_regularizer = l2(0.01))(f)
f=Dense(90,activation = 'relu')(f)
f = Dense(40,activation = 'relu', kernel_regularizer = l2(0.01))(f)
output=Dense(10,activation = 'softmax')(f)
model2 = Model(inputs = input1 , outputs = output)
plot_model(model2,show_shapes = True)

model2.compile(optimizer='adam',loss = 'categorical_crossentropy',metrics=['acc'])
hist = model2.fit(x,y,epochs = 80, batch_size = 500, validation_split = 0.3)
plt.subplot(1,2,1)
plt.plot(hist.history['loss'],label = 'train loss')
plt.plot(hist.history['val_loss'], label = 'val loss')
plt.legend()
plt.subplot(1,2,2)
plt.plot(hist.history['acc'],label = 'train acc')
plt.plot(hist.history['val_acc'], label = 'val acc')
plt.legend()

이러한 형식으로 네트워크를 만들 수 있다. 이건 확실하게 오버피팅이 발생한 그래프이다.

 

 

 

728x90

'인공지능 > 공부' 카테고리의 다른 글

AI 챌린지 예선  (1) 2023.11.24
FCN quiz  (0) 2023.11.16
1 vs all classification  (0) 2023.11.15
Logistic regression + regularized  (1) 2023.11.14
logistic Regression  (0) 2023.11.13