Big Data/Analytics Zone is brought to you in partnership with:

Istvan Szegedi is an IT Technical Architect at Vodafone UK. He has been working at Hewlett-Packard, Nokia Networks, Google, Morgan Stanley and Vodafone. He holds certificates such as Sun Certified System Administrator, Sun Certified Java Programmer, Sun Certified Web Component Developer, Salesforce.com Certified Force.com Developer, TOGAF Certified Enterprise Architect. As a big fan of mobile and cloud computing, he likes to believe that these technologies will eventually push aside the desktop/client-server architecture Istvan is a DZone MVB and is not an employee of DZone and has posted 38 posts at DZone. You can read more from them at their website. View Full User Profile

Prediction API: Machine Learning from Google

01.22.2013
| 9486 views |
  • submit to reddit

Introduction

One of the exciting APIs among the 50+ APIs offered by Google is the Prediction API. It provides pattern matching and machine learning capabilities like recommendations or categorization. The notion is similar to the machine learning capabilities that we can see in other solutions (e.g. in Apache Mahout): we can train the system with a set of training data and then the applications based on Prediction API can recommend ("predict") what products the user might like or  they can categories spams, etc. In this post we go through an example how to categorize SMS messages - whether they are spams or valuable texts ("hams").

Using Prediction API

In order to be able to use Prediction API, the service needs to be enabled via Google API console. To upload training data, Prediction API also requires Google Cloud Storage. The dataset  used in this post is from UCI Machine Learning Repository.  UCI Machine Learning repository has 235 datasets publicly available, this post is based on SMS Spam Collections dataset.

To upload the training data first we need to create a bucket in Google Cloud Storage. From Google API console we need to click on Google Cloud Storage and then on Google Cloud Storage Manager: This will open a webpage whe we can create new buckets and upload or delete files. GoogleStorage2

The UCI SMS Spam Collection file is not suitable as is for Prediction API, it needs to be converted into the following format (the categories - ham/spam - need to be quoted as well as the SMS text):

"ham" "Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat..."

GoogleStorage4

Google Prediction API offers a handful of commands that can be invoked via REST interface. The simplest way of testing Prediction API is to use Prediction API explorer.GooglePrediction1

Once the training data is available on Google Cloud Storage, we can start training the machine learning system behind Prediction API. To begin training our model, we need to run prediction.trainedmodels.insert. All commands require authentication, it is based on OAuth 2.0 standard.

GooglePrediction2

In the insert menu we need to specify the fields that we want to be included in the response.  In the request body we need to define an id (this will be used as a reference to the model in the commands used later on), a storageDataLocation where we have the training data uploaded (the Google Cloud Storage path) and the modelType (could be regression or classification, for spam filtering it is classification):

GooglePrediction-SpamInsert1

The training runs for a while, we can check the status using prediction.trainedmodels.get command. The status field is going to be RUNNING and then will be changed to DONE, once the training is finished.

GooglePrediction-SpamGet1GooglePrediction-SpamGet2

Now we are ready to run our test against the machine learning system and it is going to classify whether the given text is spam or ham. The Prediction API command for this action is prediction.trainedmodels.predict. In the id field we have to refer to the id that we defined for the  prediction.trainedmodels.insert command (bighadoop-00001) and we also need to specify the request body - input will be csvInstance and then we enter the text that we want to get categorized (e.g. "Free entry")

GooglePrediction-SpamPredict1

The system then returns with the category (spam) and the score (0.822158 for spam, 0.177842 for ham):

GooglePrediction-SpamPredict2

Google Prediction API libraries

Google also offers a featured sample application that includes all the code required to run it on Google App Engine. It is called Try-Prediction and the code is written in Python and also in Java. The application can be tested at http://try-prediction.appspot.com. For instance, if we enter a quote for the Language Detection model from Niels Bohr: "Prediction is very difficult, especially if it's about the future.", it will return that it is likely to be an English text (54,4%). TryPrediction

The key part of the Python code is in predict.py: 

class PredictAPI(webapp.RequestHandler):
  '''This class handles Ajax prediction requests, i.e. not user initiated
     web sessions but remote procedure calls initiated from the Javascript
     client code running the browser.
  '''




  def get(self):
    try:
      # Read server-side OAuth 2.0 credentials from datastore and
      # raise an exception if credentials not found.
      credentials = StorageByKeyName(CredentialsModel, USER_AGENT, 
                                    'credentials').locked_get()
      if not credentials or credentials.invalid:
        raise Exception('missing OAuth 2.0 credentials')




      # Authorize HTTP session with server credentials and obtain  
      # access to prediction API client library.
      http = credentials.authorize(httplib2.Http())
      service = build('prediction', 'v1.4', http=http)
      papi = service.trainedmodels()




      # Read and parse JSON model description data.
      models = parse_json_file(MODELS_FILE)




      # Get reference to user's selected model.
      model_name = self.request.get('model')
      model = models[model_name]




      # Build prediction data (csvInstance) dynamically based on form input.
      vals = []
      for field in model['fields']:
        label = field['label']
        val = str(self.request.get(label))
        vals.append(val)
      body = {'input' : {'csvInstance' : vals }}
      logging.info('model:' + model_name + ' body:' + str(body))




      # Make a prediction and return JSON results to Javascript client.
      ret = papi.predict(id=model['model_id'], body=body).execute()
      self.response.out.write(json.dumps(ret))




    except Exception, err:
      # Capture any API errors here and pass response from API back to
      # Javascript client embedded in a special error indication tag.
      err_str = str(err)
      if err_str[0:len(ERR_TAG)] != ERR_TAG:
        err_str = ERR_TAG + err_str + ERR_END
      self.response.out.write(err_str)
The Java version of Prediction web application is as follows:
public class PredictServlet extends HttpServlet {




  @Override
  protected void doGet(HttpServletRequest request,
                       HttpServletResponse response) throws ServletException, 
                                                            IOException {
    Entity credentials = null;
    try {
      // Retrieve server credentials from app engine datastore.
      DatastoreService datastore = 
        DatastoreServiceFactory.getDatastoreService();
      Key credsKey = KeyFactory.createKey("Credentials", "Credentials");
      credentials = datastore.get(credsKey);
    } catch (EntityNotFoundException ex) {
      // If can't obtain credentials, send exception back to Javascript client.
      response.setContentType("text/html");
      response.getWriter().println("exception: " + ex.getMessage());
    }




    // Extract tokens from retrieved credentials.
    AccessTokenResponse tokens = new AccessTokenResponse();
    tokens.accessToken = (String) credentials.getProperty("accessToken");
    tokens.expiresIn = (Long) credentials.getProperty("expiresIn");
    tokens.refreshToken = (String) credentials.getProperty("refreshToken");
    String clientId = (String) credentials.getProperty("clientId");
    String clientSecret = (String) credentials.getProperty("clientSecret");
    tokens.scope = IndexServlet.scope;




    // Set up the HTTP transport and JSON factory
    HttpTransport httpTransport = new NetHttpTransport();
    JsonFactory jsonFactory = new JacksonFactory();




    // Get user requested model, if specified.
    String model_name = request.getParameter("model");




    // Parse model descriptions from models.json file.
    Map models = 
      IndexServlet.parseJsonFile(IndexServlet.modelsFile);




    // Setup reference to user specified model description.
    Map selectedModel = 
      (Map) models.get(model_name);
    
    // Obtain model id (the name under which model was trained), 
    // and iterate over the model fields, building a list of Strings
    // to pass into the prediction request.
    String modelId = (String) selectedModel.get("model_id");
    List params = new ArrayList();
    List<Map > fields = 
      (List<Map >) selectedModel.get("fields");
    for (Map field : fields) {
      // This loop is populating the input csv values for the prediction call.
      String label = field.get("label");
      String value = request.getParameter(label);
      params.add(value);
    }




    // Set up OAuth 2.0 access of protected resources using the retrieved
    // refresh and access tokens, automatically refreshing the access token 
    // whenever it expires.
    GoogleAccessProtectedResource requestInitializer = 
      new GoogleAccessProtectedResource(tokens.accessToken, httpTransport, 
                                        jsonFactory, clientId, clientSecret, 
                                        tokens.refreshToken);




    // Now populate the prediction data, issue the API call and return the
    // JSON results to the Javascript AJAX client.
    Prediction prediction = new Prediction(httpTransport, requestInitializer, 
                                           jsonFactory);
    Input input = new Input();
    InputInput inputInput = new InputInput();
    inputInput.setCsvInstance(params);
    input.setInput(inputInput);
    Output output = 
      prediction.trainedmodels().predict(modelId, input).execute();
    response.getWriter().println(output.toPrettyString());
  }
}
Besides Python and Java support, Google also offers .NET, Objective-C, Ruby, Go, JavaScript, PHP, etc. libraries for Prediction API.
Published at DZone with permission of Istvan Szegedi, author and DZone MVB.

(Note: Opinions expressed in this article and its replies are the opinions of their respective authors and not those of DZone, Inc.)