import logging
import itertools
from functools import partial
from collections import OrderedDict
import sympy
from sympy.concrete.delta import (
_has_simple_delta, _extract_delta as _sympy_extract_delta)
from .abstract_algebra import LOG, LEVEL, LOG_NO_MATCH
from .exceptions import CannotSimplify
from ..pattern_matching import ProtoExpr, Pattern, match_pattern
from ...utils.indices import IdxSym
__all__ = []
__private__ = [
'assoc', 'assoc_indexed', 'idem', 'orderby', 'filter_neutral',
'filter_cid', 'match_replace', 'match_replace_binary', 'check_cdims',
'convert_to_spaces', 'empty_trivial', 'implied_local_space',
'delegate_to_method', 'scalars_to_op', 'convert_to_scalars',
'disjunct_hs_zero', 'commutator_order', 'accept_bras',
'basis_ket_zero_outside_hs', 'indexed_sum_over_const',
'indexed_sum_over_kronecker', 'derivative_via_diff', 'collect_summands',
'collect_scalar_summands']
_RESOLVE_KRONECKER_WITH_PIECEWISE = False
# Handling indexed sums over Kronecker deltas correctly in the most general
# cases requires substituting it with a Piecewise function (the delta is zero
# outside of the range covered by the summmation index). This incurrs
# considerable numerical overhead. In many cases, one gets correct results
# while ignoring the range of the summation, with less effort. In that case,
# you may speed calculations by settings this flag to False, at your own risk.
#
# An exmple where one gets the wrong result by ignoring the summation range is
# this:
#
# i, j = symbols('i, j', cls=IdxSym)
# sum = Sum(i, (1, 2, 3))(Sum(j, (3, 4))(KroneckerDelta(i, j)))
#
# The wrong result is 3, the correct result is 1
[docs]def assoc(cls, ops, kwargs):
"""Associatively expand out nested arguments of the flat class.
E.g.::
>>> class Plus(Operation):
... simplifications = [assoc, ]
>>> Plus.create(1,Plus(2,3))
Plus(1, 2, 3)
"""
expanded = [(o,) if not isinstance(o, cls) else o.operands for o in ops]
return sum(expanded, ()), kwargs
[docs]def assoc_indexed(cls, ops, kwargs):
r"""Flatten nested indexed structures while pulling out possible prefactors
For example, for an :class:`.IndexedSum`:
.. math::
\sum_j \left( a \sum_i \dots \right) = a \sum_{j, i} \dots
"""
from qnet.algebra.core.abstract_quantum_algebra import (
ScalarTimesQuantumExpression)
term, *ranges = ops
if isinstance(term, cls):
coeff = 1
elif isinstance(term, ScalarTimesQuantumExpression):
coeff = term.coeff
term = term.term
if not isinstance(term, cls):
return ops, kwargs
else:
return ops, kwargs
term = term.make_disjunct_indices(*ranges)
combined_ranges = tuple(ranges) + term.ranges
if coeff == 1:
return cls.create(term.term, *combined_ranges)
else:
bound_symbols = set([r.index_symbol for r in combined_ranges])
if len(coeff.free_symbols.intersection(bound_symbols)) == 0:
return coeff * cls.create(term.term, *combined_ranges)
else:
return cls.create(coeff * term.term, *combined_ranges)
[docs]def idem(cls, ops, kwargs):
"""Remove duplicate arguments and order them via the cls's order_key key
object/function.
E.g.::
>>> class Set(Operation):
... order_key = lambda val: val
... simplifications = [idem, ]
>>> Set.create(1,2,3,1,3)
Set(1, 2, 3)
"""
return sorted(set(ops), key=cls.order_key), kwargs
[docs]def orderby(cls, ops, kwargs):
"""Re-order arguments via the class's ``order_key`` key object/function.
Use this for commutative operations:
E.g.::
>>> class Times(Operation):
... order_key = lambda val: val
... simplifications = [orderby, ]
>>> Times.create(2,1)
Times(1, 2)
"""
return sorted(ops, key=cls.order_key), kwargs
[docs]def filter_neutral(cls, ops, kwargs):
"""Remove occurrences of a neutral element from the argument/operand list,
if that list has at least two elements. To use this, one must also specify
a neutral element, which can be anything that allows for an equality check
with each argument. E.g.::
>>> class X(Operation):
... _neutral_element = 1
... simplifications = [filter_neutral, ]
>>> X.create(2,1,3,1)
X(2, 3)
"""
c_n = cls._neutral_element
if len(ops) == 0:
return c_n
fops = [op for op in ops if c_n != op] # op != c_n does NOT work
if len(fops) > 1:
return fops, kwargs
elif len(fops) == 1:
# the remaining operand is the single non-trivial one
return fops[0]
else:
# the original list of operands consists only of neutral elements
return ops[0]
[docs]def collect_summands(cls, ops, kwargs):
"""Collect summands that occur multiple times into a single summand
Also filters out zero-summands.
Example:
>>> A, B, C = (OperatorSymbol(s, hs=0) for s in ('A', 'B', 'C'))
>>> collect_summands(
... OperatorPlus, (A, B, C, ZeroOperator, 2 * A, B, -C) , {})
((3 * A^(0), 2 * B^(0)), {})
>>> collect_summands(OperatorPlus, (A, -A), {})
ZeroOperator
>>> collect_summands(OperatorPlus, (B, A, -B), {})
A^(0)
"""
from qnet.algebra.core.abstract_quantum_algebra import (
ScalarTimesQuantumExpression)
coeff_map = OrderedDict()
for op in ops:
if isinstance(op, ScalarTimesQuantumExpression):
coeff, term = op.coeff, op.term
else:
coeff, term = 1, op
if term in coeff_map:
coeff_map[term] += coeff
else:
coeff_map[term] = coeff
fops = []
for (term, coeff) in coeff_map.items():
op = coeff * term
if not op.is_zero:
fops.append(op)
if len(fops) == 0:
return cls._zero
elif len(fops) == 1:
return fops[0]
else:
return tuple(fops), kwargs
[docs]def collect_scalar_summands(cls, ops, kwargs):
"""Collect :class:`ValueScalar` and :class:`ScalarExpression` summands
Example:
>>> srepr(collect_scalar_summands(Scalar, (1, 2, 3), {}))
'ScalarValue(6)'
>>> collect_scalar_summands(Scalar, (1, 1, -1), {})
One
>>> collect_scalar_summands(Scalar, (1, -1), {})
Zero
>>> Psi = KetSymbol("Psi", hs=0)
>>> Phi = KetSymbol("Phi", hs=0)
>>> braket = BraKet.create(Psi, Phi)
>>> collect_scalar_summands(Scalar, (1, braket, -1), {})
<Psi|Phi>^(0)
>>> collect_scalar_summands(Scalar, (1, 2 * braket, 2, 2 * braket), {})
((3, 4 * <Psi|Phi>^(0)), {})
>>> collect_scalar_summands(Scalar, (2 * braket, -braket, -braket), {})
Zero
"""
# This routine is required because there is no
# "ScalarTimesQuantumExpression" for scalars: we have to extract
# coefficiencts from ScalarTimes instead
from qnet.algebra.core.scalar_algebra import (
Zero, One, Scalar, ScalarTimes, ScalarValue)
a_0 = Zero
coeff_map = OrderedDict()
for op in ops:
if isinstance(op, ScalarValue) or isinstance(op, Scalar._val_types):
a_0 += op
continue
elif isinstance(op, ScalarTimes):
if isinstance(op.operands[0], ScalarValue):
coeff = op.operands[0]
term = op.operands[1]
for sub_op in op.operands[2:]:
term *= sub_op
else:
coeff, term = One, op
else:
coeff, term = One, op
if term in coeff_map:
coeff_map[term] += coeff
else:
coeff_map[term] = coeff
if a_0 == Zero:
fops = []
else:
fops = [a_0]
for (term, coeff) in coeff_map.items():
op = coeff * term
if not op.is_zero:
fops.append(op)
if len(fops) == 0:
return cls._zero
elif len(fops) == 1:
return fops[0]
else:
return tuple(fops), kwargs
[docs]def match_replace(cls, ops, kwargs):
"""Match and replace a full operand specification to a function that
provides a replacement for the whole expression
or raises a :exc:`.CannotSimplify` exception.
E.g.
First define an operation::
>>> class Invert(Operation):
... _rules = OrderedDict()
... simplifications = [match_replace, ]
Then some _rules::
>>> A = wc("A")
>>> A_float = wc("A", head=float)
>>> Invert_A = pattern(Invert, A)
>>> Invert._rules.update([
... ('r1', (pattern_head(Invert_A), lambda A: A)),
... ('r2', (pattern_head(A_float), lambda A: 1./A)),
... ])
Check rule application::
>>> print(srepr(Invert.create("hallo"))) # matches no rule
Invert('hallo')
>>> Invert.create(Invert("hallo")) # matches first rule
'hallo'
>>> Invert.create(.2) # matches second rule
5.0
A pattern can also have the same wildcard appear twice::
>>> class X(Operation):
... _rules = {
... 'r1': (pattern_head(A, A), lambda A: A),
... }
... simplifications = [match_replace, ]
>>> X.create(1,2)
X(1, 2)
>>> X.create(1,1)
1
"""
expr = ProtoExpr(ops, kwargs)
if LOG:
logger = logging.getLogger('QNET.create')
for key, rule in cls._rules.items():
pat, replacement = rule
match_dict = match_pattern(pat, expr)
if match_dict:
try:
replaced = replacement(**match_dict)
if LOG:
logger.debug(
"%sRule %s.%s: (%s, %s) -> %s", (" " * (LEVEL)),
cls.__name__, key, expr.args, expr.kwargs, replaced)
return replaced
except CannotSimplify:
if LOG_NO_MATCH:
logger.debug(
"%sRule %s.%s: no match: CannotSimplify",
(" " * (LEVEL)), cls.__name__, key)
continue
else:
if LOG_NO_MATCH:
logger.debug(
"%sRule %s.%s: no match: %s", (" " * (LEVEL)),
cls.__name__, key, match_dict.reason)
# No matching rules
return ops, kwargs
def _get_binary_replacement(first, second, cls):
"""Helper function for match_replace_binary"""
expr = ProtoExpr([first, second], {})
if LOG:
logger = logging.getLogger('QNET.create')
for key, rule in cls._binary_rules.items():
pat, replacement = rule
match_dict = match_pattern(pat, expr)
if match_dict:
try:
replaced = replacement(**match_dict)
if LOG:
logger.debug(
"%sRule %s.%s: (%s, %s) -> %s", (" " * (LEVEL)),
cls.__name__, key, expr.args, expr.kwargs, replaced)
return replaced
except CannotSimplify:
if LOG_NO_MATCH:
logger.debug(
"%sRule %s.%s: no match: CannotSimplify",
(" " * (LEVEL)), cls.__name__, key)
continue
else:
if LOG_NO_MATCH:
logger.debug(
"%sRule %s.%s: no match: %s", (" " * (LEVEL)),
cls.__name__, key, match_dict.reason)
return None
[docs]def match_replace_binary(cls, ops, kwargs):
"""Similar to func:`match_replace`, but for arbitrary length operations,
such that each two pairs of subsequent operands are matched pairwise.
>>> A = wc("A")
>>> class FilterDupes(Operation):
... _binary_rules = {
... 'filter_dupes': (pattern_head(A,A), lambda A: A)}
... simplifications = [match_replace_binary, assoc]
... _neutral_element = 0
>>> FilterDupes.create(1,2,3,4) # No duplicates
FilterDupes(1, 2, 3, 4)
>>> FilterDupes.create(1,2,2,3,4) # Some duplicates
FilterDupes(1, 2, 3, 4)
Note that this only works for *subsequent* duplicate entries:
>>> FilterDupes.create(1,2,3,2,4) # No *subsequent* duplicates
FilterDupes(1, 2, 3, 2, 4)
Any operation that uses binary reduction must be associative and define a
neutral element. The binary rules must be compatible with associativity,
i.e. there is no specific order in which the rules are applied to pairs of
operands.
"""
assert assoc in cls.simplifications, (
cls.__name__ + " must be associative to use match_replace_binary")
assert hasattr(cls, '_neutral_element'), (
cls.__name__ + " must define a neutral element to use "
"match_replace_binary")
fops = _match_replace_binary(cls, list(ops))
if len(fops) == 1:
return fops[0]
elif len(fops) == 0:
return cls._neutral_element
else:
return fops, kwargs
def _match_replace_binary(cls, ops: list) -> list:
"""Reduce list of `ops`"""
n = len(ops)
if n <= 1:
return ops
ops_left = ops[:n // 2]
ops_right = ops[n // 2:]
return _match_replace_binary_combine(
cls,
_match_replace_binary(cls, ops_left),
_match_replace_binary(cls, ops_right))
def _match_replace_binary_combine(cls, a: list, b: list) -> list:
"""combine two fully reduced lists a, b"""
if len(a) == 0 or len(b) == 0:
return a + b
r = _get_binary_replacement(a[-1], b[0], cls)
if r is None:
return a + b
if r == cls._neutral_element:
return _match_replace_binary_combine(cls, a[:-1], b[1:])
if isinstance(r, cls):
r = list(r.args)
else:
r = [r, ]
return _match_replace_binary_combine(
cls,
_match_replace_binary_combine(cls, a[:-1], r),
b[1:])
[docs]def check_cdims(cls, ops, kwargs):
"""Check that all operands (`ops`) have equal channel dimension."""
if not len({o.cdim for o in ops}) == 1:
raise ValueError("Not all operands have the same cdim:" + str(ops))
return ops, kwargs
[docs]def filter_cid(cls, ops, kwargs):
"""Remove occurrences of the :func:`.circuit_identity` ``cid(n)`` for any
``n``. Cf. :func:`filter_neutral`
"""
from qnet.algebra.core.circuit_algebra import CircuitZero, circuit_identity
if len(ops) == 0:
return CircuitZero
fops = [op for op in ops if op != circuit_identity(op.cdim)]
if len(fops) > 1:
return fops, kwargs
elif len(fops) == 1:
# the remaining operand is the single non-trivial one
return fops[0]
else:
# the original list of operands consists only of neutral elements
return ops[0]
[docs]def convert_to_spaces(cls, ops, kwargs):
"""For all operands that are merely of type str or int, substitute
LocalSpace objects with corresponding labels:
For a string, just itself, for an int, a string version of that int.
"""
from qnet.algebra.core.hilbert_space_algebra import (
HilbertSpace, LocalSpace)
cops = [o if isinstance(o, HilbertSpace) else LocalSpace(o) for o in ops]
return cops, kwargs
[docs]def empty_trivial(cls, ops, kwargs):
"""A ProductSpace of zero Hilbert spaces should yield the TrivialSpace"""
from qnet.algebra.core.hilbert_space_algebra import TrivialSpace
if len(ops) == 0:
return TrivialSpace
else:
return ops, kwargs
[docs]def implied_local_space(*, arg_index=None, keys=None):
"""Return a simplification that converts the positional argument
`arg_index` from (str, int) to a subclass of :class:`.LocalSpace`, as well
as any keyword argument with one of the given keys.
The exact type of the resulting Hilbert space is determined by
the `default_hs_cls` argument of :func:`init_algebra`.
In many cases, we have :func:`implied_local_space` (in ``create``) in
addition to a conversion in ``__init__``, so
that :func:`match_replace` etc can rely on the relevant arguments being a
:class:`HilbertSpace` instance.
"""
from qnet.algebra.core.hilbert_space_algebra import (
HilbertSpace, LocalSpace)
def args_to_local_space(cls, args, kwargs):
"""Convert (str, int) of selected args to :class:`.LocalSpace`"""
if isinstance(args[arg_index], LocalSpace):
new_args = args
else:
if isinstance(args[arg_index], (int, str)):
try:
hs = cls._default_hs_cls(args[arg_index])
except AttributeError:
hs = LocalSpace(args[arg_index])
else:
hs = args[arg_index]
assert isinstance(hs, HilbertSpace)
new_args = (tuple(args[:arg_index]) + (hs,) +
tuple(args[arg_index + 1:]))
return new_args, kwargs
def kwargs_to_local_space(cls, args, kwargs):
"""Convert (str, int) of selected kwargs to LocalSpace"""
if all([isinstance(kwargs[key], LocalSpace) for key in keys]):
new_kwargs = kwargs
else:
new_kwargs = {}
for key, val in kwargs.items():
if key in keys:
if isinstance(val, (int, str)):
try:
val = cls._default_hs_cls(val)
except AttributeError:
val = LocalSpace(val)
assert isinstance(val, HilbertSpace)
new_kwargs[key] = val
return args, new_kwargs
def to_local_space(cls, args, kwargs):
"""Convert (str, int) of selected args and kwargs to LocalSpace"""
new_args, __ = args_to_local_space(args, kwargs, arg_index)
__, new_kwargs = kwargs_to_local_space(args, kwargs, keys)
return new_args, new_kwargs
if (arg_index is not None) and (keys is None):
return args_to_local_space
elif (arg_index is None) and (keys is not None):
return kwargs_to_local_space
elif (arg_index is not None) and (keys is not None):
return to_local_space
else:
raise ValueError("must give at least one of arg_index and keys")
[docs]def delegate_to_method(mtd):
"""Create a simplification rule that delegates the instantiation to the
method `mtd` of the operand (if defined)"""
def _delegate_to_method(cls, ops, kwargs):
assert len(ops) == 1
op, = ops
if hasattr(op, mtd):
return getattr(op, mtd)()
else:
return ops, kwargs
return _delegate_to_method
[docs]def scalars_to_op(cls, ops, kwargs):
r'''Convert any scalar $\alpha$ in `ops` into an operator $\alpha
\identity$'''
from qnet.algebra.core.scalar_algebra import is_scalar
op_ops = []
for op in ops:
if is_scalar(op):
op_ops.append(op * cls._one)
else:
op_ops.append(op)
return op_ops, kwargs
[docs]def convert_to_scalars(cls, ops, kwargs):
"""Convert any entry in `ops` that is not a :class:`.Scalar` instance into
a :class:`.ScalarValue` instance"""
from qnet.algebra.core.scalar_algebra import Scalar, ScalarValue
scalar_ops = []
for op in ops:
if not isinstance(op, Scalar):
scalar_ops.append(ScalarValue(op))
else:
scalar_ops.append(op)
return scalar_ops, kwargs
[docs]def disjunct_hs_zero(cls, ops, kwargs):
"""Return ZeroOperator if all the operators in `ops` have a disjunct
Hilbert space, or an unchanged `ops`, `kwargs` otherwise
"""
from qnet.algebra.core.hilbert_space_algebra import TrivialSpace
from qnet.algebra.core.operator_algebra import ZeroOperator
hilbert_spaces = []
for op in ops:
try:
hs = op.space
except AttributeError: # scalars
hs = TrivialSpace
for hs_prev in hilbert_spaces:
if not hs.isdisjoint(hs_prev):
return ops, kwargs
hilbert_spaces.append(hs)
return ZeroOperator
[docs]def commutator_order(cls, ops, kwargs):
"""Apply anti-commutative property of the commutator to apply a standard
ordering of the commutator arguments
"""
from qnet.algebra.core.operator_algebra import Commutator
assert len(ops) == 2
if cls.order_key(ops[1]) < cls.order_key(ops[0]):
return -1 * Commutator.create(ops[1], ops[0])
else:
return ops, kwargs
[docs]def accept_bras(cls, ops, kwargs):
"""Accept operands that are all bras, and turn that into to bra of the
operation applied to all corresponding kets"""
from qnet.algebra.core.state_algebra import Bra
kets = []
for bra in ops:
if isinstance(bra, Bra):
kets.append(bra.ket)
else:
return ops, kwargs
return Bra.create(cls.create(*kets, **kwargs))
[docs]def basis_ket_zero_outside_hs(cls, ops, kwargs):
"""For ``BasisKet.create(ind, hs)`` with an integer label `ind`, return a
:class:`ZeroKet` if `ind` is outside of the range of the underlying Hilbert
space
"""
from qnet.algebra.core.state_algebra import ZeroKet
ind, = ops
hs = kwargs['hs']
if isinstance(ind, int):
if ind < 0 or (hs._dimension is not None and ind >= hs._dimension):
return ZeroKet
return ops, kwargs
[docs]def indexed_sum_over_const(cls, ops, kwargs):
r'''Execute an indexed sum over a term that does not depend on the
summation indices
.. math::
\sum_{j=1}^{N} a = N a
>>> a = symbols('a')
>>> i, j = (IdxSym(s) for s in ('i', 'j'))
>>> unicode(Sum(i, 1, 2)(a))
'2 a'
>>> unicode(Sum(j, 1, 2)(Sum(i, 1, 2)(a * i)))
'∑_{i=1}^{2} 2 i a'
'''
term, *ranges = ops
new_ranges = []
new_term = term
for r in ranges:
if r.index_symbol not in term.free_symbols:
try:
new_term *= len(r)
except TypeError:
new_ranges.append(r)
else:
new_ranges.append(r)
if len(new_ranges) == 0:
return new_term
else:
return (new_term, ) + tuple(new_ranges), kwargs
def _ranges_key(r, delta_indices):
"""Sorting key for ranges.
When used with ``reverse=True``, this can be used to sort index ranges into
the order we would prefer to eliminate them by evaluating KroneckerDeltas:
First, eliminate primed indices, then indices names higher in the alphabet.
"""
idx = r.index_symbol
if idx in delta_indices:
return (r.index_symbol.primed, r.index_symbol.name)
else:
# ranges that are not in delta_indices should remain in the original
# order
return (0, ' ')
[docs]def indexed_sum_over_kronecker(cls, ops, kwargs):
"""Execute sums over KroneckerDelta prefactors"""
from qnet.algebra.core.abstract_quantum_algebra import QuantumExpression
term, *ranges = ops
assert isinstance(term, QuantumExpression)
deltas = set(Pattern(head=sympy.KroneckerDelta).findall(term))
if len(deltas) == 0:
return ops, kwargs # nothing to do
else: # the term contains at least one KroneckerDelta
delta_indices = set.union(*[set(
[idx for idx in delta.free_symbols if isinstance(idx, IdxSym)])
for delta in deltas])
ranges = sorted( # sort in the order we'd prefer to eliminate
ranges,
key=partial(_ranges_key, delta_indices=delta_indices),
reverse=True)
buffer = [(term, ranges)]
i = 0 # position in buffer that we're currently handling
i_range = 0 # position of index-range for current buffer item
while i < len(buffer):
t, rs = buffer[i]
if rs[i_range].index_symbol in delta_indices:
new_items, flag = _deltasummation(t, rs, i_range)
new_rs = new_items[0][1] # same for all new_items
buffer = buffer[:i] + new_items + buffer[i + 1:]
assert flag in [1, 2, 3]
if flag == 2:
i_range += 1
# * for flag == 1, leaving i_range unchanged will
# effectively to to the next range (as the current range
# was removed)
# * for flag == 3, buffer[i] has changed, and we'll want to
# call it again with the same i_range
else:
# if the index symbol doesn't occur in any KroneckerDelta,
# there is no chance _deltasummation will do anything; so we
# just skip to the next index
i_range += 1
new_rs = rs
if i_range >= len(new_rs):
# if we've exhausted the index-ranges for the current buffer
# item, go to the next buffer item
i += 1
i_range = 0
if len(buffer) == 1 and buffer[0] == (term, ranges):
return ops, kwargs # couldn't resolve deltas
else:
(t, rs) = buffer[0]
res = t
if len(rs) > 0:
res = cls.create(t, *rs, **kwargs)
for (t, rs) in buffer[1:]:
if len(rs) > 0:
t = cls.create(t, *rs, **kwargs)
res += t
return res
def _factors_for_expand_delta(expr):
"""Yield factors from expr, mixing sympy and QNET
Auxiliary routine for :func:`_expand_delta`.
"""
from qnet.algebra.core.scalar_algebra import ScalarValue
from qnet.algebra.core.abstract_quantum_algebra import (
ScalarTimesQuantumExpression)
if isinstance(expr, ScalarTimesQuantumExpression):
yield from _factors_for_expand_delta(expr.coeff)
yield expr.term
elif isinstance(expr, ScalarValue):
yield from _factors_for_expand_delta(expr.val)
elif isinstance(expr, sympy.Basic) and expr.is_Mul:
yield from expr.args
else:
yield expr
def _expand_delta(expr, idx):
"""Expand the first :class:`sympy.Add` containing a simple
:class:`sympy.KroneckerDelta`.
Auxiliary routine for :func:`_deltasummation`. Adapted from SymPy. The
input `expr` may be a :class:`.QuantumExpression` or a
`:class:`sympy.Basic` instance.
Returns a list of summands. The elements of the list may be
:class:`.QuantumExpression` or a `:class:`sympy.Basic` instances. There is
no guarantee of type stability: an input :class:`.QuantumExpression` may
result in a :class:`sympy.Basic` instance in the `summands`.
"""
found_first_delta = False
summands = None
for factor in _factors_for_expand_delta(expr):
need_to_expand = False
if not found_first_delta and isinstance(factor, sympy.Basic):
if factor.is_Add and _has_simple_delta(factor, idx):
need_to_expand = True
if need_to_expand:
found_first_delta = True
if summands is None:
summands = list(factor.args)
else:
summands = [summands[0]*t for t in factor.args]
else:
if summands is None:
summands = [factor, ]
else:
summands = [t*factor for t in summands]
return summands
def _split_sympy_quantum_factor(expr):
"""Split a product into sympy and qnet factors
This is a helper routine for applying some sympy transformation on an
arbitrary product-like expression in QNET. The idea is this::
expr -> sympy_factor, quantum_factor
sympy_factor -> sympy_function(sympy_factor)
expr -> sympy_factor * quantum_factor
"""
from qnet.algebra.core.abstract_quantum_algebra import (
QuantumExpression, ScalarTimesQuantumExpression)
from qnet.algebra.core.scalar_algebra import ScalarValue, ScalarTimes, One
if isinstance(expr, ScalarTimesQuantumExpression):
sympy_factor, quantum_factor = _split_sympy_quantum_factor(expr.coeff)
quantum_factor *= expr.term
elif isinstance(expr, ScalarValue):
sympy_factor = expr.val
quantum_factor = expr._one
elif isinstance(expr, ScalarTimes):
sympy_factor = sympy.S(1)
quantum_factor = expr._one
for op in expr.operands:
op_sympy, op_quantum = _split_sympy_quantum_factor(op)
sympy_factor *= op_sympy
quantum_factor *= op_quantum
elif isinstance(expr, sympy.Basic):
sympy_factor = expr
quantum_factor = One
else:
sympy_factor = sympy.S(1)
quantum_factor = expr
assert isinstance(sympy_factor, sympy.Basic)
assert isinstance(quantum_factor, QuantumExpression)
return sympy_factor, quantum_factor
def _extract_delta(expr, idx):
"""Extract a "simple" Kronecker delta containing `idx` from `expr`.
Assuming `expr` can be written as the product of a Kronecker Delta and a
`new_expr`, return a tuple of the sympy.KroneckerDelta instance and
`new_expr`. Otherwise, return a tuple of None and the original `expr`
(possibly converted to a :class:`.QuantumExpression`).
On input, `expr` can be a :class:`QuantumExpression` or a
:class:`sympy.Basic` object. On output, `new_expr` is guaranteed to be a
:class:`QuantumExpression`.
"""
from qnet.algebra.core.abstract_quantum_algebra import QuantumExpression
from qnet.algebra.core.scalar_algebra import ScalarValue
sympy_factor, quantum_factor = _split_sympy_quantum_factor(expr)
delta, new_expr = _sympy_extract_delta(sympy_factor, idx)
if delta is None:
new_expr = expr
else:
new_expr = new_expr * quantum_factor
if isinstance(new_expr, ScalarValue._val_types):
new_expr = ScalarValue.create(new_expr)
assert isinstance(new_expr, QuantumExpression)
return delta, new_expr
def _deltasummation(term, ranges, i_range):
"""Partially execute a summation for `term` with a Kronecker Delta for one
of the summation indices.
This implements the solution to the core sub-problem in
:func:`indexed_sum_over_kronecker`
Args:
term (QuantumExpression): term of the sum
ranges (list): list of all summation index ranges
(class:`IndexRangeBase` instances)
i_range (int): list-index of element in `ranges` which should be
eliminated
Returns:
``(result, flag)`` where `result` is a list
of ``(new_term, new_ranges)`` tuples and `flag` is an integer.
There are three possible cases, indicated by the returned `flag`. Consider
the following setup::
>>> i, j, k = symbols('i, j, k', cls=IdxSym)
>>> i_range = IndexOverList(i, (0, 1))
>>> j_range = IndexOverList(j, (0, 1))
>>> ranges = [i_range, j_range]
>>> def A(i, j):
... from sympy import IndexedBase
... return OperatorSymbol(StrLabel(IndexedBase('A')[i, j]), hs=0)
1. If executing the sum produces a single non-zero term, result will be
``[(new_term, new_ranges)]`` where `new_ranges` contains the input `ranges`
without the eliminated range specified by `i_range`. This should be the
most common case for calls to:func:`_deltasummation`::
>>> term = KroneckerDelta(i, j) * A(i, j)
>>> result, flag = _deltasummation(term, [i_range, j_range], 1)
>>> assert result == [(A(i, i), [i_range])]
>>> assert flag == 1
2. If executing the sum for the index symbol specified via `index_range`
does not reduce the sum, the result will be the list ``[(term, ranges)]``
with unchanged `term` and `ranges`::
>>> term = KroneckerDelta(j, k) * A(i, j)
>>> result, flag = _deltasummation(term, [i_range, j_range], 0)
>>> assert result == [(term, [i_range, j_range])]
>>> assert flag == 2
This case also covers if there is no Kroncker delta in the term::
>>> term = A(i, j)
>>> result, flag = _deltasummation(term, [i_range, j_range], 0)
>>> assert result == [(term, [i_range, j_range])]
>>> assert flag == 2
3. If `term` does not contain a Kronecker delta as a factor, but in a
sum that can be expanded, the result will be a list of
``[(summand1, ranges), (summand2, ranges), ...]`` for the summands of that
expansion. In this case, `:func:`_deltasummation` should be called again
for every tuple in the list, with the same `i_range`::
>>> term = (KroneckerDelta(i, j) + 1) * A(i, j)
>>> result, flag = _deltasummation(term, [i_range, j_range], 1)
>>> assert result == [
... (A(i, j), [i_range, j_range]),
... (KroneckerDelta(i,j) * A(i, j), [i_range, j_range])]
>>> assert flag == 3
"""
from qnet.algebra.core.abstract_quantum_algebra import QuantumExpression
idx = ranges[i_range].index_symbol
summands = _expand_delta(term, idx)
if len(summands) > 1:
return [(summand, ranges) for summand in summands], 3
else:
delta, expr = _extract_delta(summands[0], idx)
if not delta:
return [(term, ranges)], 2
solns = sympy.solve(delta.args[0] - delta.args[1], idx)
assert len(solns) > 0 # I can't think of an example that might cause this
# if len(solns) == 0:
# return [(term._zero, [])], 4
if len(solns) != 1:
return [(term, ranges)], 2
value = solns[0]
new_term = expr.substitute({idx: value})
if _RESOLVE_KRONECKER_WITH_PIECEWISE:
new_term *= ranges[i_range].piecewise_one(value)
assert isinstance(new_term, QuantumExpression)
return [(new_term, ranges[:i_range] + ranges[i_range+1:])], 1
[docs]def derivative_via_diff(cls, ops, kwargs):
"""Implementation of the :meth:`QuantumDerivative.create` interface via the
use of :meth:`QuantumExpression._diff`.
Thus, by having :meth:`.QuantumExpression.diff` delegate to
:meth:`.QuantumDerivative.create`, instead of
:meth:`.QuantumExpression._diff` directly, we get automatic caching of
derivatives
"""
assert len(ops) == 1
op = ops[0]
derivs = kwargs['derivs']
vals = kwargs['vals']
# both `derivs` and `vals` are guaranteed to be tuples, via the conversion
# that's happening in `QuantumDerivative.create`
for (sym, n) in derivs:
if sym.free_symbols.issubset(op.free_symbols):
for k in range(n):
op = op._diff(sym)
else:
return op.__class__._zero
if vals is not None:
try:
# for QuantumDerivative instance
return op.evaluate_at(vals)
except AttributeError:
# for explicit Expression
return op.substitute(vals)
else:
return op