Using TensorFlow models in Go

Image via www.vpnsrus.com

In earlier posts, I discussed hosting a deep learning model such as Resnet50 on Kubernetes or Azure Container Instances. The model can then be used as any API which receives input as JSON and returns a result as JSON.

Naturally, you can also run the model in offline scenarios and directly from your code. In this post, I will take a look at calling a TensorFlow model from Go. If you want to follow along, you will need Linux or MacOS because the Go module does not support Windows.

Getting Ready

I installed an Ubuntu Data Science Virtual Machine on Azure and connected to it with X2Go:

Data Science Virtual Machine (Ubuntu) with X2Go

The virtual machine has all the required machine learning tools installed such as TensorFlow and Python. It also has Visual Studio Code. There are some extra requirements though:

  • Go: follow the instructions here to download and install Go
  • TensorFlow C API: follow the instructions here to download and install the C API; the TensorFlow package for Go requires this; it is recommended to also build and run the Hello from TensorFlow C program to verify that the library works (near the bottom of the instructions page)

After installing Go and the TensorFlow C API, install the TensorFlow Go package with the following command:

go get github.com/tensorflow/tensorflow/tensorflow/go

Test the package with go test:

go test github.com/tensorflow/tensorflow/tensorflow/go

The above command should return:

ok      github.com/tensorflow/tensorflow/tensorflow/go  0.104s

The go get command installed the package in $HOME/go/src/github.com if you did not specify a custom $GOPATH (see this wiki page for more info).

Getting a model

A model describes how the input (e.g. an image for image classification) gets translated to an output (e.g. a list of classes with probabilities). The model contains thousands or even millions of parameters which means a model can be quite large. In this example, we will use NASNetMobile which can be used to classify images.

Now we need some code to save the model in TensorFlow format so that it can be used from a Go program. The code below is based on the sample code on the NASNetMobile page from modeldepot.io. It also does a quick test inference on a cat image.

import keras
from keras.applications.nasnet import NASNetMobile
from keras.preprocessing import image
from keras.applications.xception import preprocess_input, decode_predictions
import numpy as np
import tensorflow as tf
from keras import backend as K

sess = tf.Session()
K.set_session(sess)

model = NASNetMobile(weights="imagenet")
img = image.load_img('cat.jpg', target_size=(224,224))
img_arr = np.expand_dims(image.img_to_array(img), axis=0)
x = preprocess_input(img_arr)
preds = model.predict(x)
print('Prediction:', decode_predictions(preds, top=5)[0])

#save the model for use with TensorFlow
builder = tf.saved_model.builder.SavedModelBuilder("nasnet")

#Tag the model, required for Go
builder.add_meta_graph_and_variables(sess, ["atag"])
builder.save()
sess.close()

On the Ubuntu Data Science Virtual Machine, the above code should execute without any issues because all Python packages are already installed. I used the py35 conda environment. Use activate py35 to make sure you are in that environment.

The above code results in a nasnet folder, which contains the saved_model.pb file for the graph structure. The actual weights are in the variables subfolder. In total, the nasnet folder is around 38MB.

Great! Now we need a way to use the model from our Go program.

Using the saved model from Go

The model can be loaded with the LoadSavedModel function of the TensorFlow package. That package is imported like so:

import (
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

LoadSavedModel is used like so:

model, err := tf.LoadSavedModel("nasnet",
[]string{"atag"}, nil)
if err != nil {
log.Fatal(err)
}

The above code simply tries to load the model from the nasnet folder. We also need to specify the tag.

Next, we need to load an image and convert the image to a tensor with the following dimensions [1][224][224][3]. This is similar to my earlier ResNet50 post.

Now we need to pass the tensor to the model as input, and retrieve the class predictions as output. The following code achieves this:

output, err := model.Session.Run(
map[tf.Output]*tf.Tensor{
model.Graph.Operation("input_1").Output(0): input,
},
[]tf.Output{
model.Graph.Operation("predictions/Softmax").Output(0),
},
nil,
)
if err != nil {
log.Fatal(err)
}

What the heck is this? The run method is defined as follows:

func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error)

When you build a model, you can give names to tensors and operations. In this case the input tensor (of dimensions [1][224][224][3]) is called input_1 and needs to be specified as a map. The inference operation is called predictions/Softmax and the output needs to be specified as an array.

The actual predictions can be retrieved from the output variable:

predictions, ok := output[0].Value().([][]float32)
if !ok {
log.Fatal(fmt.Sprintf("output has unexpected type %T", output[0].Value()))
}

If you are not very familiar with Go, the code above uses type assertion to verify that predictions is a 2-dimensional array of float32. If the type assertion succeeds, the predictions variable will contain the actual predictions: [[<probability class 1 (tench)>, <probability class 2 (goldfish)>, …]]

You can now simply find the top prediction(s) in the array and match them with the list of classes for NASNet (actually the ImageNet classes). I get the following output with a cat image:

Yep, it’s a tabby!

If you are wondering what image I used:

Tabby?

Conclusion

With Go’s TensorFlow bindings, you can load TensorFlow models from disk and use them for inference locally, without having to call a remote API. We used Python to prepare the model with some help from Keras.

One thought on “Using TensorFlow models in Go”

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: