diff --git a/nets/densenet.py b/nets/densenet.py index 3b5fa30..6494232 100644 --- a/nets/densenet.py +++ b/nets/densenet.py @@ -46,7 +46,8 @@ def _conv(inputs, num_filters, kernel_size, stride=1, dropout_rate=None, net = slim.conv2d(net, num_filters, kernel_size) if dropout_rate: - net = tf.nn.dropout(net) + keep_prob = 1.0 - dropout_rate + net = slim.dropout(net, keep_prob=keep_prob) net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)