Data Structures, Algorithms, & Applications in Java
Chapter 16, Exercise 52

We shall extend the class BinarySearchTree. The ascend method is inherited from BinarySearchTree. The get, put, and remove methods of SplayTree differ from the corresponding methods of BinarySearchTree in that the new codes stack the path from the root up to and including the splay node and also invoke the method splay which performs the splay operation. The data type of the elements in this stack is SplayElement. This data type is defined below.
// top-level nested class
static class StackElement
{
   // data members
   BinaryTreeNode node;        // a node in the tree
   boolean left;               // true iff we move to the left child of node

   // constructors
   StackElement(BinaryTreeNode theNode, boolean theLeft)
   {
      left = theLeft;
      node = theNode;
   }
}



The new code for get, put, and remove are given below.
// stack used by splay method
static ArrayStack stack = new ArrayStack();

// ascend is inherited from BinarySearchTree

/** @return element whose key is theKey
  * @return null if there is no element with key theKey */
public Object get(Object theKey)
{
   Object theElement = null;    // element to return

   // pointer p starts at the root and moves through
   // the tree looking for an element with key theKey
   // nodes on the search path are saved on a stack
   // for later use by the splay method
   BinaryTreeNode p = root;

   Comparable searchKey = (Comparable) theKey;
   while (p != null)
      // examine p.element.key
      if (searchKey.compareTo(((Data) p.element).key) < 0)
      {
         stack.push(new StackElement(p, true));
         p = p.leftChild;
      }
      else
         if (searchKey.compareTo(((Data) p.element).key) > 0)
         {
            stack.push(new StackElement(p, false));
            p = p.rightChild;
         }
         else // found matching element
         {
            stack.push(new StackElement(p, false));
            theElement = ((Data) p.element).element;
            break;
         }

   splay();

   return theElement;
}

/** insert an element with the specified key
  * overwrite old element if there is already an
  * element with the given key
  * @return old element (if any) with key = theKey */
public Object put(Object theKey, Object theElement)
{
   BinaryTreeNode p = root;     // search pointer
   Comparable elementKey = (Comparable) theKey;
   // find place to insert theElement
   while (p != null)
   {// examine p.element.key
      // move p to a child
      if (elementKey.compareTo(((Data) p.element).key) < 0)
      {
         stack.push(new StackElement(p, true));
         p = p.leftChild;
      }
      else
      {
         stack.push(new StackElement(p, false));
         if (elementKey.compareTo(((Data) p.element).key) > 0)
              p = p.rightChild;
         else
         {// overwrite element with same key
            Object elementToReturn = ((Data) p.element).element;
            ((Data) p.element).element = theElement;
            splay();
            return elementToReturn;
         }
      }
   }

   // get a node for theElement and attach to pp
   BinaryTreeNode r = new BinaryTreeNode
                          (new Data(elementKey, theElement));
   if (root != null)
   {// the tree is not empty
      StackElement pp = (StackElement) stack.peek();
      if (pp.left)
         pp.node.leftChild = r;
      else
         pp.node.rightChild = r;
      stack.push(new StackElement(r, false));
      splay();
   }
   else // insertion into empty tree
      root = r;

   return null;
}

/** @return matching element and remove it
  * @return null if no matching element */
public Object remove(Object theKey)
{
   Comparable searchKey = (Comparable) theKey;

   // set p to point to node with key searchKey
   BinaryTreeNode p = root;    // search pointer
   while (p != null && !((Data) p.element).key.equals(searchKey))
      // move to a child of p
      if (searchKey.compareTo(((Data) p.element).key) < 0)
      {
         stack.push(new StackElement(p, true));
         p = p.leftChild;
      }
      else
      {
         stack.push(new StackElement(p, false));
         p = p.rightChild;
      }

   if (p == null) // no element with key searchKey
   {
      splay();
      return null;
   }

   // save element to be removed
   Object theElement = ((Data) p.element).element; 

   // pp is parent of p
   BinaryTreeNode pp;
   if (stack.empty())
      pp = null;
   else
      pp = ((StackElement) stack.peek()).node;

   // restructure tree
   // handle case when p has two children
   if (p.leftChild != null && p.rightChild != null)
   {// two children
      // convert to zero or one child case
      // find element with largest key in left subtree of p
      BinaryTreeNode s = p.leftChild;
      stack.push(new StackElement(p, true));
      while (s.rightChild != null)
      {// move to larger element
         stack.push(new StackElement(s, false));
         s = s.rightChild;
      }

      // move largest element from s to p
      p.element = s.element;
      p = s;
      pp = ((StackElement) stack.peek()).node;
   }

   // p has at most one child, save this child in c
   BinaryTreeNode c;
   if (p.leftChild == null)
      c = p.rightChild;
   else
      c = p.leftChild;

   // remove node p
   if (p == root)
      root = c;
   else
   {// is p left or right child of pp?
      if (p == pp.leftChild)
         pp.leftChild = c;
      else
         pp.rightChild = c;
   }

   splay();
   return theElement;
}



The code for the six different splay step types is given below.
/** type L splay step
  * @param q is splay node
  * @param p is parent of splay node */
void lSplay(BinaryTreeNode q, BinaryTreeNode p)
{
   p.leftChild = q.rightChild;
   q.rightChild = p;
}

/** type R splay step
  * @param q is splay node
  * @param p is parent of splay node */
void rSplay(BinaryTreeNode q, BinaryTreeNode p)
{
   p.rightChild = q.leftChild;
   q.leftChild = p;
}

/** type LL splay step
  * @param q is splay node
  * @param p is parent of splay node
  * @param gp is grandparent of splay node */
void llSplay(BinaryTreeNode q, BinaryTreeNode p, BinaryTreeNode gp)
{
   gp.leftChild = p.rightChild;
   p.leftChild = q.rightChild;
   p.rightChild = gp;
   q.rightChild = p;
}

/** type RR splay step
  * @param q is splay node
  * @param p is parent of splay node
  * @param gp is grandparent of splay node */
void rrSplay(BinaryTreeNode q, BinaryTreeNode p, BinaryTreeNode gp)
{
   gp.rightChild = p.leftChild;
   p.rightChild = q.leftChild;
   p.leftChild = gp;
   q.leftChild = p;
}

/** type LR splay step
  * @param q is splay node
  * @param p is parent of splay node
  * @param gp is grandparent of splay node */
void lrSplay(BinaryTreeNode q, BinaryTreeNode p, BinaryTreeNode gp)
{
   gp.leftChild = q.rightChild;
   p.rightChild = q.leftChild;
   q.leftChild = p;
   q.rightChild = gp;
}

/** type RL splay step
  * @param q is splay node
  * @param p is parent of splay node
  * @param gp is grandparent of splay node */
void rlSplay(BinaryTreeNode q, BinaryTreeNode p, BinaryTreeNode gp)
{
   gp.rightChild = q.leftChild;
   p.leftChild = q.rightChild;
   q.rightChild = p;
   q.leftChild = gp;
}



The code for the splay operation is given below. This code assumes that the splay node is the node at the top of the stack and that the path from the root to the splay node is saved on the stack. When the code terminates, the splay node becomes the root and the stack is empty.
/** perform the splay operation, splay node is at the top of the stack */
void splay()
{
   if (stack.empty())
      // no splay node
      return;

   // get splay node from stack
   BinaryTreeNode q = ((StackElement) stack.pop()).node;

   while (!stack.empty())
   {// splay node is not at level 1 yet
      // get parent of splay node
      StackElement p = (StackElement) stack.pop();

      if (stack.empty())
      {// splay node is at level 2
         if (p.left)
            // type L splay
            lSplay(q, p.node);
         else
            // type R splay
            rSplay(q, p.node);
         break;
       }
       else
       {// splay node is at level > 2
          StackElement gp = (StackElement) stack.pop();
          if (gp.left)
             if (p.left)
                llSplay(q, p.node, gp.node);
             else
                lrSplay(q, p.node, gp.node);
          else
             if (p.left)
                rlSplay(q, p.node, gp.node);
             else
                rrSplay(q, p.node, gp.node);
       }
   }   

   root = q;
}