Tuesday, July 27, 2010

Just In Time Compiler for Managed Platform- Part 3: Call a method

Today I'll try to extend the simple JIT compiler to the point where we can call a method from another method.

First lets create a simple java class:

public class Math
    public static int add(int x, int y)
        int r;
        r= x+y;
        return r;

    public static int SimpleCall2()
        return add(17, 29);

Here we call add method from SimpleCall2 method. Our JIT compiler will supply mechanism to handle this. When we compile using java compiler we get following class and method byte code:

public class Math extends java.lang.Object{
public Math();
  Signature: ()V
   0:   aload_0
   1:   invokespecial   #1; 
   4:   return

public static int add(int, int);
  Signature: (II)I
   0:   iload_0
   1:   iload_1
   2:   iadd
   3:   istore_2
   4:   iload_2
   5:   ireturn

public static int SimpleCall2();
  Signature: ()I
   0:   bipush  17
   2:   bipush  29
   4:   invokestatic    #2; //Method add:(II)I
   7:   ireturn

[Note: I do not describe the Java Virtual Machine basics here again- You may look at my article for basic understanding: Home Made Java Virtual Machine]

First we need to extend our Context structure to hold some more values. We should add members at the bottom- otherwise the code we generated so far will be invalid.

struct VMEnvironment
    ObjectHeap *pObjectHeap;
    ClassHeap *pClassHeap;
    void **ppHelperMethods;

struct Context
    Variable *stack;
    int stackTop;
    JavaClass *pClass;
    Context *pCallerContext; 
    VMEnvironment *pVMEnv;

Also we want to keep track of native codes we generate. So we need another structure.

struct MethodLink
    JavaClass *pClass;
    method_info_ex *pMethod;
    void *pNativeBlock;

Now we need a helper class to keep track of the methods we work with. We use a simple string to pointer map. The key string is generated from classname, method name and method desc. So key looks like "Math::add(II)I".

MethodLink* GetMethod(JavaClass *pClass, method_info_ex *pMethod, u4 pc)
    static CMapStringToPtr methodsMap;

    u2 mi=getu2(&pMethod->pCode_attr->code[pc+1]);
    char *pConstPool = (char *)pClass->constant_pool[mi];
    u2 classIndex = getu2(&pConstPool[1]);
    u2 nameAndTypeIndex = getu2(&pConstPool[3]);

    //get class at pool index 
    pConstPool = (char *)pClass->constant_pool[classIndex];

    ASSERT(pConstPool[0] == CONSTANT_Class);

    u2 ni=getu2(&pConstPool[1]);

    CString strClassName;
    pClass->GetStringFromConstPool(ni, strClassName);

    ClassHeap *pClassHeap = new ClassHeap();

    JavaClass *pClassCallee=pClassHeap->GetClass(strClassName);

    pConstPool = (char *)pClassCallee->constant_pool[nameAndTypeIndex];
    ASSERT(pConstPool[0] == CONSTANT_NameAndType);
    u2 name_index = getu2(&pConstPool[1]);
    u2 descriptor_index = getu2(&pConstPool[3]);

    CString strMethodName, strMethodDesc;
    pClassCallee->GetStringFromConstPool(name_index, strMethodName);
    pClassCallee->GetStringFromConstPool(descriptor_index, strMethodDesc);
    JavaClass *pVirtualClass=pClassCallee;
    int nIndex=pClassCallee->GetMethodIndex(strMethodName, strMethodDesc, pVirtualClass);
    method_info_ex *pCalleeMethod = &pClassCallee->methods[nIndex];    

    if( ACC_SUPER & pCalleeMethod->access_flags)
        pCalleeMethod = pClassCallee->GetSuperClass();

    CString sign(strClassName+"::"+strMethodName+strMethodDesc);
    MethodLink *pLink=NULL;
    if(!methodsMap.Lookup(sign, (void *&)pLink))
        pLink = new MethodLink();
        pLink->pClass = pClassCallee;
        pLink->pMethod = pCalleeMethod;
        pLink->pNativeBlock = NULL;
        methodsMap.SetAt(sign, pLink);    

    return pLink;

To call a method we do not generate the statck preparation code using machine code for now to keep the things simple. We'll do that after we finish all type of code generation. So from native code we call back to a C++ method that again calls into generated codes-

void CallMethod(MethodLink *pMethodLink, Context *pRE)

    int codeBlockSize = pMethodLink->pMethod->pCode_attr->code_length*2;   //todo guess better

    int (*NativeBlock)(Context *)=(int (*)(Context *)) VirtualAlloc(NULL, codeBlockSize,  MEM_COMMIT, PAGE_EXECUTE_READWRITE);
    u1* codes = (u1*) NativeBlock;

    int ip =0;

    JavaClass *pClass = pMethodLink->pClass;

    if(NULL == pMethodLink->pNativeBlock)
        Compile(pMethodLink->pClass, pMethodLink->pMethod, codes, ip);
        pMethodLink->pNativeBlock = codes;

    CString strName, strDesc;
    pMethodLink->pClass->GetStringFromConstPool(pMethodLink->pMethod->name_index, strName);
    pMethodLink->pClass->GetStringFromConstPool(pMethodLink->pMethod->descriptor_index, strDesc);

    int params=GetMethodParametersStackCount(strDesc)+1;
    //invokestatic: we are only dealing with static methods so far

    int nDiscardStack =params;
    if(pMethodLink->pMethod->access_flags & ACC_NATIVE)
        nDiscardStack+= pMethodLink->pMethod->pCode_attr->max_locals; 

    LOG(_T("Invoking method %s%s, \n"), strName, strDesc);


    //if returns then get on stack    
    if(strDesc.Find(_T(")V")) < 0)
        if(strDesc.Find(_T(")J")) < 0)


OK, thats the callbacks we need for now. Now we generate the actual machine code that will use the MethodLink* value to call back to the CallMethod function. To do this we use a function pointer list and store it in the context environment-


void*  HelperMethods[] = {

Let us now define the InvokeStatic helper method.

void InvokeStatic(JavaClass *pClass, method_info_ex *pMethod, u4 pc, u1* codes, int &ip)
    MethodLink* pLink = GetMethod(pClass, pMethod, pc);
    EmitCallMethod(codes, ip, pLink);

void EmitCallMethod(u1* code, int &ip, void* pLinkAddress)
    //((void (*)(MethodLink *pMethodLink))pRE->pVMEnv->ppHelperMethods[CALL_METHOD_HELPER_INDEX])(pLinkAddress, pRE);
    u1 c[] = {
         0x8B, 0x45, 0x08, //         mov         eax,dword ptr [pRE] 
         0x50, //              push        eax  
         0x68, 0x00, 0x00, 0x00, 0x00, //   push        pLinkAddress 
         0x8B, 0x4D, 0x08, //         mov         ecx,dword ptr [pRE] 
         0x8B, 0x51, 0x10, //         mov         edx,dword ptr [ecx+10h] 
         0x8B, 0x42, 0x08, //         mov         eax,dword ptr [edx+8] 
         0x8B, 0x08, //            mov         ecx,dword ptr [eax] 
         0xFF, 0xD1, //            call        ecx  
         0x83, 0xC4, 0x08, //         add         esp,8 

    memcpy(c+5, &pLinkAddress, 4);
    memcpy(&code[ip], c, sizeof(c));

To compile the methods we define a function that generates machine code for java byte codes. This function does not handle branch instructions right now. To handle branch we probably need two pass- since we would not know the exact address during first pass. So, here is a large while loop to do basic things:

u4 Compile(JavaClass *pClass, method_info_ex *pMethod, u1 *codes, int &ip)
    if(pMethod->access_flags & ACC_NATIVE)
        return 1;

    Prolog(codes, ip);    

    u4 pc=0;
    u1 *bc=pMethod->pCode_attr->code;    
    i4 error=0;
    CString strMethod;
    pClass->GetStringFromConstPool(pMethod->name_index, strMethod);
    i4 index=0;
        LOG(_T("Opcode = %s\n"),OpcodeDesc[(u1)bc[pc]]); 

        case nop:

        case bipush:// 16 /*(0x10)*/
            BiPush(codes, ip, (u1)bc[pc+1]);
        case iload_0: //26 Load int from local variable 0 
            ILoad_0(codes, ip);

        case iload_1: //27 Load int from local variable 1 
            ILoad_1(codes, ip);
        case iload_2: //28 Load int from local variable 2 
            ILoad_2(codes, ip);
        case iload_3: //29 Load int from local variable 3 
            ILoad_3(codes, ip);

        case istore_2: // 61 /*(0x3d) */
            IStore_2(codes, ip);

        case iadd: //96
            IAdd(codes, ip);

        case invokestatic:// 184 
            InvokeStatic(pClass, pMethod, pc, codes, ip);
        case ireturn: //172 (0xac)            
            IReturn(codes, ip);


        if(error) break;

    Return0(codes, ip);
    Epilog(codes, ip);

    return error;

OK, we are now ready to test out code:

int main()
    Context *pRE = new Context();;
    pRE->stack = new Variable[STACK_SIZE];    
    pRE->stackTop = 0;
    memset(pRE->stack, 0, sizeof(Variable)*STACK_SIZE);

    pRE->pVMEnv = new VMEnvironment();
    pRE->pVMEnv->pClassHeap = new ClassHeap();
    pRE->pVMEnv->pObjectHeap = new ObjectHeap();

    pRE->pVMEnv->ppHelperMethods = HelperMethods;

    ClassHeap* pClsHeap =  pRE->pVMEnv->pClassHeap;
    JavaClass jc;
    pClsHeap->LoadClass("Math", &jc);
    JavaClass *pVirtualClass =&jc,  *pClass1 = &jc;

    int mindex=pClass1->GetMethodIndex(_T("SimpleCall2"),_T("()I"),pVirtualClass); 

    method_info_ex *pMethod = &pVirtualClass->methods[mindex];

    MethodLink *pMethodLink = new MethodLink();
    pMethodLink->pClass = pVirtualClass;
    pMethodLink->pMethod = pMethod;    

    ((void (*)(MethodLink *pMethodLink, Context *pRE))pRE->pVMEnv->ppHelperMethods[CALL_METHOD_HELPER_INDEX])(pMethodLink, pRE);
    LOG(_T("Return Value = %d"), pRE->stack[0].intValue);

    return 0;

Do you see value 46 on the stack as return value? Cool!

No comments:

Post a Comment