diff --git a/src/quackseq/event.py b/src/quackseq/event.py index 7a6c46a..054256d 100644 --- a/src/quackseq/event.py +++ b/src/quackseq/event.py @@ -1,7 +1,6 @@ import logging from collections import OrderedDict -from quackseq.pulsesequence import PulseSequence from quackseq.pulseparameters import Option from quackseq.helpers import UnitConverter @@ -13,7 +12,7 @@ class Event: Args: name (str): The name of the event - duration (str): The duration of the event + duration (float | str): The duration of the event, either as a float or a string with a unit suffix (n, u, m) pulse_sequence (PulseSequence): The pulse sequence the event belongs to Attributes: @@ -22,7 +21,7 @@ class Event: pulse_sequence (PulseSequence): The pulse sequence the event belongs to """ - def __init__(self, name: str, duration: str, pulse_sequence : PulseSequence) -> None: + def __init__(self, name: str, duration: float | str, pulse_sequence : "PulseSequence") -> None: """Initializes the event.""" self.parameters = OrderedDict() self.name = name @@ -46,15 +45,6 @@ class Event: id(self.parameters[name]), ) - def on_duration_changed(self, duration: str) -> None: - """This method is called when the duration of the event is changed. - - Args: - duration (str): The new duration of the event - """ - logger.debug("Duration of event %s changed to %s", self.name, duration) - self.duration = duration - @classmethod def load_event(cls, event, pulse_parameter_options): """Loads an event from a dict. @@ -95,12 +85,11 @@ class Event: return self._duration @duration.setter - def duration(self, duration: str): + def duration(self, duration: float | str): # Duration needs to be a positive number - try: + if isinstance(duration, str): duration = UnitConverter.to_float(duration) - except ValueError: - raise ValueError("Duration needs to be a number") + if duration < 0: raise ValueError("Duration needs to be a positive number") diff --git a/src/quackseq/pulsesequence.py b/src/quackseq/pulsesequence.py index 6497ba4..23f33bf 100644 --- a/src/quackseq/pulsesequence.py +++ b/src/quackseq/pulsesequence.py @@ -5,7 +5,8 @@ import importlib.metadata from collections import OrderedDict from quackseq.pulseparameters import PulseParameter, TXPulse, RXReadout -from quackseq.functions import Function +from quackseq.functions import Function, RectFunction +from quackseq.event import Event logger = logging.getLogger(__name__) @@ -60,19 +61,24 @@ class PulseSequence: Args: event (Event): The event to add """ + if event.name in self.get_event_names(): + raise ValueError(f"Event with name {event.name} already exists in the pulse sequence") self.events.append(event) - def create_event(self, event_name: str, duration: float) -> "Event": + def create_event(self, event_name: str, duration: str) -> "Event": """Create a new event and return it. Args: - event_name (str): The name of the event + event_name (str): The name of the event with a unit suffix (n, u, m) duration (float): The duration of the event Returns: Event: The created event """ - event = self.Event(event_name, f"{float(duration):.16g}u") + event = Event(event_name, duration, self) + if event.name in self.get_event_names(): + raise ValueError(f"Event with name {event.name} already exists in the pulse sequence") + self.events.append(event) return event @@ -215,6 +221,26 @@ class QuackSequence(PulseSequence): self.add_pulse_parameter_option(self.TX_PULSE, TXPulse) self.add_pulse_parameter_option(self.RX_READOUT, RXReadout) + def add_blank_event(self, event_name: str, duration: float): + event = self.create_event(event_name, duration) + + def add_pulse_event( + self, + event_name: str, + duration: float, + amplitude: float, + phase: float, + shape: Function = RectFunction(), + ): + event = self.create_event(event_name, duration) + self.set_tx_amplitude(event, amplitude) + self.set_tx_phase(event, phase) + self.set_tx_shape(event, shape) + + def add_readout_event(self, event_name: str, duration: float): + event = self.create_event(event_name, duration) + self.set_rx(event, True) + # TX Specific functions def set_tx_amplitude(self, event, amplitude: float) -> None: @@ -259,6 +285,4 @@ class QuackSequence(PulseSequence): event (Event): The event to set the receiver for rx (bool): The receiver state """ - event.parameters[self.RX_READOUT].get_option_by_name( - RXReadout.RX - ).value = rx + event.parameters[self.RX_READOUT].get_option_by_name(RXReadout.RX).value = rx diff --git a/tests/simulator.py b/tests/simulator.py index 840f0e7..9d47031 100644 --- a/tests/simulator.py +++ b/tests/simulator.py @@ -1,52 +1,78 @@ -# Dummy test to communicate the structure +import unittest import logging import matplotlib.pyplot as plt - from quackseq.pulsesequence import QuackSequence from quackseq.event import Event from quackseq.functions import RectFunction from quackseq.spectrometer.simulator import Simulator -logging.basicConfig(level=logging.DEBUG) +# logging.basicConfig(level=logging.DEBUG) -seq = QuackSequence("test") -tx = Event("tx", "10u", seq) +class TestQuackSequence(unittest.TestCase): -seq.add_event(tx) + def test_event_creation(self): + seq = QuackSequence("test - event creation") + seq.add_pulse_event("tx", "10u", 1, 0, RectFunction()) + seq.add_blank_event("blank", "3u") + seq.add_readout_event("rx", "100u") + seq.add_blank_event("TR", "1m") -seq.set_tx_amplitude(tx, 1) -seq.set_tx_phase(tx, 0) + json = seq.to_json() + print(json) -rect = RectFunction() + sim = Simulator() + sim.set_averages(100) -seq.set_tx_shape(tx, rect) + result = sim.run_sequence(seq) + self.assertIsNotNone(result) + self.assertTrue(hasattr(result, "tdx")) + self.assertTrue(hasattr(result, "tdy")) + self.assertGreater(len(result.tdx), 0) + self.assertGreater(len(result.tdy), 0) -blank = Event("blank", "3u", seq) + # Plotting the result can be useful for visual inspection during development + plt.plot(result.tdx, abs(result.tdy)) + plt.show() -seq.add_event(blank) + def test_simulation_run_sequence(self): + seq = QuackSequence("test - simulation run sequence") -rx = Event("rx", "50u", seq) -#rx.set_rx_phase(0) + tx = Event("tx", "10u", seq) + seq.add_event(tx) + seq.set_tx_amplitude(tx, 1) + seq.set_tx_phase(tx, 0) -seq.set_rx(rx, True) + json = seq.to_json() + print(json) -seq.add_event(rx) + rect = RectFunction() + seq.set_tx_shape(tx, rect) -TR = Event("TR", "1m", seq) + blank = Event("blank", "3u", seq) + seq.add_event(blank) -seq.add_event(TR) + rx = Event("rx", "100u", seq) + seq.set_rx(rx, True) + seq.add_event(rx) -json = seq.to_json() + TR = Event("TR", "1m", seq) + seq.add_event(TR) -print(json) + sim = Simulator() + sim.set_averages(100) -sim = Simulator() + result = sim.run_sequence(seq) + self.assertIsNotNone(result) + self.assertTrue(hasattr(result, "tdx")) + self.assertTrue(hasattr(result, "tdy")) + self.assertGreater(len(result.tdx), 0) + self.assertGreater(len(result.tdy), 0) -sim.set_averages(100) + # Plotting the result can be useful for visual inspection during development + plt.plot(result.tdx, abs(result.tdy)) + plt.show() -# Returns the data at the RX event -result = sim.run_sequence(seq) -plt.plot(result.tdx, abs(result.tdy)) -plt.show() +if __name__ == "__main__": + unittest.main()