"""Implementation of the scalar (quantum) algebra"""
from abc import ABCMeta
from collections import OrderedDict
from itertools import product as cartesian_product
import numpy
import sympy
from sympy.concrete.delta import _simplify_delta
from numpy import complex128, float64, int64
from .abstract_quantum_algebra import (
QuantumExpression, QuantumIndexedSum, QuantumOperation, QuantumPlus,
QuantumTimes, QuantumDerivative)
from .algebraic_properties import (
assoc, assoc_indexed, convert_to_scalars, filter_neutral,
indexed_sum_over_const, indexed_sum_over_kronecker, match_replace,
match_replace_binary, orderby, collect_scalar_summands)
from .hilbert_space_algebra import TrivialSpace
from ...utils.singleton import Singleton, singleton_object
from ...utils.ordering import KeyTuple
from ...utils.indices import SymbolicLabelBase, IdxSym
__all__ = [
'Scalar', 'ScalarValue', 'ScalarExpression', 'Zero', 'One', 'ScalarPlus',
'ScalarTimes', 'ScalarIndexedSum', 'ScalarPower', 'ScalarDerivative',
'sqrt', 'KroneckerDelta']
__private__ = ['is_scalar']
[docs]class Scalar(QuantumExpression, metaclass=ABCMeta):
"""Base class for Scalars"""
#: types that may be wrapped by :class:`ScalarValue`
_val_types = (
int, float, complex, sympy.Basic, int64, complex128, float64)
#: values that cannot be wrapped by :class:`ScalarValue`
_invalid = {sympy.oo, sympy.zoo, numpy.nan, numpy.inf}
@property
def space(self):
""":obj:`.TrivialSpace`, by definition"""
return TrivialSpace
[docs] def conjugate(self):
"""Complex conjugate"""
return self._adjoint()
@property
def real(self):
"""Real part"""
return (self.conjugate() + self) / 2
@property
def imag(self):
"""Imaginary part"""
return (self.conjugate() - self) * (sympy.I / 2)
def __add__(self, other):
if other == 0:
return self
return super().__add__(other)
def __sub__(self, other):
if other == 0:
return self
return super().__sub__(other)
def __mul__(self, other):
if other == 1:
return self
elif other == 0:
return Zero
return super().__mul__(other)
def __floordiv__(self, other):
if other == 1:
return self // 1 # Note: 3.5 // 1 == 3.0 != 3.5; -> NOT self
elif other == 0:
raise ZeroDivisionError("integer division or modulo by zero")
try:
# noinspection PyUnresolvedReferences
return super().__floordiv__(other)
except AttributeError:
return NotImplemented
def __truediv__(self, other):
if other == 1:
return self
elif other == 0:
raise ZeroDivisionError("integer division or modulo by zero")
elif other == self:
return One
if isinstance(other, ScalarValue):
other = other.val
if isinstance(other, (float, complex, complex128, float64)):
return (ScalarValue(1/other)) * self
elif isinstance(other, (int, sympy.Basic, int64)):
return (ScalarValue(sympy.sympify(1)/other)) * self
return super().__truediv__(other)
def __mod__(self, other):
if other == 1:
return Zero
elif other == 0:
raise ZeroDivisionError("integer division or modulo by zero")
try:
# noinspection PyUnresolvedReferences
return super().__mod__(other)
except AttributeError:
return NotImplemented
# __pow__(self, other) is fully implemented in QuantumExpression
def __radd__(self, other):
if other == 0:
return self
return super().__radd__(other)
def __rsub__(self, other):
if other == 0:
return -self
return super().__rsub__(other)
def __rmul__(self, other):
if other == 1:
return self
elif other == 0:
return Zero
return super().__rmul__(other)
def __rfloordiv__(self, other):
if other == 0:
return Zero
try:
# noinspection PyUnresolvedReferences
return super().__rfloordiv__(other)
except AttributeError:
return NotImplemented
def __rtruediv__(self, other):
if other == 0:
if self != 0:
return Zero
elif other == 1:
return ScalarPower.create(self, -1)
elif is_scalar(other):
return other * ScalarPower.create(self, -1)
try:
# noinspection PyUnresolvedReferences
return super().__rtruediv__(other)
except AttributeError:
return NotImplemented
def __rmod__(self, other):
if other == 0:
return Zero
try:
# noinspection PyUnresolvedReferences
return super().__rmod__(other)
except AttributeError:
return NotImplemented
def __rpow__(self, other):
if other == 0:
return Zero
elif other == 1:
return One
try:
# noinspection PyUnresolvedReferences
return super().__rpow__(other)
except AttributeError:
return NotImplemented
[docs]class ScalarValue(Scalar):
"""Wrapper around a numeric or symbolic value
The wrapped value may be of any of the following types::
>>> for t in ScalarValue._val_types:
... print(t)
<class 'int'>
<class 'float'>
<class 'complex'>
<class 'sympy.core.basic.Basic'>
<class 'numpy.int64'>
<class 'numpy.complex128'>
<class 'numpy.float64'>
A :class:`ScalarValue` behaves exactly like its wrapped value in all
algebraic contexts::
>>> 5 * ScalarValue.create(2)
10
Any unknown attributes or methods will be forwarded to the wrapped value
to ensure complete "duck-typing"::
>>> alpha = ScalarValue(sympy.symbols('alpha', positive=True))
>>> alpha.is_positive # same as alpha.val.is_positive
True
>>> ScalarValue(5).is_positive
Traceback (most recent call last):
...
AttributeError: 'int' object has no attribute 'is_positive'
"""
[docs] @classmethod
def create(cls, val):
"""Instatiate the :class:`ScalarValue` while recognizing :class:`Zero`
and :class:`One`.
:class:`Scalar` instances as `val` (including
:class:`ScalarExpression` instances) are left unchanged. This makes
:meth:`ScalarValue.create` a safe method for converting unknown objects
to :class:`Scalar`.
"""
if val in cls._invalid:
raise ValueError("Invalid value %r" % val)
if val == 0:
return Zero
elif val == 1:
return One
elif isinstance(val, Scalar):
return val
else:
# We instantiate ScalarValue directly to avoid the overhead of
# super().create(). Thus, there is no caching for scalars (which is
# probably a good thing)
return cls(val)
def __init__(self, val):
self._val = val
if not isinstance(val, self._val_types):
raise TypeError(
"val must be one of " +
", ".join(["%s" % t for t in self._val_types]))
super().__init__(val)
def __getattr__(self, name):
return getattr(self.val, name)
def _diff(self, sym):
if isinstance(self.val, sympy.Basic):
return ScalarValue.create(sympy.diff(self.val, sym))
else:
return Zero
def _simplify_scalar(self, func):
if isinstance(self.val, sympy.Basic):
return self.__class__.create(func(self.val))
else:
return self
@property
def val(self):
"""The wrapped scalar value"""
return self._val
@property
def args(self):
"""Tuple containing the wrapped scalar value as its only element"""
return (self._val,)
def _series_expand(self, param, about, order):
if isinstance(self.val, sympy.Basic):
if about != 0:
c = self.val.subs({param: about + param})
else:
c = self.val
series = sympy.series(c, x=param, x0=0, n=None)
res = []
next_order = 0
for term in series:
c, o = term.as_coeff_exponent(param)
if o < 0 or o.is_noninteger:
raise ValueError(
"%s is singular at expansion point %s=%s."
% (self, param, about))
if o > order:
break
res.extend([0] * (o - next_order))
res.append(c)
next_order = o + 1
res.extend([0] * (order + 1 - next_order))
return tuple([ScalarValue.create(c) for c in res])
else:
return tuple([self, ] + [Zero] * order)
@property
def real(self):
"""Real part"""
if hasattr(self.val, 'real'):
return self.val.real
else:
# SymPy
return self.val.as_real_imag()[0]
@property
def imag(self):
"""Imaginary part"""
if hasattr(self.val, 'imag'):
return self.val.imag
else:
# SymPy
return self.val.as_real_imag()[1]
def _adjoint(self):
return self.__class__(self.val.conjugate())
def __eq__(self, other):
if isinstance(other, ScalarValue):
return self.val == other.val
else:
return self.val == other
def __lt__(self, other):
if isinstance(other, ScalarValue):
return self.val < other.val
else:
return self.val < other
def __le__(self, other):
if isinstance(other, ScalarValue):
return self.val <= other.val
else:
return self.val <= other
def __gt__(self, other):
if isinstance(other, ScalarValue):
return self.val > other.val
else:
return self.val > other
def __ge__(self, other):
if isinstance(other, ScalarValue):
return self.val >= other.val
else:
return self.val >= other
def __hash__(self):
return hash(self.val)
def __neg__(self):
return self.create(-self.val)
def __abs__(self):
return self.create(abs(self.val))
def __add__(self, other):
if isinstance(other, ScalarValue):
return self.create(self.val + other.val)
elif isinstance(other, self._val_types):
return self.create(self.val + other)
elif other == 1:
return self.create(self.val + 1)
else:
return super().__add__(other) # other == 0
def __sub__(self, other):
if isinstance(other, ScalarValue):
return self.create(self.val - other.val)
elif isinstance(other, self._val_types):
return self.create(self.val - other)
elif other == 1:
return self.create(self.val - 1)
else:
return super().__sub__(other) # other == 0
def __mul__(self, other):
if isinstance(other, ScalarValue):
return self.create(self.val * other.val)
elif isinstance(other, self._val_types):
return self.create(self.val * other)
else:
return super().__mul__(other) # other == 0, 1
def __floordiv__(self, other):
if isinstance(other, ScalarValue):
return self.create(self.val // other.val)
elif isinstance(other, self._val_types):
return self.create(self.val // other)
else:
return super().__floordiv__(other) # other == 0, 1
def __truediv__(self, other):
if isinstance(other, ScalarValue):
return self.create(self.val / other.val)
elif isinstance(other, self._val_types):
try:
return self.create(self.val / other)
except ValueError:
# sympy may produce 'infinity', which `create` catches as a
# ValueError
raise ZeroDivisionError("integer division or modulo by zero")
else:
return super().__truediv__(other) # other == 0, 1
def __mod__(self, other):
if isinstance(other, ScalarValue):
return self.create(self.val % other.val)
elif isinstance(other, self._val_types):
return self.create(self.val % other)
else:
return super().__mod__(other) # other == 0, 1
def __pow__(self, other):
if isinstance(other, ScalarValue):
return self.create(self.val**other.val)
elif isinstance(other, self._val_types):
return self.create(self.val**other)
else:
return super().__pow__(other) # other == 0, 1
def __radd__(self, other):
if other == 1:
return self.create(1 + self.val)
elif isinstance(other, self._val_types):
return self.create(other + self.val)
else:
return super().__radd__(other) # other == 0
def __rsub__(self, other):
if other == 1:
return self.create(1 - self.val)
elif isinstance(other, self._val_types):
return self.create(other - self.val)
else:
return super().__radd__(other) # other == 0
def __rmul__(self, other):
if isinstance(other, self._val_types):
return self.create(other * self.val)
else:
return super().__rmul__(other) # other == 0, 1
def __rfloordiv__(self, other):
if other == 1:
return self.create(1 // self.val)
elif isinstance(other, self._val_types):
return self.create(other // self.val)
else:
return super().__rfloordiv__(other) # other == 0
def __rtruediv__(self, other):
if other == 1:
return self.create(1 / self.val)
elif isinstance(other, self._val_types):
return self.create(other / self.val)
else:
return super().__rtruediv__(other) # other == 0, 1/x -> x^(-1)
def __rmod__(self, other):
if isinstance(other, self._val_types):
return self.create(other % self.val)
else:
return super().__rmod__(other) # other == 0
def __rpow__(self, other):
if isinstance(other, self._val_types):
return self.create(other**self.val)
else:
return super().__rpow__(other) # other == 0, 1
def __complex__(self):
return complex(self.val)
def __int__(self):
return int(self.val)
def __float__(self):
return float(self.val)
def _sympy_(self):
return sympy.sympify(self.val)
[docs]class ScalarExpression(Scalar, metaclass=ABCMeta):
"""Base class for scalars with non-scalar arguments
For example, a :class:`.BraKet` is a :class:`Scalar`, but has arguments
that are states.
"""
_order_index = 1 # Expression scalars come after ScalarValue
def __pow__(self, other):
return ScalarPower.create(self, other)
[docs]@singleton_object
class Zero(Scalar, metaclass=Singleton):
"""The neutral element with respect to scalar addition
Equivalent to the scalar value zero::
>>> Zero == 0
True
"""
_order_name = 'ScalarValue' # sort like ScalarValue(0)
_hash_val = 0
@property
def args(self):
return tuple()
@property
def val(self):
return self._hash_val
@property
def real(self):
"""Real part"""
return self
@property
def imag(self):
"""Imaginary part"""
return self
@property
def _order_key(self):
return KeyTuple([
self._order_index, self._order_name or self.__class__.__name__,
self._order_coeff, KeyTuple([self.val, ]), self._order_kwargs])
def _diff(self, sym):
return self
def _adjoint(self):
return self
def __abs__(self):
return self
__neg__ = __abs__
def __lt__(self, other):
if isinstance(other, ScalarValue):
return self._hash_val < other.val
else:
return self._hash_val < other
def __le__(self, other):
if isinstance(other, ScalarValue):
return self._hash_val <= other.val
else:
return self._hash_val <= other
def __gt__(self, other):
if isinstance(other, ScalarValue):
return self._hash_val > other.val
else:
return self._hash_val > other
def __ge__(self, other):
if isinstance(other, ScalarValue):
return self._hash_val >= other.val
else:
return self._hash_val >= other
def __add__(self, other):
if isinstance(other, ScalarValue):
return other
elif isinstance(other, self._val_types):
return ScalarValue.create(other)
else:
return super().__add__(other)
def __sub__(self, other):
if isinstance(other, ScalarValue):
return -ScalarValue.create(other)
elif isinstance(other, self._val_types):
return ScalarValue.create(-other)
else:
return super().__sub__(other)
def __mul__(self, other):
try:
# if possible, keep the type of `other`
return other._zero
except AttributeError:
return self
def __floordiv__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(0 // other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(0 // other)
else:
return super().__floordiv__(other)
def __truediv__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(0 / other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(0 / other)
else:
return super().__truediv__(other)
def __mod__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(0 % other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(0 % other)
else:
return super().__mod__(other)
def __pow__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(0**other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(0**other)
else:
return super().__pow__(other)
def __radd__(self, other):
if isinstance(other, self._val_types):
return ScalarValue.create(other)
else:
return super().__radd__(other)
__rsub__ = __radd__
def __rmul__(self, other):
if isinstance(other, self._val_types):
return self
else:
return super().__rmul__(other)
def __rfloordiv__(self, other):
if isinstance(other, self._val_types):
return ScalarValue.create(other // 0)
else:
return super().__rfloordiv__(other)
def __rtruediv__(self, other):
raise ZeroDivisionError("integer division or modulo by zero")
__rmod__ = __rtruediv__
def __rpow__(self, other):
return One
def __complex__(self):
return 0j
def __int__(self):
return 0
def __float__(self):
return 0.0
def _sympy_(self):
return sympy.sympify(0)
[docs]@singleton_object
class One(Scalar, metaclass=Singleton):
"""The neutral element with respect to scalar multiplication
Equivalent to the scalar value one::
>>> One == 1
True
"""
_order_name = 'ScalarValue' # sort like ScalarValue(1)
_hash_val = 1
@property
def args(self):
return tuple()
@property
def val(self):
return self._hash_val
@property
def real(self):
"""Real part"""
return self
@property
def imag(self):
"""Imaginary part"""
return Zero
@property
def _order_key(self):
return KeyTuple([
self._order_index, self._order_name or self.__class__.__name__,
self._order_coeff, KeyTuple([self.val, ]), self._order_kwargs])
def _diff(self, sym):
return Zero
def __lt__(self, other):
if isinstance(other, ScalarValue):
return self._hash_val < other.val
else:
return self._hash_val < other
def __le__(self, other):
if isinstance(other, ScalarValue):
return self._hash_val <= other.val
else:
return self._hash_val <= other
def __gt__(self, other):
if isinstance(other, ScalarValue):
return self._hash_val > other.val
else:
return self._hash_val > other
def __ge__(self, other):
if isinstance(other, ScalarValue):
return self._hash_val >= other.val
else:
return self._hash_val >= other
def _adjoint(self):
return self
def __abs__(self):
return self
def __neg__(self):
return ScalarValue(-1)
def __add__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(1 + other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(1 + other)
elif other == 1:
return ScalarValue(2)
else:
return super().__add__(other)
def __sub__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(1 - other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(1 - other)
elif other == 1:
return Zero
else:
return super().__add__(other)
def __mul__(self, other):
if isinstance(other, self._val_types):
return ScalarValue.create(other)
else:
return other
def __floordiv__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(1 // other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(1 // other)
else:
return super().__floordiv__(other)
def __truediv__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(1 / other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(1 / other)
else:
return super().__floordiv__(other)
def __mod__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(1 % other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(1 % other)
else:
return super().__mod__(other)
def __pow__(self, other):
if isinstance(other, ScalarValue):
return ScalarValue.create(1**other.val)
elif isinstance(other, self._val_types):
return ScalarValue.create(1**other)
else:
return super().__pow__(other)
def __radd__(self, other):
if isinstance(other, self._val_types):
return ScalarValue.create(other + 1)
else:
return super().__radd__(other)
def __rsub__(self, other):
if isinstance(other, self._val_types):
return ScalarValue.create(other - 1)
else:
return super().__radd__(other)
__rfloordiv = __rtruediv__ = __rmul__ = __rpow__ = __mul__
def __rmod__(self, other):
if isinstance(other, self._val_types):
return ScalarValue.create(other % 1)
else:
return super().__rmod__(other)
def __complex__(self):
return 1j
def __int__(self):
return 1
def __float__(self):
return 1.0
def _sympy_(self):
return sympy.sympify(1)
[docs]class ScalarPlus(QuantumPlus, Scalar):
"""Sum of scalars
Generally, :class:`ScalarValue` instances are combined directly::
>>> alpha = ScalarValue.create(sympy.symbols('alpha'))
>>> print(srepr(alpha + 1))
ScalarValue(Add(Symbol('alpha'), Integer(1)))
An unevaluated :class:`ScalarPlus` remains only for
:class:`ScalarExpression` instaces::
>>> braket = KetSymbol('Psi', hs=0).dag() * KetSymbol('Phi', hs=0)
>>> print(srepr(braket + 1, indented=True))
ScalarPlus(
One,
BraKet(
KetSymbol(
'Psi',
hs=LocalSpace(
'0')),
KetSymbol(
'Phi',
hs=LocalSpace(
'0'))))
"""
_neutral_element = Zero
_binary_rules = OrderedDict()
simplifications = [
assoc, convert_to_scalars, orderby, collect_scalar_summands,
match_replace_binary]
[docs] def conjugate(self):
"""Complex conjugate of of the sum"""
return self.__class__.create(
*[arg.conjugate() for arg in self.args])
def __pow__(self, other):
return ScalarPower.create(self, other)
[docs]class ScalarTimes(QuantumTimes, Scalar):
"""Product of scalars
Generally, :class:`ScalarValue` instances are combined directly::
>>> alpha = ScalarValue.create(sympy.symbols('alpha'))
>>> print(srepr(alpha * 2))
ScalarValue(Mul(Integer(2), Symbol('alpha')))
An unevaluated :class:`ScalarTimes` remains only for
:class:`ScalarExpression` instaces::
>>> braket = KetSymbol('Psi', hs=0).dag() * KetSymbol('Phi', hs=0)
>>> print(srepr(braket * 2, indented=True))
ScalarTimes(
ScalarValue(
2),
BraKet(
KetSymbol(
'Psi',
hs=LocalSpace(
'0')),
KetSymbol(
'Phi',
hs=LocalSpace(
'0'))))
"""
_neutral_element = One
_binary_rules = OrderedDict()
simplifications = [assoc, orderby, filter_neutral, match_replace_binary]
[docs] @classmethod
def create(cls, *operands, **kwargs):
"""Instantiate the product while applying simplification rules"""
converted_operands = []
for op in operands:
if not isinstance(op, Scalar):
op = ScalarValue.create(op)
converted_operands.append(op)
return super().create(*converted_operands, **kwargs)
[docs] def conjugate(self):
"""Complex conjugate of of the product"""
return self.__class__.create(
*[arg.conjugate() for arg in reversed(self.args)])
def __pow__(self, other):
return ScalarPower.create(self, other)
def _expand(self):
eops = [o.expand() for o in self.operands]
# store tuples of summands of all expanded factors
def get_summands(x):
if isinstance(x, self.__class__._plus_cls):
return x.operands
elif isinstance(x, ScalarValue) and isinstance(x.val, sympy.Add):
return x.val.args
else:
return (x, )
eopssummands = [get_summands(eo) for eo in eops]
# iterate over a cartesian product of all factor summands, form product
# of each tuple and sum over result
summands = []
for combo in cartesian_product(*eopssummands):
summand = self.__class__._times_cls.create(*combo)
summands.append(summand)
ret = self.__class__._plus_cls.create(*summands)
if isinstance(ret, self.__class__._plus_cls):
return ret.expand()
else:
return ret
[docs]class ScalarIndexedSum(QuantumIndexedSum, Scalar):
"""Indexed sum over scalars"""
_rules = OrderedDict()
simplifications = [
assoc_indexed, indexed_sum_over_kronecker,
indexed_sum_over_const, match_replace]
[docs] @classmethod
def create(cls, term, *ranges):
"""Instantiate the indexed sum while applying simplification rules"""
if not isinstance(term, Scalar):
term = ScalarValue.create(term)
return super().create(term, *ranges)
def __init__(self, term, *ranges):
if not isinstance(term, Scalar):
term = ScalarValue.create(term)
super().__init__(term, *ranges)
[docs] def conjugate(self):
"""Complex conjugate of of the indexed sum"""
return self.__class__.create(self.term.conjugate(), *self.ranges)
@property
def real(self):
"""Real part"""
return self.__class__.create(self.term.real, *self.ranges)
@property
def imag(self):
"""Imaginary part"""
return self.__class__.create(self.term.imag, *self.ranges)
def _check_val_type(self, other):
return (
isinstance(other, ScalarValue) or
isinstance(other, Scalar._val_types))
def __mul__(self, other):
# For "normal" indexed sums, we prefer to keep scalar factors in front
# of the sum. For a ScalarIndexedSum, this doesn't make sense, though
if self._check_val_type(other):
sum = self
try:
idx_syms = [
s for s in other.free_symbols if isinstance(s, IdxSym)]
if len(idx_syms) > 0:
sum = self.make_disjunct_indices(*idx_syms)
except AttributeError:
pass
return self.__class__.create(sum.term * other, *self.ranges)
else:
return super().__mul__(other)
def __rmul__(self, other):
if self._check_val_type(other):
sum = self
try:
idx_syms = [
s for s in other.free_symbols if isinstance(s, IdxSym)]
if len(idx_syms) > 0:
sum = self.make_disjunct_indices(*idx_syms)
except AttributeError:
pass
return self.__class__.create(other * sum.term, *self.ranges)
else:
return super().__rmul__(other)
def __pow__(self, other):
if other == 0:
return self._one
elif other == 1:
return self
else:
try:
other_is_int = (other == int(other))
except TypeError:
other_is_int = False
if other_is_int:
if other > 1:
res = self
for _ in range(other - 1):
res = res * self
return res
else:
assert other < 1
return 1 / self**(-other)
else:
return ScalarPower.create(self, other)
[docs]class ScalarPower(QuantumOperation, Scalar):
"""A scalar raised to a power
Generally, :class:`ScalarValue` instances are exponentiated directly::
>>> alpha = ScalarValue.create(sympy.symbols('alpha'))
>>> print(srepr(alpha**2))
ScalarValue(Pow(Symbol('alpha'), Integer(2)))
An unevaluated :class:`ScalarPower` remains only for
:class:`ScalarExpression` instaces, see e.g. :func:`sqrt`.
"""
_rules = OrderedDict()
simplifications = [convert_to_scalars, match_replace]
def __init__(self, b, e):
self._base = b
self._exp = e
super().__init__(b, e)
@property
def base(self):
"""The base of the exponential"""
return self._base
@property
def exp(self):
"""The exponent"""
return self._exp
def __pow__(self, other):
return ScalarPower.create(self.base, self.exp * other)
def _adjoint(self):
return self.__class__.create(
self._base.conjugate(), self._exp.conjugate())
def _diff(self, sym):
inner = self.base._diff(sym)
if inner != 0:
return (
self.exp * ScalarPower.create(self.base, self.exp-1) * inner)
else:
return Zero
def _series_expand(self, param, about, order):
try:
if int(self.exp) == self.exp and int(self.exp) > 0:
# delegate to the _series_expand for ScalarTimes
prod_ops = [self.base for _ in range(int(self.exp))]
self_as_product = ScalarTimes(*prod_ops)
return self_as_product.series_expand(param, about, order)
else:
raise ValueError("self.exp is not a positive integer")
except ValueError:
# We don't know for sure if self is singular. One way to find out
# is to substitute symbols for every ScalarExpression (assuming
# they don't depend on param), let sympy do the series expansion,
# and then substitute back. We'll keep this option for a later day
raise ValueError(
"%s MAY be singular at expansion point %s=%s. Report this "
"as a bug." % (self, param, about))
[docs]class ScalarDerivative(QuantumDerivative, Scalar):
"""Symbolic partial derivative of a scalar
See :class:`.QuantumDerivative`.
"""
pass
[docs]def KroneckerDelta(i, j, simplify=True):
"""Kronecker delta symbol
Return :class:`One` (`i` equals `j`)), :class:`Zero` (`i` and `j` are
non-symbolic an unequal), or a :class:`ScalarValue` wrapping SymPy's
:class:`~sympy.functions.special.tensor_functions.KroneckerDelta`.
>>> i, j = IdxSym('i'), IdxSym('j')
>>> KroneckerDelta(i, i)
One
>>> KroneckerDelta(1, 2)
Zero
>>> KroneckerDelta(i, j)
KroneckerDelta(i, j)
By default, the Kronecker delta is returned in a simplified form, e.g::
>>> KroneckerDelta((i+1)/2, (j+1)/2)
KroneckerDelta(i, j)
This may be suppressed by setting `simplify` to False::
>>> KroneckerDelta((i+1)/2, (j+1)/2, simplify=False)
KroneckerDelta(i/2 + 1/2, j/2 + 1/2)
Raises:
TypeError: if `i` or `j` is not an integer or sympy expression. There
is no automatic sympification of `i` and `j`.
"""
from qnet.algebra.core.scalar_algebra import ScalarValue, One
if not isinstance(i, (int, sympy.Basic)):
raise TypeError(
"i is not an integer or sympy expression: %s" % type(i))
if not isinstance(j, (int, sympy.Basic)):
raise TypeError(
"j is not an integer or sympy expression: %s" % type(j))
if i == j:
return One
else:
delta = sympy.KroneckerDelta(i, j)
if simplify:
delta = _simplify_delta(delta)
return ScalarValue.create(delta)
[docs]def sqrt(scalar):
"""Square root of a :class:`Scalar` or scalar value
This always returns a :class:`Scalar`, and uses a symbolic square root if
possible (i.e., for non-floats)::
>>> sqrt(2)
sqrt(2)
>>> sqrt(2.0)
1.414213...
For a :class:`ScalarExpression` argument, it returns a
:class:`ScalarPower` instance::
>>> braket = KetSymbol('Psi', hs=0).dag() * KetSymbol('Phi', hs=0)
>>> nrm = sqrt(braket * braket.dag())
>>> print(srepr(nrm, indented=True))
ScalarPower(
ScalarTimes(
BraKet(
KetSymbol(
'Phi',
hs=LocalSpace(
'0')),
KetSymbol(
'Psi',
hs=LocalSpace(
'0'))),
BraKet(
KetSymbol(
'Psi',
hs=LocalSpace(
'0')),
KetSymbol(
'Phi',
hs=LocalSpace(
'0')))),
ScalarValue(
Rational(1, 2)))
"""
if isinstance(scalar, ScalarValue):
scalar = scalar.val
if scalar == 1:
return One
elif scalar == 0:
return Zero
elif isinstance(scalar, (float, complex, complex128, float64)):
return ScalarValue.create(numpy.sqrt(scalar))
elif isinstance(scalar, (int, sympy.Basic, int64)):
return ScalarValue.create(sympy.sqrt(scalar))
elif isinstance(scalar, Scalar):
return scalar**(sympy.sympify(1) / 2)
else:
raise TypeError("Unknown type of scalar: %r" % type(scalar))
[docs]def is_scalar(scalar):
"""Check if `scalar` is a :class:`Scalar` or a scalar value
Specifically, whether `scalar` is an instance of :class:`Scalar` or an
instance of a numeric or symbolic type that could be wrapped in
:class:`ScalarValue`.
For internal use only.
"""
return isinstance(scalar, Scalar) or isinstance(scalar, Scalar._val_types)
Scalar._zero = Zero
Scalar._one = One
Scalar._base_cls = Scalar
Scalar._scalar_times_expr_cls = ScalarTimes
Scalar._plus_cls = ScalarPlus
Scalar._times_cls = ScalarTimes
Scalar._adjoint_cls = lambda scalar: scalar.conjugate()
Scalar._adjoint_cls.create = Scalar._adjoint_cls # mock Expression
Scalar._indexed_sum_cls = ScalarIndexedSum
Scalar._derivative_cls = ScalarDerivative