Source code for mergernet.data.preprocessing
import tensorflow as tf
[docs]def one_hot_factory(n_class):
def one_hot(X, y):
return X, tf.one_hot(y, n_class)
return one_hot
[docs]def load_jpg_with_label(X, y=None):
img_bytes = tf.io.read_file(X)
img = tf.io.decode_jpeg(img_bytes, channels=3)
return img, y
[docs]def load_png_with_label(X, y=None):
img_bytes = tf.io.read_file(X)
img = tf.io.decode_png(img_bytes, channels=3)
return img, y
[docs]def load_jpg(X):
img_bytes = tf.io.read_file(X)
img = tf.io.decode_jpeg(img_bytes, channels=3)
return img
[docs]def load_png(X):
img_bytes = tf.io.read_file(X)
img = tf.io.decode_png(img_bytes, channels=3)
return img