Properties for phase table.

This commit is contained in:
jupfi 2024-06-07 14:52:35 +02:00
parent 50ae13f89a
commit 44938f5cc0
2 changed files with 70 additions and 32 deletions

View file

@ -2,7 +2,6 @@ import logging
import numpy as np
from collections import OrderedDict
from quackseq.pulsesequence import QuackSequence
from quackseq.pulseparameters import TXPulse
logger = logging.getLogger(__name__)
@ -13,14 +12,23 @@ class PhaseTable:
def __init__(self, quackseq):
self.quackseq = quackseq
self.phase_table = self.generate_phase_table()
self.phase_array = self.generate_phase_array()
def generate_phase_table(self):
"""Generate a list of phases for each phasecycle in the sequence."""
def generate_phase_array(self):
"""Generate a list of phases for each phasecycle in the sequence.
Returns:
phase_array (np.array): A table of phase values for each phasecycle.
The columns are the values for the different TX pulse parameters and the rows are the different phase cycles.
"""
phase_table = OrderedDict()
events = self.quackseq.events
# If there are no events, return an empty array
if not events:
return np.array([])
for event in events:
for parameter in event.parameters.values():
if parameter.name == self.quackseq.TX_PULSE:
@ -38,6 +46,10 @@ class PhaseTable:
logger.info(phase_table)
# If no TX events are found, return an empty array
if not phase_table:
return np.array([])
# First we make sure that all phase groups are in direct sucessive order. E.if there is a a phase group 0 and phase group 2 then phase group 2 will be renamed to 1.
phase_groups = [phase_group for phase_group, _ in phase_table.values()]
phase_groups = list(set(phase_groups))
@ -99,15 +111,9 @@ class PhaseTable:
phase_values, int(np.ceil(max_phase_values / len(phase_values)))
)
pulse_phases[parameter] = phase_values
logger.info(
f"Phase values for parameter {parameter}: {phase_values}"
)
parameters[parameter] = phase_values
else:
pulse_phases[parameter] = phase_values
logger.info(
f"Phase values for parameter {parameter}: {phase_values}"
)
logger.info(f"Parameters for group {group}: {parameters}")
group_phases[group] = parameters
@ -141,7 +147,10 @@ class PhaseTable:
for i in range(phase_length):
for parameter, phases in group_phases[group].items():
if parameter != first_parameter:
total_group_phases[i] += [parameter, phases[i]]
try:
total_group_phases[i] += [parameter, phases[i]]
except IndexError:
logger.info(f"Index Error: Parameter {parameter}, Phases: {phases}")
return total_group_phases
@ -149,31 +158,57 @@ class PhaseTable:
for row, phases in enumerate(all_phases):
phases = [phases[i : i + 2] for i in range(0, len(phases), 2)]
logger.info(f"Phases: {phases}")
for phase in phases:
parameter, phase_value = phase
column = list(phase_table.keys()).index(parameter)
phase_array[row, column] = phase_value
# logger.info(f"Phase table: {phase_table}")
# logger.info(f"Pulse phases: {pulse_phases}")
# for column in range(n_columns):
# Get the parameter
# parameter = list(phase_table.keys())[column]
# The division factor is the number of rows divided by the length of the pulse phases of the parameter
# logger.info(f"Len of pulse phases: {(pulse_phases[parameter])}")
# division_factor = (n_rows // len(pulse_phases[parameter]))
# logger.info(f"Division factor: {division_factor}")
# for row in range(n_rows):
# The phase value is the row index divided by the division factor
# phase_array[row, column] = pulse_phases[parameter][row // division_factor]
try:
phase_array[row, column] = phase_value
except IndexError:
logger.info(f"Index error: {row}, {column}, {phase_value}")
logger.info(phase_array)
return phase_table
return phase_array
def update_phase_array(self):
"""Update the phase array of the sequence."""
self.phase_array = self.generate_phase_array()
@property
def phase_array(self) -> np.array:
"""The phase array of the sequence."""
return self._phase_array
@phase_array.setter
def phase_array(self, phase_array : np.array):
self._phase_array = phase_array
@property
def rx_phase_sign(self) -> list:
return self._rx_phase_sign
@rx_phase_sign.setter
def rx_phase_sign(self, rx_phase_sign: list):
"""The phase sign of the RX pulse.
Args:
rx_phase_sign (list): A list of phase signs for the RX pulse. The different entries are tuples with the first element being the sign and the second element being the phase.
"""
# Check that the rx_phase_sign has the same length as the number of rows in the phase table
if len(rx_phase_sign) != self.phase_array.shape[0]:
raise ValueError(
f"The number of rows in the phase table is {self.phase_table.shape[0]} but the length of the rx_phase_sign is {len(rx_phase_sign)}"
)
self._rx_phase_sign = rx_phase_sign
@property
def n_phase_cycles(self) -> int:
return self.phase_array.shape[0]
@property
def n_parameters(self) -> int:
return self.phase_array.shape[1]

View file

@ -7,6 +7,7 @@ from collections import OrderedDict
from quackseq.pulseparameters import PulseParameter, TXPulse, RXReadout
from quackseq.functions import Function, RectFunction
from quackseq.event import Event
from quackseq.phase_table import PhaseTable
logger = logging.getLogger(__name__)
@ -226,6 +227,8 @@ class QuackSequence(PulseSequence):
self.add_pulse_parameter_option(self.TX_PULSE, TXPulse)
self.add_pulse_parameter_option(self.RX_READOUT, RXReadout)
self.phase_table = PhaseTable(self)
def add_blank_event(self, event_name: str, duration: float):
"""Adds a blank event to the pulse sequence.