인공지능/공부

강화학습 과제 1 최종 - N Step TD 구현하기

이게될까 2024. 3. 27. 20:04
728x90
728x90

일단 주어진 코드입니다.

import random
import numpy as np

class GridWorld():
    def __init__(self):
        self.x=0
        self.y=0

    def step(self, a):
        if a == 0:
            self.move_left()
        elif a == 1:
            self.move_up()
        elif a == 2:
            self.move_right()
        elif a == 3:
            self.move_down()

        reward = -1
        done = self.is_done()
        return (self.x, self.y), reward, done

    def move_right(self):
        self.y += 1
        if self.y > 3:
            self.y = 3

    def move_left(self):
        self.y -= 1
        if self.y < 0:
            self.y = 0

    def move_up(self):
        self.x -= 1
        if self.x < 0:
            self.x = 0
            
    def move_down(self):
        self.x += 1
        if self.x > 3:
            self.x = 3

    def is_done(self):
        if self.x == 3 and self.y == 3:
            return True
        else :
            return False

    def get_state(self):
        return (self.x, self.y)

    def reset(self):
        self.x = 0
        self.y = 0
        return (self.x, self.y)

class Agent():
    def __init__(self):
        pass

    def select_action(self):
        coin = random.random()
        if coin < 0.25:
            action = 0
        elif coin < 0.5:
            action = 1
        elif coin < 0.75:
            action = 2
        else :
            action = 3
        return action         
            
            
def main():
    #TD
    env = GridWorld()
    agent = Agent()
    data = [[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]
    gamma = 1.0
    reward = -1
    alpha = 0.01

    for k in range(50000):
        done = False
        while not done:
            x, y = env.get_state()
            action = agent.select_action()
            (x_prime, y_prime), reward, done = env.step(action)
            data[x][y] = data[x][y] + alpha*(reward+gamma*data[x_prime][y_prime]-data[x][y])
        env.reset()

    for row in data:
        print(row)

if __name__== '__main__':
    main()

TD 식

내 리스트 구조이다.

def main():
    #TD
    env = GridWorld()
    agent = Agent()
    data = [[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]
    gamma = 1.0
    reward = -1
    alpha = 0.01
    n = 8
    
    for k in range(30000):
        done = False
        x_primes = []
        y_primes = []
        rewards = []
        dones = []
        x_primes.append(0) # t
        y_primes.append(0) # t
        for i in range(n): # t+1 ~ t+n
            action = agent.select_action()
            (x_prime, y_prime), reward, done = env.step(action)
            x_primes.append(x_prime)
            y_primes.append(y_prime)
            dones.append(done)
            if done :
                rewards.append(0)
                break
            rewards.append(reward)
        done = False
        while not done:
            step =data[x_primes[-1]][y_primes[-1]]  # v(S t+n)
            for i in rewards[::-1]:# t+n ~ t+1
                step = step * gamma + i
            #(x_primes,y_primes,rewards, dones,step) # t위치
            x= x_primes.pop(0)#t
            y= y_primes.pop(0)#t
            data[x][y] = data[x][y] + alpha*(step-data[x][y]) # step - v(S t)
            action = agent.select_action()
            (x_prime, y_prime), reward, done = env.step(action)
            if dones[-1] == False:
                x_primes.append(x_prime)
                y_primes.append(y_prime)
                rewards.append(reward)
                dones.append(done)
            reward = rewards.pop(0)
            done = dones.pop(0)
        env.reset()

    for row in data:
        print(row)

이게 맞는거 같은데 살려주세요 ㅎㅎ,...

 

728x90