Friday, July 30, 2010

Just In Time Compiler for Managed Platform- Part 4: Basic Conditional Branch

Today I'll generate code that can handle conditional branch. This one is really gian step- since once we have conditional branching enabled we are ready to execute almost all instructions. This is relatively complex to handle than earlier things.

First let us define the java class with a condition:

public class Math
{
    public static int add(int x, int y)
    {
        int r;
        r= x+y;
        return r;
    }

    public static int SimpleCall()
    {
        return 17;
    }
    public static int SimpleCall2()
    {
        return add(17, 29);
    }

    public static int SimpleCall4(int p)
    {
        if(p>100) return 100;
        return add(p, 29);
    }

    public static int SimpleCall5(int p)
    {
        return SimpleCall4(101);
    }

    public static int SimpleCall3()
    {        
        return SimpleCall4(17);
    }

    public static int Main()
    {
        return SimpleCall3();
    }
}


And the generated byte code:

public class Math extends java.lang.Object{
public Math();
  Signature: ()V
  Code:
   0:   aload_0
   1:   invokespecial   #1; //Method java/lang/Object."init":()V
   4:   return

public static int add(int, int);
  Signature: (II)I
  Code:
   0:   iload_0
   1:   iload_1
   2:   iadd
   3:   istore_2
   4:   iload_2
   5:   ireturn

public static int SimpleCall();
  Signature: ()I
  Code:
   0:   bipush  17
   2:   ireturn

public static int SimpleCall2();
  Signature: ()I
  Code:
   0:   bipush  17
   2:   bipush  29
   4:   invokestatic    #2; //Method add:(II)I
   7:   ireturn

public static int SimpleCall4(int);
  Signature: (I)I
  Code:
   0:   iload_0
   1:   bipush  100
   3:   if_icmple       9
   6:   bipush  100
   8:   ireturn
   9:   iload_0
   10:  bipush  29
   12:  invokestatic    #2; //Method add:(II)I
   15:  ireturn

public static int SimpleCall5(int);
  Signature: (I)I
  Code:
   0:   bipush  101
   2:   invokestatic    #3; //Method SimpleCall4:(I)I
   5:   ireturn

public static int SimpleCall3();
  Signature: ()I
  Code:
   0:   bipush  17
   2:   invokestatic    #3; //Method SimpleCall4:(I)I
   5:   ireturn

public static int Main();
  Signature: ()I
  Code:
   0:   invokestatic    #4; //Method SimpleCall3:()I
   3:   ireturn

}

The SimpleCall4 method has a conditional branch instruction if_icmple. Note that we deal with static methiods so far- so dont care the init method yet.

Let us first define the helper method for the instruction to generate machine code:

void IfIcmple(u1* code, int& ip, int targetpc, CMapPtrToPtr *pJmpTargetMap)
{
    u1 c[] = {
         //if((pRE->stack[pRE->stackTop-2].intValue > pRE->stack[pRE->stackTop-1].intValue))

         0x8B, 0x45, 0x08, //         mov         eax,dword ptr [pRE] 
         0x8B, 0x48, 0x04, //         mov         ecx,dword ptr [eax+4] 
         0x8B, 0x55, 0x08, //         mov         edx,dword ptr [pRE] 
         0x8B, 0x02, //            mov         eax,dword ptr [edx] 
         0x8B, 0x55, 0x08, //         mov         edx,dword ptr [pRE] 
         0x8B, 0x52, 0x04, //         mov         edx,dword ptr [edx+4] 
         0x8B, 0x75, 0x08, //         mov         esi,dword ptr [pRE] 
         0x8B, 0x36, //            mov         esi,dword ptr [esi] 
         0x8B, 0x44, 0xC8, 0xF0, //      mov         eax,dword ptr [eax+ecx*8-10h] 
         0x3B, 0x44, 0xD6, 0xF8, //      cmp         eax,dword ptr [esi+edx*8-8] 
         0x76, 0x00, 0x90, 0x90, 0x90, // JBE  nop nop nop                                                        


         //pRE->stackTop -= 2;
         0x8B, 0x45, 0x08, //         mov         eax,dword ptr [pRE] 
         0x8B, 0x48, 0x04, //         mov         ecx,dword ptr [eax+4] 
         0x83, 0xE9, 0x02, //         sub         ecx,2 
         0x8B, 0x55, 0x08, //         mov         edx,dword ptr [pRE] 
         0x89, 0x4A, 0x04, //         mov         dword ptr [edx+4],ecx 
    };

    memcpy(&code[ip], c, sizeof(c));
    ip+=sizeof(c);

    LinkedListNode *pNode = new LinkedListNode(&code[ip-20], NULL);
    
    LOG(_T("PC = 0x%X Jmp Inst Offset = 0x%X Inst = 0x%X\n"), targetpc, (int)(&code[ip-20])-(int)(code), code[ip-20]);

    JmpTarget *pJmpTarget = NULL;
    if(!pJmpTargetMap->Lookup((void *) targetpc, (void *&) pJmpTarget))
    {
        pJmpTarget = new JmpTarget();
        pJmpTarget->pTargetList = pNode;
        pJmpTarget->pc = targetpc;
        pJmpTargetMap->SetAt((void *)targetpc, pJmpTarget);
    }
    else
    {
        pNode->pNext = pJmpTarget->pTargetList;
    }

    pJmpTarget->pTargetList = pNode;
}

The instruction if_icmple checks value on top of stack and brances if first argument is greater than second argument. We do same in the machine code. Since we can not calculate address of target instruction without first generating code that is behind the target- we keep the target address empty and fix all jump address after we generate machine code all the instructions. To do this we maintain a map of jmp instruction locations and also a map of native vs managed code locations.

Since one instruction can be target of multiple jmp instructions we keep a linked list of jmp instructions we need to fix for a specific managed target instruction-

struct LinkedListNode
{
 void *pData;
 LinkedListNode *pNext;
 LinkedListNode(void *pData, LinkedListNode *pNext)
 {
  this->pData = pData;
  this->pNext = pNext;
 }

 LinkedListNode()
 {
  this->pData = NULL;
  this->pNext = NULL;
 }
};

Here we need to fix out ireturn helper- since after return from method we must not execute any instruction we must return but before that we must fix the native stack (epilog) and then return. Since we add epilog at the end of each native function we just jmp to that location for the cleanup if there is multiple retutn from managed code. To track thios we insert an unconditional jmp instruction to the epilog-

// ireturn instruction takes the value from stack top and push
// to stack[0] position. 
void IReturn(u1* code, int& ip, CMapPtrToPtr *pJmpTargetMap)
{
    u1 c[] = {
         //pRE->stack[0].intValue=pRE->stack[pRE->stackTop-1].intValue;
         0x8B, 0x45, 0x08, //         mov         eax,dword ptr [pRE] 
         0x8B, 0x48, 0x04, //         mov         ecx,dword ptr [eax+4] 
         0x8B, 0x55, 0x08, //         mov         edx,dword ptr [pRE] 
         0x8B, 0x02,       //         mov         eax,dword ptr [edx] 
         0x8B, 0x55, 0x08, //         mov         edx,dword ptr [pRE] 
         0x8B, 0x12,       //         mov         edx,dword ptr [edx] 
         0x8B, 0x44, 0xC8, 0xF8, //         mov         eax,dword ptr [eax+ecx*8-8] 
         0x89, 0x02,       //         mov         dword ptr [edx],eax 
    };

    memcpy(&code[ip], c, sizeof(c));
    ip+=sizeof(c);    

    u1 c1[] = {
        0xE9, 0x00, 0x00, 0x00, 0x00  //JMP , nop ,nop
    };

    memcpy(&code[ip], c1, sizeof(c1));
    ip+=sizeof(c1);    

    LinkedListNode *pNode = new LinkedListNode(&code[ip-5], NULL);
    
    LOG(_T("PC = RETURN Jmp Inst Offset = 0x%X Inst = 0x%X\n"), (int)(&code[ip-5])-(int)(code), code[ip-5]);

    JmpTarget *pJmpTarget = NULL;
    if(!pJmpTargetMap->Lookup((void *) 0, (void *&) pJmpTarget))
    {
        pJmpTarget = new JmpTarget();
        pJmpTarget->pTargetList = pNode;
        pJmpTarget->pc = 0;
        pJmpTargetMap->SetAt((void *)0, pJmpTarget);
    }
    else
    {
        pNode->pNext = pJmpTarget->pTargetList;
    }

    pJmpTarget->pTargetList = pNode;

}


Lets now see how we fix the address-

void FixJmpLocations(u1 *codes, CMapPtrToPtr *pJmpTargetMap, CMapPtrToPtr *pManagedtoNativeMap, u1* retAddress)
{
    LOG(_T("Fixing Jmp Target\n"));

    //Iterate through the entire map,
    for (POSITION pos = pJmpTargetMap->GetStartPosition(); pos != NULL;)
    {
        JmpTarget *pJmpTarget;
        int pc;
        pJmpTargetMap->GetNextAssoc(pos, (void *&)pc, (void *&)pJmpTarget);

        ASSERT(pc == pJmpTarget->pc);
        int target;

        if(0 == pJmpTarget->pc)
        {
            target = (int)retAddress;
        }
        else
        {
            target = (int) pManagedtoNativeMap->GetValueAt((void *&)pJmpTarget->pc);
        }

        LinkedListNode *pTargetList = pJmpTarget->pTargetList;

        do{
            int offset=0;
            if(0xE9 == codes[(int)((int)pTargetList->pData-(int)codes)])
            {
                offset = target - (int)pTargetList->pData-5; //1 for inst 4 for 4 byte offset = -5
                memcpy(&codes[(int)((int)pTargetList->pData-(int)codes)+1], &offset, sizeof(offset));
            }
            else
            {
                offset = target - (int)pTargetList->pData - 2;  //1 for inst 1 for 1 byte offset = -2 
                codes[(int)((int)pTargetList->pData-(int)codes)+1]=offset;
            }

            LOG(_T("Fixed 0x%X with Native Address Offset 0x%X\n"), (int)pTargetList->pData - (int)codes, offset);

            pTargetList = pTargetList->pNext;
        }while(NULL != pTargetList);
    }
}

Please note that the map pManagedtoNativeMap is polulated the Compile function like this in the giant for loop for each instruction-

pManagedtoNativeMap->SetAt((void *)pc, &codes[ip]);

Thats it. We are now ready to test the vodes we generate-

int main()
{
    Context *pRE = new Context();;
    pRE->stack = new Variable[STACK_SIZE];    
    pRE->stackTop = 0;
    memset(pRE->stack, 0, sizeof(Variable)*STACK_SIZE);

    pRE->pVMEnv = new VMEnvironment();
    pRE->pVMEnv->pClassHeap = new ClassHeap();
    pRE->pVMEnv->pObjectHeap = new ObjectHeap();

    pRE->pVMEnv->ppHelperMethods = HelperMethods;

    ClassHeap* pClsHeap =  pRE->pVMEnv->pClassHeap; 

    JavaClass jc;
    pClsHeap->LoadClass("Math", &jc);
    JavaClass *pVirtualClass =&jc,  *pClass1 = &jc;

    int mindex=pClass1->GetMethodIndex(_T("Main"),_T("()I"),pVirtualClass); 

    method_info_ex *pMethod = &pVirtualClass->methods[mindex];

    MethodLink *pMethodLink = new MethodLink();
    pMethodLink->pClass = pVirtualClass;
    pMethodLink->pMethod = pMethod;    

    ((void (*)(MethodLink *pMethodLink, Context *pRE))pRE->pVMEnv->ppHelperMethods[CALL_METHOD_HELPER_INDEX])(pMethodLink, pRE);

    LOG(_T("Ret = %d"), pRE->stack[0].intValue);

    return 0;
}

Thats it. Our generated native code can handle branching!

No comments:

Post a Comment