Basic phase table generation with bugs.

This commit is contained in:
jupfi 2024-06-06 19:45:24 +02:00
parent 3cf8000389
commit 1a1a2150f6
3 changed files with 216 additions and 3 deletions

179
src/quackseq/phase_table.py Normal file
View file

@ -0,0 +1,179 @@
import logging
import numpy as np
from collections import OrderedDict
from quackseq.pulsesequence import QuackSequence
from quackseq.pulseparameters import TXPulse
logger = logging.getLogger(__name__)
class PhaseTable:
"""A phase table interprets the TX parameters of a QuackSequence object and then generates a table of different phase values and signs for each phasecycle."""
def __init__(self, quackseq):
self.quackseq = quackseq
self.phase_table = self.generate_phase_table()
def generate_phase_table(self):
"""Generate a list of phases for each phasecycle in the sequence."""
phase_table = OrderedDict()
events = self.quackseq.events
for event in events:
for parameter in event.parameters.values():
if parameter.name == self.quackseq.TX_PULSE:
if (
parameter.get_option_by_name(TXPulse.RELATIVE_AMPLITUDE).value
> 0
):
phase_group = parameter.get_option_by_name(
TXPulse.PHASE_CYCLE_GROUP
).value
phase_values = parameter.get_phases()
phase_table[parameter] = (phase_group, phase_values)
logger.info(phase_table)
# First we make sure that all phase groups are in direct sucessive order. E.if there is a a phase group 0 and phase group 2 then phase group 2 will be renamed to 1.
phase_groups = [phase_group for phase_group, _ in phase_table.values()]
phase_groups = list(set(phase_groups))
phase_groups.sort()
for i, phase_group in enumerate(phase_groups):
if i != phase_group:
for parameter, (group, phase_values) in phase_table.items():
if group == phase_group:
phase_table[parameter] = (i, phase_values)
logger.info(phase_table)
# Now get the maximum phase group
max_phase_group = max([phase_group for phase_group, _ in phase_table.values()])
logger.info(f"Max phase group: {max_phase_group}")
# The columns of the phase table are the number of parameters in the phase table
n_columns = len(phase_table)
logger.info(f"Number of columns: {n_columns}")
# The number of rows is the maximum number of phase values in a phase group multiplied for every phase group
n_rows = 1
for i in range(max_phase_group + 1):
for parameter, (group, phase_values) in phase_table.items():
if group == i:
n_rows *= len(phase_values)
break
logger.info(f"Number of rows: {n_rows}")
# Create the phase table
phase_array = np.zeros((n_rows, n_columns))
groups = [group for group, _ in phase_table.values()]
groups = list(set(groups))
group_phases = {}
pulse_phases = {}
for group in groups:
# All the parameters that belong to the same group
parameters = {
parameter: phase_values
for parameter, (p_group, phase_values) in phase_table.items()
if p_group == group
}
# The maximum number of phase values in the group
max_phase_values = max(
[len(phase_values) for phase_values in parameters.values()]
)
# Fill up the phase tables of parameters that have less than the maximum number of phase values by repeating the values
for parameter, phase_values in parameters.items():
if len(phase_values) < max_phase_values:
phase_values = np.tile(
phase_values, max_phase_values // len(phase_values)
)
pulse_phases[parameter] = phase_values
logger.info(
f"Phase values for parameter {parameter}: {phase_values}"
)
parameters[parameter] = phase_values
else:
pulse_phases[parameter] = phase_values
logger.info(
f"Phase values for parameter {parameter}: {phase_values}"
)
logger.info(f"Parameters for group {group}: {parameters}")
group_phases[group] = parameters
def get_group_phases(group):
current_group_phases = []
phase_length = len(list(group_phases[group].values())[0])
group_parameters = list(group_phases[group].items())
first_parameter = group_parameters[0][0]
for i in range(phase_length):
for parameter, phases in group_phases[group].items():
if parameter == first_parameter:
current_group_phases.append([parameter, phases[i]])
sub_group_phases = []
if group + 1 in group_phases:
sub_group_phases = get_group_phases(group + 1)
total_group_phases = []
for parameter, phase in current_group_phases:
for sub_group_phase in sub_group_phases:
total_group_phases.append([parameter, phase, *sub_group_phase])
if not total_group_phases:
total_group_phases = current_group_phases
for i in range(phase_length):
for parameter, phases in group_phases[group].items():
if parameter != first_parameter:
total_group_phases[i] += [parameter, phases[i]]
return total_group_phases
all_phases = get_group_phases(0)
for row, phases in enumerate(all_phases):
phases = [phases[i : i + 2] for i in range(0, len(phases), 2)]
logger.info(f"Phases: {phases}")
for phase in phases:
parameter, phase_value = phase
column = list(phase_table.keys()).index(parameter)
phase_array[row, column] = phase_value
# logger.info(f"Phase table: {phase_table}")
# logger.info(f"Pulse phases: {pulse_phases}")
# for column in range(n_columns):
# Get the parameter
# parameter = list(phase_table.keys())[column]
# The division factor is the number of rows divided by the length of the pulse phases of the parameter
# logger.info(f"Len of pulse phases: {(pulse_phases[parameter])}")
# division_factor = (n_rows // len(pulse_phases[parameter]))
# logger.info(f"Division factor: {division_factor}")
# for row in range(n_rows):
# The phase value is the row index divided by the division factor
# phase_array[row, column] = pulse_phases[parameter][row // division_factor]
logger.info(phase_array)
return phase_table

View file

@ -7,6 +7,7 @@ Todo:
from __future__ import annotations from __future__ import annotations
import logging import logging
import numpy as np
from numpy.core.multiarray import array as array from numpy.core.multiarray import array as array
from quackseq.options import BooleanOption, FunctionOption, NumericOption, Option from quackseq.options import BooleanOption, FunctionOption, NumericOption, Option
@ -87,7 +88,7 @@ class TXPulse(PulseParameter):
TX_PHASE = "TX Phase" TX_PHASE = "TX Phase"
TX_PULSE_SHAPE = "TX Pulse Shape" TX_PULSE_SHAPE = "TX Pulse Shape"
N_PHASE_CYCLES = "Number of Phase Cycles" N_PHASE_CYCLES = "Number of Phase Cycles"
PHASE_CYCLE_LEVEL = "Phase Cycle Level" PHASE_CYCLE_GROUP = "Phase Cycle Group"
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
"""Initializes the TX Pulse Parameter. """Initializes the TX Pulse Parameter.
@ -118,7 +119,7 @@ class TXPulse(PulseParameter):
) )
self.add_option( self.add_option(
NumericOption( NumericOption(
self.PHASE_CYCLE_LEVEL, self.PHASE_CYCLE_GROUP,
0, 0,
is_float=False, is_float=False,
min_value=0, min_value=0,
@ -138,6 +139,18 @@ class TXPulse(PulseParameter):
), ),
) )
def get_phases(self):
"""Gets the phase of the TX Pulse.
Returns:
np.array : The phases of the TX Pulse
"""
n_phase_cycles = self.get_option_by_name(self.N_PHASE_CYCLES).value
phase = self.get_option_by_name(self.TX_PHASE).value
return (np.linspace(0, 360, n_phase_cycles, endpoint=False) + phase) % 360
class RXReadout(PulseParameter): class RXReadout(PulseParameter):
"""Basic PulseParameter for the RX Readout. It includes an option for the RX Readout state. """Basic PulseParameter for the RX Readout. It includes an option for the RX Readout state.

View file

@ -83,7 +83,6 @@ class PulseSequence:
raise ValueError( raise ValueError(
f"Event with name {event.name} already exists in the pulse sequence" f"Event with name {event.name} already exists in the pulse sequence"
) )
self.events.append(event) self.events.append(event)
return event return event
@ -309,6 +308,28 @@ class QuackSequence(PulseSequence):
TXPulse.TX_PULSE_SHAPE TXPulse.TX_PULSE_SHAPE
).value = shape ).value = shape
def set_tx_n_phase_cycles(self, event, n_phase_cycles: int) -> None:
"""Sets the number of phase cycles for the transmit pulse.
Args:
event (Event): The event to set the number of phase cycles for
n_phase_cycles (int): The number of phase cycles
"""
event.parameters[self.TX_PULSE].get_option_by_name(
TXPulse.N_PHASE_CYCLES
).value = n_phase_cycles
def set_tx_phase_cycle_group(self, event, phase_cycle_group: int) -> None:
"""Sets the phase cycle group for the transmit pulse.
Args:
event (Event): The event to set the phase cycle group for
phase_cycle_group (int): The phase cycle group
"""
event.parameters[self.TX_PULSE].get_option_by_name(
TXPulse.PHASE_CYCLE_GROUP
).value = phase_cycle_group
# RX Specific functions # RX Specific functions
def set_rx(self, event, rx: bool) -> None: def set_rx(self, event, rx: bool) -> None: