Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

A hand-crafted toy level deep learning framework based on python.

Notifications You must be signed in to change notification settings

MarkXCloud/Ocean

Folders and files

NameName
Last commit message
Last commit date

Latest commit

History

21 Commits

Repository files navigation

Ocean: a hand-crafted toy level deep learning framework

1. 简介

本仓库是一个基于python的练习级别深度学习框架Ocean,主要用于熟悉深度学习算法的各个细节。

Ocean框架具有以下特性:

  • 基于静态图
  • 接近于PyTorch的api风格
  • 清晰、完备的前向、反向传播过程,适用于初学者学习与熟悉
  • 基于CuPy完成的cuda加速

Ocean框架完成的功能:

  • 线性层Linear,具有可学习参数WB
  • 卷积层Conv2d,采用img2col方法转换为GEMM实现,具有可学习的卷积核。
  • 池化层MaxPoolingAveragePoolingGlobalAveragePooling
  • 可切换train_modeeval_modeBatchNorm2dDropout
  • 激活函数,包括SigmoidTanhSoftmax等。
  • 损失函数MSECELoss
  • 优化器SGDAdam

Ocean框架具有非常易读且易于使用的api风格:

class MLP(nn.NodeAdder):
 def __init__(self):
 super().__init__()
 self.fc = nn.Sequential(
 nn.Linear(input_dim=784, output_dim=200),
 nn.Sigmoid(),
 nn.Linear(input_dim=200, output_dim=10),
 nn.Softmax()
 )
 def forward(self, X):
 return self.fc(X)
 
x = Variable()
m = MLP()
pred = m(x)
y = Variable()
loss = nn.MSE()
error = loss(pred=pred, target=y)
optim = SGD(graph=m.model_graph, loss=error, lr=0.1)
for i in range(E):
 # train 
 m.set_train_mode()
 for batch_data, batch_label in tqdm(train_loader, desc=f'epoch {i}'):
 optim.zero_gradient()
 for data, label in zip(batch_data, batch_label):
 x.set_value(data)
 y.set_value(label)
 optim.calculate_grad()
 # test
 m.set_eval_mode()
 for batch_data, batch_label in tqdm(test_loader):
 for data, label in zip(batch_data, batch_label):
 x.set_value(data)
 y.set_value(label)
 error.forward()

/demo中有基于Ocean框架的更多示例。

同时,由于个人的力量有限,Ocean还有许多不足之处有待改进,未来如果有时间会尝试进行更大的改进。

2. 附录

  1. Mnist数据集地址

About

A hand-crafted toy level deep learning framework based on python.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

AltStyle によって変換されたページ (->オリジナル) /