Created at : 2024-06-23 13:33
Auther: Soo.Y

์ฝ”๋žฉ(Colab) ํ™˜๊ฒฝ์—์„œ MLX๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ MLP(Multi-Layer Perceptron)์„ ๊ตฌํ˜„ํ•˜๋Š” ๋‚ด์šฉ์ž…๋‹ˆ๋‹ค. ์˜ˆ์ œ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๋ฐ์ดํ„ฐ๋Š” mnist์ž…๋‹ˆ๋‹ค.

MLX install

์ฝ”๋žฉ์—๋Š” MLX๋ฅผ ์„ค์น˜ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ pip install mlx์„ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

Mnist data set

์˜ˆ์ œ ์‚ฌ์šฉํ•  minist๋ฅผ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•œ ์ฝ”๋“œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

import gzip
import os
import pickle
from urllib import request
import numpy as np
 
def mnist(
	save_dir="/tmp",
	base_url="https://raw.githubusercontent.com/fgnt/mnist/master/",
	filename="mnist.pkl",
):
 
	def download_and_save(save_file):
		filename = [
			["training_images",  "train-images-idx3-ubyte.gz"],
			["test_images",  "t10k-images-idx3-ubyte.gz"],
			["training_labels",  "train-labels-idx1-ubyte.gz"],
			["test_labels",  "t10k-labels-idx1-ubyte.gz"],
		]
 
		mnist = {}
 
		for name in filename:
			out_file = os.path.join("/tmp", name[1])
			request.urlretrieve(base_url + name[1], out_file)
 
		for name in filename[:2]:
			out_file = os.path.join("/tmp", name[1])
			with gzip.open(out_file,  "rb")  as f:
				mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
					-1,  28 * 28
				)
 
		for name in filename[-2:]:
			out_file = os.path.join("/tmp", name[1])
 
			with gzip.open(out_file,  "rb")  as f:
				mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
			with  open(save_file,  "wb")  as f:
				pickle.dump(mnist, f)
 
	def preproc(x):
		return x.astype(np.float32) / 255.0
 
	save_file = os.path.join(save_dir, filename)
	if not os.path.exists(save_file):
		download_and_save(save_file)
	with open(save_file,  "rb")  as f:
		mnist = pickle.load(f)
		
	mnist["training_images"] = preproc(mnist["training_images"])
	mnist["test_images"] = preproc(mnist["test_images"])
 
	return (
		mnist["training_images"],
		mnist["training_labels"].astype(np.uint32),
		mnist["test_images"],
		mnist["test_labels"].astype(np.uint32),
	)
 
train_x, train_y, test_x, test_y = mnist()

ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋Š” 60,000๊ฐœ, ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋Š” 10,000๊ฐœ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

train_x.shape, train_x

((60000, 784),
array([[0., 0., 0., โ€ฆ, 0., 0., 0.],
[0., 0., 0., โ€ฆ, 0., 0., 0.],
[0., 0., 0., โ€ฆ, 0., 0., 0.],
โ€ฆ,
[0., 0., 0., โ€ฆ, 0., 0., 0.],
[0., 0., 0., โ€ฆ, 0., 0., 0.],
[0., 0., 0., โ€ฆ, 0., 0., 0.]], dtype=float32))

train_y.shape, train_y

((60000,), array([5, 0, 4, โ€ฆ, 5, 6, 8], dtype=uint32))

MLP ๋งŒ๋“ค๊ธฐ

๋ชจ๋“ˆ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

import sys
import argparse
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

MLP class

MLX์—์„œ๋Š” ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•  ๋•Œ ๋งค์ง๋ฉ”์„œ๋“œ __init__๊ณผ __call__์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. pytorch์—์„œ๋Š” forward ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•˜๋Š” ๊ฒƒ๊ณผ ์œ ์‚ฌํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.(์‹ค์ œ๋กœ๋Š” pytorch์—์„œ forward๋ฅผ ์ž‘์„ฑํ•˜์ง€๋งŒ __call__์—์„œ ์ด๋ฏธ forward ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜๋„๋ก ๊ตฌํ˜„๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

class  MLP(nn.Module):
 
	def __init__(
		self,
		num_layers: int, # hidden layer์˜ ์ˆ˜
		input_dim: int, hidden_dim: int, output_dim: int
	):
 
	super().__init__()
		layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
		self.layers = [
			nn.Linear(idim, odim)
			for idim, odim in  zip(layer_sizes[:-1], layer_sizes[1:])
		]
 
	def __call__(self, x):
		for l in  self.layers[:-1]:
			x = nn.relu(l(x))
		return  self.layers[-1](x)

Batch ํ•จ์ˆ˜

๋ฐฐ์น˜ ๋‹จ์œ„ ๋งˆ๋‹ค ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ํ•จ์ˆ˜๋ฅผ ๊ตฌํ˜„ํ•˜๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค. ํ•จ์ˆ˜์—๋Š” ๋ฐฐ์น˜์˜ ํฌ๊ธฐ์ธ batch_size, ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ X์™€ ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ์˜ ์ •๋‹ต์ธ y๋ฅผ ๋ฐ›๋„๋ก ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์˜ ์…”ํ”Œ์„ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด์„œ np.random.permutation์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฌด์ž‘์œ„๋กœ ๋ฐ์ดํ„ฐ๊ฐ€ ์„ž์ด๋„๋ก ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋งˆ์ง€๋ง‰์œผ๋กœ mx.array์„ ์‚ฌ์šฉํ•˜์—ฌ MLX์˜ array ๊ตฌ์กฐ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

def  batch_iterate(batch_size, X, y, suffle=True):
	if suffle:
		perm = mx.array(np.random.permutation(y.size))
	else:
		perm = mx.array(np.range(y.size))
	for s in  range(0, y.size, batch_size):
		ids = perm[s : s + batch_size]
		yield mx.array(X[ids]), mx.array(y[ids])

์ž‘์„ฑ๋œ ์ฝ”๋“œ๊ฐ€ ์ž˜ ์ž‘๋™ํ•˜๋Š”์ง€ ์•„๋ž˜ ์ฝ”๋“œ์™€ ๊ฐ™์ด ์‹คํ–‰ํ•ด๋ณด๋ฉด X, y์— ์ฒซ ๋ฒˆ์งธ์™€ ๋‘ ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ๊ฐ€ ๋“ค์–ด๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

X, y = next(batch_iterate(2, train_x, train_y,  False))

๋˜ํ•œ, type(X)๋ฅผ ์‹คํ–‰ํ•ด๋ณด๋ฉด mlx.core.array๋กœ ๋‚˜์˜ต๋‹ˆ๋‹ค.

์†์‹คํ•จ์ˆ˜

๋ฌธ์ œ์— ์ ํ•ฉํ•œ ํ•จ์ˆ˜๋ฅผ ์„ ํƒํ•˜์—ฌ ๋ถˆ๋Ÿฌ์˜ค๋ฉด ๋ฉ๋‹ˆ๋‹ค. MLX์—์„œ๋„ loss ํ•จ์ˆ˜๋Š” ๋”ฐ๋กœ ๊ตฌํ˜„๋œ ํ•จ์ˆ˜๋“ค์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

def  loss_fn(model, X, y):
	return nn.losses.cross_entropy(model(X), y, reduction="mean")

MLX์—์„œ๋Š” nn.value_and_grad๋ฅผ ์‚ฌ์šฉํ•ด์„œ wrapperํ•จ์ˆ˜๋กœ ๋งŒ๋“ค์–ด์ฃผ๊ณ , ์ด wrapperํ•จ์ˆ˜๋ฅผ ์‹คํ–‰ํ•˜๋ฉด loss ๊ฐ’๊ณผ gradient ๊ฐ’์ด ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค.

์ฃผ์˜! nn์€ pytorch์—์„œ ์‚ฌ์šฉํ•˜๋Š” nn์ด ์•„๋‹™๋‹ˆ๋‹ค.

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grad = loss_and_grad_fn(model, X, y)

Optimizer ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

์‚ฌ์šฉํ•˜๊ณ ์ž ํ•˜๋Š” optimizer๋ฅผ ์„ ํƒํ•˜์—ฌ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. Optimizer๋„ MLX์—์„œ ๊ตฌํ˜„๋œ ํ•จ์ˆ˜๋“ค์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

optimizer๋งˆ๋‹ค ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ์ข…๋ฅ˜๋Š” ๋‹ค๋ฅด์ง€๋งŒ ๊ณตํ†ต ํŒŒ๋ผ๋ฏธํ„ฐ์ธ learning rate๋งŒ ์ž…๋ ฅํ•ด์ฃผ๋ฉด ๋‚˜๋จธ์ง€๋Š” ๊ธฐ๋ณธ ๊ฐ’์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.

test_lr = 1
optimizer = optim.SGD(learning_rate=test_lr)

๊ทธ ๋’ค์— ์•ž์—์„œ ๊ณ„์‚ฐํ•œ gradient ๊ฐ’๊ณผ ํ•จ๊ป˜ ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ๋„˜๊ฒจ์ฃผ๋ฉด weight๊ฐ€ ์—…๋ฐ์ดํŠธ ๋ฉ๋‹ˆ๋‹ค.

optimizer.update(model, grad)

gradient update ์ง์ ‘ ๊ณ„์‚ฐ

์•„๋ž˜ ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด์„œ gradient์˜ updateํ•˜๋Š” ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ MLX์—์„œ updateํ•œ ๊ณ„์‚ฐ ๊ฒฐ๊ณผ์™€ ๋น„๊ตํ•˜์—ฌ ๊ฒ€์ฆํ•ด ๋ด…์‹œ๋‹ค.

weight_myself = init_parameters['layers'][1]['weight'] - grad['layers'][1]['weight'] * 1
weight_mlx = model.parameters()['layers'][1]['weight']
 
print(weight_myself)
print(weight_mlx)

Validation

ํ›ˆ๋ จ์ด ์™„๋ฃŒ๋˜์—ˆ์œผ๋ฉด, test data set์œผ๋กœ ๋ชจ๋ธ์„ ๊ฒ€์ฆํ•˜๋Š” ๊ณผ์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฒˆ ์˜ˆ์ œ์—์„œ๋Š” ์ˆซ์ž 0๋ถ€ํ„ฐ 9๊นŒ์ง€์˜ ๋ ˆ์ด๋ธ”์„ ๋งž์ถ”๋Š” ๋ฌธ์ œ์ด๋ฏ€๋กœ argmax๋ฅผ ์‚ฌ์šฉํ•ด์„œ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์ด ๊ณ„์‚ฐ๋œ ๋ ˆ์ด๋ธ”์„ ์ฐพ์•„์ค˜์„œ ์ •๋‹ต๊ณผ ๋น„๊ตํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

mlx_test_x = mx.array(test_x)
mlx_test_y = mx.array(test_y)
mx.mean(mx.argmax(model(mlx_test_x), axis=1) == mlx_test_y)

MLP ์ „์ฒด ์ฝ”๋“œ

๋งํฌ : ์ฝ”๋“œ๋ฅผ ์–ด๋””์—์„œ ๊ณต์œ ํ•ด์•ผ ํ• ๊นŒ์š”?

๊ด€๋ จ ๋ฌธ์„œ