diff --git a/src/quackseq/phase_table.py b/src/quackseq/phase_table.py index 630e4e1..e244b75 100644 --- a/src/quackseq/phase_table.py +++ b/src/quackseq/phase_table.py @@ -2,7 +2,6 @@ import logging import numpy as np from collections import OrderedDict -from quackseq.pulsesequence import QuackSequence from quackseq.pulseparameters import TXPulse logger = logging.getLogger(__name__) @@ -13,14 +12,23 @@ class PhaseTable: def __init__(self, quackseq): self.quackseq = quackseq - self.phase_table = self.generate_phase_table() + self.phase_array = self.generate_phase_array() - def generate_phase_table(self): - """Generate a list of phases for each phasecycle in the sequence.""" + def generate_phase_array(self): + """Generate a list of phases for each phasecycle in the sequence. + + Returns: + phase_array (np.array): A table of phase values for each phasecycle. + The columns are the values for the different TX pulse parameters and the rows are the different phase cycles. + """ phase_table = OrderedDict() events = self.quackseq.events + # If there are no events, return an empty array + if not events: + return np.array([]) + for event in events: for parameter in event.parameters.values(): if parameter.name == self.quackseq.TX_PULSE: @@ -38,6 +46,10 @@ class PhaseTable: logger.info(phase_table) + # If no TX events are found, return an empty array + if not phase_table: + return np.array([]) + # First we make sure that all phase groups are in direct sucessive order. E.if there is a a phase group 0 and phase group 2 then phase group 2 will be renamed to 1. phase_groups = [phase_group for phase_group, _ in phase_table.values()] phase_groups = list(set(phase_groups)) @@ -99,15 +111,9 @@ class PhaseTable: phase_values, int(np.ceil(max_phase_values / len(phase_values))) ) pulse_phases[parameter] = phase_values - logger.info( - f"Phase values for parameter {parameter}: {phase_values}" - ) parameters[parameter] = phase_values else: pulse_phases[parameter] = phase_values - logger.info( - f"Phase values for parameter {parameter}: {phase_values}" - ) logger.info(f"Parameters for group {group}: {parameters}") group_phases[group] = parameters @@ -141,7 +147,10 @@ class PhaseTable: for i in range(phase_length): for parameter, phases in group_phases[group].items(): if parameter != first_parameter: - total_group_phases[i] += [parameter, phases[i]] + try: + total_group_phases[i] += [parameter, phases[i]] + except IndexError: + logger.info(f"Index Error: Parameter {parameter}, Phases: {phases}") return total_group_phases @@ -149,31 +158,57 @@ class PhaseTable: for row, phases in enumerate(all_phases): phases = [phases[i : i + 2] for i in range(0, len(phases), 2)] - logger.info(f"Phases: {phases}") for phase in phases: parameter, phase_value = phase column = list(phase_table.keys()).index(parameter) - phase_array[row, column] = phase_value - - # logger.info(f"Phase table: {phase_table}") - # logger.info(f"Pulse phases: {pulse_phases}") - - # for column in range(n_columns): - - # Get the parameter - # parameter = list(phase_table.keys())[column] - - # The division factor is the number of rows divided by the length of the pulse phases of the parameter - # logger.info(f"Len of pulse phases: {(pulse_phases[parameter])}") - # division_factor = (n_rows // len(pulse_phases[parameter])) - - # logger.info(f"Division factor: {division_factor}") - - # for row in range(n_rows): - # The phase value is the row index divided by the division factor - # phase_array[row, column] = pulse_phases[parameter][row // division_factor] + try: + phase_array[row, column] = phase_value + except IndexError: + logger.info(f"Index error: {row}, {column}, {phase_value}") logger.info(phase_array) - return phase_table + return phase_array + + def update_phase_array(self): + """Update the phase array of the sequence.""" + self.phase_array = self.generate_phase_array() + + @property + def phase_array(self) -> np.array: + """The phase array of the sequence.""" + return self._phase_array + + @phase_array.setter + def phase_array(self, phase_array : np.array): + self._phase_array = phase_array + + @property + def rx_phase_sign(self) -> list: + return self._rx_phase_sign + + @rx_phase_sign.setter + def rx_phase_sign(self, rx_phase_sign: list): + """The phase sign of the RX pulse. + + Args: + rx_phase_sign (list): A list of phase signs for the RX pulse. The different entries are tuples with the first element being the sign and the second element being the phase. + """ + + # Check that the rx_phase_sign has the same length as the number of rows in the phase table + + if len(rx_phase_sign) != self.phase_array.shape[0]: + raise ValueError( + f"The number of rows in the phase table is {self.phase_table.shape[0]} but the length of the rx_phase_sign is {len(rx_phase_sign)}" + ) + + self._rx_phase_sign = rx_phase_sign + + @property + def n_phase_cycles(self) -> int: + return self.phase_array.shape[0] + + @property + def n_parameters(self) -> int: + return self.phase_array.shape[1] diff --git a/src/quackseq/pulsesequence.py b/src/quackseq/pulsesequence.py index 092acc8..8305b8e 100644 --- a/src/quackseq/pulsesequence.py +++ b/src/quackseq/pulsesequence.py @@ -7,6 +7,7 @@ from collections import OrderedDict from quackseq.pulseparameters import PulseParameter, TXPulse, RXReadout from quackseq.functions import Function, RectFunction from quackseq.event import Event +from quackseq.phase_table import PhaseTable logger = logging.getLogger(__name__) @@ -226,6 +227,8 @@ class QuackSequence(PulseSequence): self.add_pulse_parameter_option(self.TX_PULSE, TXPulse) self.add_pulse_parameter_option(self.RX_READOUT, RXReadout) + self.phase_table = PhaseTable(self) + def add_blank_event(self, event_name: str, duration: float): """Adds a blank event to the pulse sequence.