มีวิธีใดบ้างที่จะทำให้โค้ดนี้มีประสิทธิภาพและเร็วขึ้น

import numpy as np
import matplotlib.pyplot as plt
import time
from numba import njit, prange

@njit('int_(int_, int_, int_[::1])')
def aloha(N, M, tmp):
  tmp.fill(0)
  for i in range(M):
    rnd = np.random.randint(0, N)
    tmp[rnd] += 1
  success = 0
  for i in range(N):
    success += tmp[i] == 1
  return success

@njit('float64(int_, int_, int_[::1])')
def repetition(I,M,tmp):
  s = 0
  for i in range(I):
    s += aloha(12, M, tmp)
  return s/I
  
@njit('float64[::1](int_)')
def simulation(M):
  tmp = np.zeros(12, np.int_) # Preallocated buffer
  success_0 = 0
  success_1 = 0
  success_2 = 0
  M0 = M
  M1 = M
  M2 = M
  s = 0
  # 1-5
  for j in range(5):
    s = 0
    # CE Lv 0
    s = repetition(8,round(M0),tmp)
    M0 -= s
    success_0 += s
    # CE Lv 1
    s = repetition(32,round(M1),tmp)
    M1 -= s
    success_1 += s
    # CE Lv 2
    s = repetition(128,round(M2),tmp)
    M2 -= s
    success_2 += s
  # 6-10
  for j in range(5):
    s = 0
    # CE Lv 0
    s = repetition(32,round(M0),tmp)
    M0 -= s
    success_0 += s
     # CE Lv 1 and 2
    M1 = M1 + M2
    s = repetition(128,round(M1),tmp)
    M1 -= s
    success_1 += s
    success_2 += s
  # 11-15
  for j in range(5):
    s = 0
    # CE Lv 0
    M0 = M0 + M1
    s = repetition(128,int(M0),tmp)
    M0 -= s
    success_0 += s
    success_1 += s
    success_2 += s
  return np.array([success_0/M,success_1/M,success_2/M], dtype=np.float64)

@njit('float64[::1](int_, int_)', parallel=True)
def compute_ps_avg(m, sample_size):
  s0 = np.zeros(sample_size, dtype=np.float64)
  s1 = np.zeros(sample_size, dtype=np.float64)
  s2 = np.zeros(sample_size, dtype=np.float64)
  for i in prange(sample_size):
    arr = simulation(m)
    s0[i]= arr[0]
    s1[i] = arr[2]
    s2[i] = arr[1]
  s0_result = np.mean(s0)
  s1_result = np.mean(s1)
  s2_result = np.mean(s2)
  return np.array([s0_result,s1_result,s2_result])

if __name__ == "__main__":
  start = time.perf_counter() 
  SAMPLE_SIZE = 100
  M = np.linspace(10,100,10)
  Result = [compute_ps_avg(m, SAMPLE_SIZE) for m in M]
  print(Result)
  Flatted = np.reshape(Result,-1)
  s0 = Flatted[0::3]
  s1 = Flatted[1::3]
  s2 = Flatted[2::3]
  elapsed = (time.perf_counter() - start)
  print("Time used:",elapsed)
  plt.scatter(M,s0,marker='s', facecolors='none', edgecolors='b')
  plt.scatter(M,s1,marker='o')
  plt.scatter(M,s2,s=80, facecolors='none', edgecolors='r')
  plt.ylim(ymin=0)
  plt.ylim(ymax=1.1)
  plt.xticks(M)
  plt.show()