summaryrefslogtreecommitdiff
path: root/Rx/v2/src/rxcpp/operators/rx-ref_count.hpp
blob: b68315d8a2b29d9b3fde40c433248dd2e8faf97f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.

#pragma once

/*! \file rx-ref_count.hpp

    \brief  Make some \c connectable_observable behave like an ordinary \c observable.
            Uses a reference count of the subscribers to control the connection to the published observable.

            The first subscription will cause a call to \c connect(), and the last \c unsubscribe will unsubscribe the connection.

            There are 2 variants of the operator:
            \li \c ref_count(): calls \c connect on the \c source \c connectable_observable.
            \li \c ref_count(other): calls \c connect on the \c other \c connectable_observable.

    \tparam ConnectableObservable the type of the \c other \c connectable_observable (optional)
    \param  other \c connectable_observable to call \c connect on (optional)

    If \c other is omitted, then \c source is used instead (which must be a \c connectable_observable).
    Otherwise, \c source can be a regular \c observable.

    \return An \c observable that emits the items from its \c source.

    \sample
    \snippet ref_count.cpp ref_count other diamond sample
    \snippet output.txt ref_count other diamond sample
 */

#if !defined(RXCPP_OPERATORS_RX_REF_COUNT_HPP)
#define RXCPP_OPERATORS_RX_REF_COUNT_HPP

#include "../rx-includes.hpp"

namespace rxcpp {

namespace operators {

namespace detail {

template<class... AN>
struct ref_count_invalid_arguments {};

template<class... AN>
struct ref_count_invalid : public rxo::operator_base<ref_count_invalid_arguments<AN...>> {
    using type = observable<ref_count_invalid_arguments<AN...>, ref_count_invalid<AN...>>;
};
template<class... AN>
using ref_count_invalid_t = typename ref_count_invalid<AN...>::type;

// ref_count(other) takes a regular observable source, not a connectable_observable.
// use template specialization to avoid instantiating 'subscribe' for two different types
// which would cause a compilation error.
template <typename connectable_type, typename observable_type>
struct ref_count_state_base {
    ref_count_state_base(connectable_type other, observable_type source)
        : connectable(std::move(other))
        , subscribable(std::move(source)) {}

    connectable_type connectable; // connects to this. subscribes to this if subscribable empty.
    observable_type subscribable; // subscribes to this if non-empty.

    template <typename Subscriber>
    void subscribe(Subscriber&& o) {
        subscribable.subscribe(std::forward<Subscriber>(o));
    }
};

// Note: explicit specializations have to be at namespace scope prior to C++17.
template <typename connectable_type>
struct ref_count_state_base<connectable_type, void> {
    explicit ref_count_state_base(connectable_type c)
        : connectable(std::move(c)) {}

    connectable_type connectable; // connects to this. subscribes to this if subscribable empty.

    template <typename Subscriber>
    void subscribe(Subscriber&& o) {
        connectable.subscribe(std::forward<Subscriber>(o));
    }
};

template<class T,
         class ConnectableObservable,
         class Observable = void> // note: type order flipped versus the operator.
struct ref_count : public operator_base<T>
{
    typedef rxu::decay_t<Observable> observable_type;
    typedef rxu::decay_t<ConnectableObservable> connectable_type;

    // ref_count() == false
    // ref_count(other) == true
    using has_observable_t = rxu::negation<std::is_same<void, Observable>>;
    static constexpr bool has_observable_v = has_observable_t::value;

    struct ref_count_state : public std::enable_shared_from_this<ref_count_state>,
                             public ref_count_state_base<ConnectableObservable, Observable>
    {
        template <class HasObservable = has_observable_t,
                  class Enabled = rxu::enable_if_all_true_type_t<
                      rxu::negation<HasObservable>>>
        explicit ref_count_state(connectable_type source)
            : ref_count_state_base<ConnectableObservable, Observable>(std::move(source))
            , subscribers(0)
        {
        }

        template <bool HasObservableV = has_observable_v>
        ref_count_state(connectable_type other,
                        typename std::enable_if<HasObservableV, observable_type>::type source)
            : ref_count_state_base<ConnectableObservable, Observable>(std::move(other),
                                                                      std::move(source))
            , subscribers(0)
        {
        }

        std::mutex lock;
        long subscribers;
        composite_subscription connection;
    };
    std::shared_ptr<ref_count_state> state;

    // connectable_observable<T> source = ...;
    // source.ref_count();
    //
    // calls connect on source after the subscribe on source.
    template <class HasObservable = has_observable_t,
              class Enabled = rxu::enable_if_all_true_type_t<
                  rxu::negation<HasObservable>>>
    explicit ref_count(connectable_type source)
        : state(std::make_shared<ref_count_state>(std::move(source)))
    {
    }

    // connectable_observable<?> other = ...;
    // observable<T> source = ...;
    // source.ref_count(other);
    //
    // calls connect on 'other' after the subscribe on 'source'.
    template <bool HasObservableV = has_observable_v>
    ref_count(connectable_type other,
              typename std::enable_if<HasObservableV, observable_type>::type source)
        : state(std::make_shared<ref_count_state>(std::move(other), std::move(source)))
    {
    }

    template<class Subscriber>
    void on_subscribe(Subscriber&& o) const {
        std::unique_lock<std::mutex> guard(state->lock);
        auto needConnect = ++state->subscribers == 1;
        auto keepAlive = state;
        guard.unlock();
        o.add(
            [keepAlive](){
                std::unique_lock<std::mutex> guard_unsubscribe(keepAlive->lock);
                if (--keepAlive->subscribers == 0) {
                    keepAlive->connection.unsubscribe();
                    keepAlive->connection = composite_subscription();
                }
            });
        keepAlive->subscribe(std::forward<Subscriber>(o));
        if (needConnect) {
            keepAlive->connectable.connect(keepAlive->connection);
        }
    }
};

}

/*! @copydoc rx-ref_count.hpp
*/
template<class... AN>
auto ref_count(AN&&... an)
    ->     operator_factory<ref_count_tag, AN...> {
    return operator_factory<ref_count_tag, AN...>(std::make_tuple(std::forward<AN>(an)...));
}   
    
}

template<>
struct member_overload<ref_count_tag>
{
    template<class ConnectableObservable,
        class Enabled = rxu::enable_if_all_true_type_t<
            is_connectable_observable<ConnectableObservable>>,
        class SourceValue = rxu::value_type_t<ConnectableObservable>,
        class RefCount = rxo::detail::ref_count<SourceValue, rxu::decay_t<ConnectableObservable>>,
        class Value = rxu::value_type_t<RefCount>,
        class Result = observable<Value, RefCount>
        >
    static Result member(ConnectableObservable&& o) {
        return Result(RefCount(std::forward<ConnectableObservable>(o)));
    }

    template<class Observable,
        class ConnectableObservable,
        class Enabled = rxu::enable_if_all_true_type_t<
            is_observable<Observable>,
            is_connectable_observable<ConnectableObservable>>,
        class SourceValue = rxu::value_type_t<Observable>,
        class RefCount = rxo::detail::ref_count<SourceValue,
            rxu::decay_t<ConnectableObservable>,
            rxu::decay_t<Observable>>,
        class Value = rxu::value_type_t<RefCount>,
        class Result = observable<Value, RefCount>
        >
    static Result member(Observable&& o, ConnectableObservable&& other) {
        return Result(RefCount(std::forward<ConnectableObservable>(other),
                               std::forward<Observable>(o)));
    }

    template<class... AN>
    static operators::detail::ref_count_invalid_t<AN...> member(AN...) {
        std::terminate();
        return {};
        static_assert(sizeof...(AN) == 10000, "ref_count takes (optional ConnectableObservable)");
    }
};
    
}

#endif