มีวิธีใดบ้างที่จะทำให้โค้ดนี้มีประสิทธิภาพและเร็วขึ้น

import numpy as np
import matplotlib.pyplot as plt
import time
from numba import njit, prange

@njit('int_(int_, int_, int_[::1])')
def aloha(N, M, tmp):
    tmp.fill(0)
    for i in range(M):
        rnd = np.random.randint(0, N)
        tmp[rnd] += 1
    success = 0
    for i in range(N):
        success += tmp[i] == 1
    return success

@njit('float64(int_, int_, int_[::1])')
def repetition(I,M,tmp):
    s = 0
    for i in range(I):
        s += aloha(12, M, tmp)
    return s/I
    
@njit('float64[::1](int_)')
def simulation(M):
    tmp = np.zeros(12, np.int_) # Preallocated buffer
    success_0 = 0
    success_1 = 0
    success_2 = 0
    M0 = M
    M1 = M
    M2 = M
    s = 0
    # 1-5
    for j in range(5):
        s = 0
        # CE Lv 0
        s = repetition(8,round(M0),tmp)
        M0 -= s
        success_0 += s
        # CE Lv 1
        s = repetition(32,round(M1),tmp)
        M1 -= s
        success_1 += s
        # CE Lv 2
        s = repetition(128,round(M2),tmp)
        M2 -= s
        success_2 += s
    # 6-10
    for j in range(5):
        s = 0
        # CE Lv 0
        s = repetition(32,round(M0),tmp)
        M0 -= s
        success_0 += s
         # CE Lv 1 and 2
        M1 = M1 + M2
        s = repetition(128,round(M1),tmp)
        M1 -= s
        success_1 += s
        success_2 += s
    # 11-15
    for j in range(5):
        s = 0
        # CE Lv 0
        M0 = M0 + M1
        s = repetition(128,int(M0),tmp)
        M0 -= s
        success_0 += s
        success_1 += s
        success_2 += s
    return np.array([success_0/M,success_1/M,success_2/M], dtype=np.float64)

@njit('float64[::1](int_, int_)', parallel=True)
def compute_ps_avg(m, sample_size):
    s0 = np.zeros(sample_size, dtype=np.float64)
    s1 = np.zeros(sample_size, dtype=np.float64)
    s2 = np.zeros(sample_size, dtype=np.float64)
    for i in prange(sample_size):
        arr = simulation(m)
        s0[i]= arr[0]
        s1[i] = arr[2]
        s2[i] = arr[1]
    s0_result = np.mean(s0)
    s1_result = np.mean(s1)
    s2_result = np.mean(s2)
    return np.array([s0_result,s1_result,s2_result])

if __name__ == "__main__":
    start = time.perf_counter() 
    SAMPLE_SIZE = 100
    M = np.linspace(10,100,10)
    Result = [compute_ps_avg(m, SAMPLE_SIZE) for m in M]
    print(Result)
    Flatted = np.reshape(Result,-1)
    s0 = Flatted[0::3]
    s1 = Flatted[1::3]
    s2 = Flatted[2::3]
    elapsed = (time.perf_counter() - start)
    print("Time used:",elapsed)
    plt.scatter(M,s0,marker='s', facecolors='none', edgecolors='b')
    plt.scatter(M,s1,marker='o')
    plt.scatter(M,s2,s=80, facecolors='none', edgecolors='r')
    plt.ylim(ymin=0)
    plt.ylim(ymax=1.1)
    plt.xticks(M)
    plt.show()