Python3 Tensorlfow:增加或者减小矩阵维度的实现

1.增加维度

下面给出两个样例

样例1:

[1, 2, 3] ==> [[1],[2],[3]]

import tensorflow as tf

a = tf.constant([1, 2, 3])
b = tf.expand_dims(a,1)

with tf.Session() as sess:
 a_, b_ = sess.run([a, b])
 print('a:')
 print(a_)
 print('b:')
 print(b_)

输出结果

a:
[1 2 3]
b:
[[1]
 [2]
 [3]]

样例2:

[1, 2, 3] ==> [[1,2,3]]

import tensorflow as tf

a = tf.constant([1, 2, 3])
b = tf.expand_dims(a, 0)

with tf.Session() as sess:
 a_, b_ = sess.run([a, b])
 print('a:')
 print(a_)
 print('b:')
 print(b_)

输出结果:

a:
[1 2 3]
b:
[[1 2 3]]

2.降低维度

样例1:

[[1, 2, 3]] ==> [1, 2, 3]

import tensorflow as tf

a = tf.constant([[1, 2, 3]])
b = tf.squeeze(a)

with tf.Session() as sess:
 a_, b_ = sess.run([a, b])
 print('a:')
 print(a_)
 print('b:')
 print(b_)

输出结果

a:
[[1 2 3]]
b:
[1 2 3]

样例2:

[[1], [2], [3]] ==> [[1, 2, 3]

import tensorflow as tf

a = tf.constant([[1], [2], [3]])
b = tf.squeeze(a, 1)

with tf.Session() as sess:
 a_, b_ = sess.run([a, b])
 print('a:')
 print(a_)
 print('b:')
 print(b_)

补充知识:pytorch中squeeze()、unsqueeze(),以及一些高维数组操作

博主最近阅读YOLO底层代码,Torch中对多数组矩阵有很多高维操作,看过一边之后,记录一下,以防忘记。

torch.squeeze()

功能:取消为1的维度

squeeze(input, dim=None, out=None) -> Tensor

这里一般分不清dim的意思

举个例子:

input=(A , 1 , B , C ,1 , D)
squeeze(input)=(A,B,C,D)
input= (A, 1, B)

squeeze(input, 0)=(A, 1, B) 不会改变 squeeze(input, 1)=(A, B) 会改变

看一个简单用例,size表示维度大小,10是取值范围,a=[:,:,:,4]表示取a最后一维的第四个元素(从0开始第四个),即取[0,0,3],[5,6,1],[0,6,8],[…], 判断大于5为true,否则为false。

注意:b的维度比a少了一维。

以上这篇Python3 Tensorlfow:增加或者减小矩阵维度的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持来客网。