Properties for phase table.

This commit is contained in:
jupfi 2024-06-07 14:52:35 +02:00
parent 50ae13f89a
commit 44938f5cc0
2 changed files with 70 additions and 32 deletions

View file

@ -2,7 +2,6 @@ import logging
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from quackseq.pulsesequence import QuackSequence
from quackseq.pulseparameters import TXPulse from quackseq.pulseparameters import TXPulse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -13,14 +12,23 @@ class PhaseTable:
def __init__(self, quackseq): def __init__(self, quackseq):
self.quackseq = quackseq self.quackseq = quackseq
self.phase_table = self.generate_phase_table() self.phase_array = self.generate_phase_array()
def generate_phase_table(self): def generate_phase_array(self):
"""Generate a list of phases for each phasecycle in the sequence.""" """Generate a list of phases for each phasecycle in the sequence.
Returns:
phase_array (np.array): A table of phase values for each phasecycle.
The columns are the values for the different TX pulse parameters and the rows are the different phase cycles.
"""
phase_table = OrderedDict() phase_table = OrderedDict()
events = self.quackseq.events events = self.quackseq.events
# If there are no events, return an empty array
if not events:
return np.array([])
for event in events: for event in events:
for parameter in event.parameters.values(): for parameter in event.parameters.values():
if parameter.name == self.quackseq.TX_PULSE: if parameter.name == self.quackseq.TX_PULSE:
@ -38,6 +46,10 @@ class PhaseTable:
logger.info(phase_table) logger.info(phase_table)
# If no TX events are found, return an empty array
if not phase_table:
return np.array([])
# First we make sure that all phase groups are in direct sucessive order. E.if there is a a phase group 0 and phase group 2 then phase group 2 will be renamed to 1. # First we make sure that all phase groups are in direct sucessive order. E.if there is a a phase group 0 and phase group 2 then phase group 2 will be renamed to 1.
phase_groups = [phase_group for phase_group, _ in phase_table.values()] phase_groups = [phase_group for phase_group, _ in phase_table.values()]
phase_groups = list(set(phase_groups)) phase_groups = list(set(phase_groups))
@ -99,15 +111,9 @@ class PhaseTable:
phase_values, int(np.ceil(max_phase_values / len(phase_values))) phase_values, int(np.ceil(max_phase_values / len(phase_values)))
) )
pulse_phases[parameter] = phase_values pulse_phases[parameter] = phase_values
logger.info(
f"Phase values for parameter {parameter}: {phase_values}"
)
parameters[parameter] = phase_values parameters[parameter] = phase_values
else: else:
pulse_phases[parameter] = phase_values pulse_phases[parameter] = phase_values
logger.info(
f"Phase values for parameter {parameter}: {phase_values}"
)
logger.info(f"Parameters for group {group}: {parameters}") logger.info(f"Parameters for group {group}: {parameters}")
group_phases[group] = parameters group_phases[group] = parameters
@ -141,7 +147,10 @@ class PhaseTable:
for i in range(phase_length): for i in range(phase_length):
for parameter, phases in group_phases[group].items(): for parameter, phases in group_phases[group].items():
if parameter != first_parameter: if parameter != first_parameter:
total_group_phases[i] += [parameter, phases[i]] try:
total_group_phases[i] += [parameter, phases[i]]
except IndexError:
logger.info(f"Index Error: Parameter {parameter}, Phases: {phases}")
return total_group_phases return total_group_phases
@ -149,31 +158,57 @@ class PhaseTable:
for row, phases in enumerate(all_phases): for row, phases in enumerate(all_phases):
phases = [phases[i : i + 2] for i in range(0, len(phases), 2)] phases = [phases[i : i + 2] for i in range(0, len(phases), 2)]
logger.info(f"Phases: {phases}")
for phase in phases: for phase in phases:
parameter, phase_value = phase parameter, phase_value = phase
column = list(phase_table.keys()).index(parameter) column = list(phase_table.keys()).index(parameter)
phase_array[row, column] = phase_value try:
phase_array[row, column] = phase_value
# logger.info(f"Phase table: {phase_table}") except IndexError:
# logger.info(f"Pulse phases: {pulse_phases}") logger.info(f"Index error: {row}, {column}, {phase_value}")
# for column in range(n_columns):
# Get the parameter
# parameter = list(phase_table.keys())[column]
# The division factor is the number of rows divided by the length of the pulse phases of the parameter
# logger.info(f"Len of pulse phases: {(pulse_phases[parameter])}")
# division_factor = (n_rows // len(pulse_phases[parameter]))
# logger.info(f"Division factor: {division_factor}")
# for row in range(n_rows):
# The phase value is the row index divided by the division factor
# phase_array[row, column] = pulse_phases[parameter][row // division_factor]
logger.info(phase_array) logger.info(phase_array)
return phase_table return phase_array
def update_phase_array(self):
"""Update the phase array of the sequence."""
self.phase_array = self.generate_phase_array()
@property
def phase_array(self) -> np.array:
"""The phase array of the sequence."""
return self._phase_array
@phase_array.setter
def phase_array(self, phase_array : np.array):
self._phase_array = phase_array
@property
def rx_phase_sign(self) -> list:
return self._rx_phase_sign
@rx_phase_sign.setter
def rx_phase_sign(self, rx_phase_sign: list):
"""The phase sign of the RX pulse.
Args:
rx_phase_sign (list): A list of phase signs for the RX pulse. The different entries are tuples with the first element being the sign and the second element being the phase.
"""
# Check that the rx_phase_sign has the same length as the number of rows in the phase table
if len(rx_phase_sign) != self.phase_array.shape[0]:
raise ValueError(
f"The number of rows in the phase table is {self.phase_table.shape[0]} but the length of the rx_phase_sign is {len(rx_phase_sign)}"
)
self._rx_phase_sign = rx_phase_sign
@property
def n_phase_cycles(self) -> int:
return self.phase_array.shape[0]
@property
def n_parameters(self) -> int:
return self.phase_array.shape[1]

View file

@ -7,6 +7,7 @@ from collections import OrderedDict
from quackseq.pulseparameters import PulseParameter, TXPulse, RXReadout from quackseq.pulseparameters import PulseParameter, TXPulse, RXReadout
from quackseq.functions import Function, RectFunction from quackseq.functions import Function, RectFunction
from quackseq.event import Event from quackseq.event import Event
from quackseq.phase_table import PhaseTable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -226,6 +227,8 @@ class QuackSequence(PulseSequence):
self.add_pulse_parameter_option(self.TX_PULSE, TXPulse) self.add_pulse_parameter_option(self.TX_PULSE, TXPulse)
self.add_pulse_parameter_option(self.RX_READOUT, RXReadout) self.add_pulse_parameter_option(self.RX_READOUT, RXReadout)
self.phase_table = PhaseTable(self)
def add_blank_event(self, event_name: str, duration: float): def add_blank_event(self, event_name: str, duration: float):
"""Adds a blank event to the pulse sequence. """Adds a blank event to the pulse sequence.