『科技之感』提升ImageNet分类准确率且可解释,决策树的复兴?结合神经网络( 三 )
HowitWorks
Neural-Backed决策树的训练与推断过程可分解为如下四个步骤:
为决策树构建称为诱导层级「InducedHierarchy」的层级;
该层级产生了一个称为树监督损失「TreeSupervisionLoss」的独特损失函数;
通过将样本传递给神经网络主干开始推断 。 在最后一层全连接层之前 , 主干网络均为神经网络;
以序列决策法则方式运行最后一层全连接层结束推断 , 研究者将其称为嵌入决策法则「EmbeddedDecisionRules」 。
文章图片
Neural-Backed决策树训练与推断示意图 。
运行嵌入决策法则
这里首先讨论推断问题 。 如前所述 , NBDT使用神经网络主干提取每个样本的特征 。 为便于理解接下来的操作 , 研究者首先构建一个与全连接层等价的退化决策树 , 如下图所示:
文章图片
以上产生了一个矩阵-向量乘法 , 之后变为一个向量的内积 , 这里将其表示为$hat{y}$ 。 以上输出最大值的索引即为对类别的预测 。
文章图片
简单决策树(naivedecisiontree):研究者构建了一个每一类仅包含一个根节点与一个叶节点的基本决策树 , 如上图中「B—Naive」所示 。 每个叶节点均直接与根节点相连 , 并且具有一个表征向量(来自W的行向量) 。
使用从样本提取的特征x进行推断意味着 , 计算x与每个子节点表征向量的内积 。 类似于全连接层 , 最大内积的索引即为所预测的类别 。
全连接层与简单决策树之间的直接等价关系 , 启发研究者提出一种特别的推断方法——使用内积的决策树 。
构建诱导层级
该层级决定了NBDT需要决策的类别集合 。 由于构建该层级时使用了预训练神经网络的权重 , 研究者将其称为诱导层级 。
文章图片
具体地 , 研究者将全连接层中权重矩阵W的每个行向量 , 看做d维空间中的一点 , 如上图「StepB」所示 。 接下来 , 在这些点上进行层级聚类 。 连续聚类之后便产生了这一层级 。
使用树监督损失进行训练
文章图片
考虑上图中的「A-Hard」情形 。 假设绿色节点对应于Horse类 。 这只是一个类 , 同时它也是动物(橙色) 。 对结果而言 , 也可以知道到达根节点(蓝色)的样本应位于右侧的动物处 。 到达节点动物「Animal」的样本也应再次向右转到「Horse」 。 所训练的每个节点用于预测正确的子节点 。 研究者将强制实施这种损失的树称为树监督损失(TreeSupervisionLoss) 。 换句话说 , 这实际上是每个节点的交叉熵损失 。
使用指南
我们可以直接使用Python包管理工具来安装nbdt:
pipinstallnbdt
安装好nbdt后即可在任意一张图片上进行推断 , nbdt支持网页链接或本地图片 。
nbdthttps://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32
#ORrunonalocalimage
nbdt/imaginary/path/to/local/image.png
不想安装也没关系 , 研究者为我们提供了网页版演示以及Colab示例 , 地址如下:
Demo:http://nbdt.alvinwan.com/demo/
Colab:http://nbdt.alvinwan.com/notebook/
下面的代码展示了如何使用研究者提供的预训练模型进行推断:
fromnbdt.modelimportSoftNBDT
fromnbdt.modelsimportResNet18,wrn28_10_cifar10,wrn28_10_cifar100,wrn28_10#usewrn28_10forTinyImagenet200
推荐阅读
- 小米科技▲卢伟冰再次发力,全球首发骁龙768G,5G新机将在两天后发布!
- 快科技最贵或达5000元,苹果头戴耳机更多细节曝光:包含两款
- 科技迷7nm版年底流片,要放弃美国代工?国产x86转向三星台积电代工
- 骊微电子科技PD充电器应用方案,PN8161+PN8307H高集成18W
- 快科技小米高管都是外人?雷军透露了一个秘密
- 靓科技解读Thing,a16z、5.15亿美金的数据加密股票基金:找寻下一个Big
- 王伯伯说科技流畅用三年,即将开学的学生党准备好了吗?三款高配低价千元机
- 知叔达科技中芯国际早已料到,成功绕开了光刻机,怒了!荷兰ASML再次失约
- 小熊带你玩科技数据成粤企生产新要素,工业互联网深调研〡从经验依赖到数据驱动
- 每日科技果粉大批华人再掀归国潮,美利坚的钱“不香了”?,硅谷科技人才流失