Generic least-squares function¶
We delve a bit deeper into how to write generic cost functions that work well with iminuit.
Note: cost functions for the most use cases can be imported from iminuit.cost
, including a generic least-squares function. The builtin cost functions come with extra features and use some insights to make them work better than naive implementations, so they are worth checking out. This tutorial is therefore not a guide to make the best least-squares cost function, it just illustrates how to write cost functions yourself.
We have seen in the basic tutorial how to make a least-squares function with an explicit signature that iminuit could read to find the parameter names automatically. Part of the structure of a least-squares function is always the same. What changes is the model that predicts the y-values and its parameters.
Here we show how to make a generic weighted least-squares class that works with iminuit.
[1]:
import numpy as np
from iminuit import Minuit
from iminuit.util import describe, make_func_code
[2]:
class LeastSquares:
"""
Generic least-squares cost function with error.
"""
errordef = Minuit.LEAST_SQUARES # for Minuit to compute errors correctly
def __init__(self, model, x, y, err):
self.model = model # model predicts y for given x
self.x = np.asarray(x)
self.y = np.asarray(y)
self.err = np.asarray(err)
def __call__(self, *par): # we accept a variable number of model parameters
ym = self.model(self.x, *par)
return np.sum((self.y - ym) ** 2 / self.err ** 2)
Let’s try it out with iminuit.
[3]:
def line(x, a, b): # simple straight line model with explicit parameters
return a + b * x
x_data = [1, 2, 3, 4, 5]
y_data = [2, 4, 6, 8, 10]
y_err = np.sqrt(y_data)
lsq = LeastSquares(line, x_data, y_data, y_err)
# this fails
try:
m = Minuit(lsq, a=0, b=0)
m.errordef=Minuit.LEAST_SQUARES
except:
import traceback
traceback.print_exc()
Traceback (most recent call last):
File "/tmp/ipykernel_7145/285277255.py", line 13, in <cell line: 12>
m = Minuit(lsq, a=0, b=0)
File "/build/python-iminuit/src/python-iminuit/build/lib.linux-x86_64-3.10/iminuit/minuit.py", line 617, in __init__
self._init_state = _make_init_state(self._pos2var, start, kwds)
File "/build/python-iminuit/src/python-iminuit/build/lib.linux-x86_64-3.10/iminuit/minuit.py", line 2038, in _make_init_state
raise RuntimeError(
RuntimeError: a is not one of the parameters []
What happened? iminuit uses introspection to detect the parameter names and the number of parameters. It uses the describe
utility for that, but it fails, since the generic method signature LeastSquares.__call__(self, *par)
, does not reveal the number and names of the parameters.
The information could be extracted from the model signature, but iminuit knows nothing about the signature of line(x, a, b)
here. We can fix this by generating a function signature for the LeastSquares
class from the signature of the model.
[4]:
# get the args from line and strip 'x'
describe(line)[1:]
[4]:
['a', 'b']
[5]:
# now inject that into the lsq object with the make_func_code tool
lsq.func_code = make_func_code(describe(line)[1:])
# now we get the right answer
describe(lsq)
[5]:
['a', 'b']
We can put this code into the init function of our generic least-squares class to obtain a generic least-squares class which works with iminuit.
[6]:
class BetterLeastSquares(LeastSquares):
def __init__(self, model, x, y, err):
super().__init__(model, x, y, err)
self.func_code = make_func_code(describe(model)[1:])
[7]:
lsq = BetterLeastSquares(line, x_data, y_data, y_err)
[8]:
m = Minuit(lsq, a=0, b=0)
m.migrad()
[8]:
Migrad | ||||
---|---|---|---|---|
FCN = 4.83e-26 | Nfcn = 30 | |||
EDM = 4.83e-26 (Goal: 0.0002) | ||||
Valid Minimum | No Parameters at limit | |||
Below EDM threshold (goal x 10) | Below call limit | |||
Covariance | Hesse ok | Accurate | Pos. def. | Not forced |
Name | Value | Hesse Error | Minos Error- | Minos Error+ | Limit- | Limit+ | Fixed | |
---|---|---|---|---|---|---|---|---|
0 | a | 0.0 | 1.8 | |||||
1 | b | 2.0 | 0.7 |
a | b | |
---|---|---|
a | 3.24 | -1.08 (-0.854) |
b | -1.08 (-0.854) | 0.494 |
It works :).