Source code for mergernet.data.preprocessing
import tensorflow as tf
[docs]def one_hot_factory(n_class):
  def one_hot(X, y):
    return X, tf.one_hot(y, n_class)
  return one_hot
[docs]def load_jpg_with_label(X, y=None):
  img_bytes = tf.io.read_file(X)
  img = tf.io.decode_jpeg(img_bytes, channels=3)
  return img, y
[docs]def load_png_with_label(X, y=None):
  img_bytes = tf.io.read_file(X)
  img = tf.io.decode_png(img_bytes, channels=3)
  return img, y
[docs]def load_jpg(X):
  img_bytes = tf.io.read_file(X)
  img = tf.io.decode_jpeg(img_bytes, channels=3)
  return img
[docs]def load_png(X):
  img_bytes = tf.io.read_file(X)
  img = tf.io.decode_png(img_bytes, channels=3)
  return img