TechnoSpacecraft

Another programming blog

Using Partial Function Applications With Recursive Wrappers and Functional Combinators

I know that it's a word salad of a title, but if you're reading this, then I'm sure you're A) old enough to read, B) have some idea of the topic, C) in the know that word salad titles are click-bait.

I cut my teeth on functional programming in Elixir. While it is a great language for many things and it is a fun language to use, ... functional it is not. When compared to something like Ocaml, there are many features of a more 'pure' functional language that Elixir lacks. These same features are also either missing in Python, or poorly implemented.

I like Python, I'll admit it. I can do many things quickly without much dev overhead. It's a good thing that I like it because my job is roughly 98% Python development. But I'm always on the search on how do do things a bit faster without making the code hard to read or maintain.

One trick that I have picked up from Ocaml is the use of partial function applications. In Ocaml you can do some weird stuff like so:

let adder x y = x + y;;
let add_five = adder 5;;
add_five 7;;
12

This seems like a strange and an unnecessary thing to do, but partial function applications are basically simple closures on-the-fly which is quite useful; a point that I will flesh out a bit later.

How Ocaml does this is quite unique. In the Ocaml world there are, technically, no multi argument functions, but a function that is a chain of functions. If you take a look at the signature of adder above it will look like the following:

x => y => x+y

To your eyes , this is a bit weird and a little hard to grok unless you understand that there are no variables only functions. The signature basically says “adder is a function that yields x. X is a function that takes an input and returns function y. Y is a function that takes an input and returns x+y.” This also explains why, typically, functional languages don't use parenthesis for functions, because they would get overused quite quickly and would be semantically difficult.

You may be saying “But the partial class exists in Pythons functools package.” That is true, but using it can give you some pretty wild results.

For example let's say you have a function like so:

def foo(a, b, c):
    return f"{a=}, {b=}, {c=}"

Now lets also say that you need to make 2 partials from this one function because you're going to get the data you need at different times and you want to pass these partials around to other functions. And, let's also say that you get a, b and c out of order. You might do something like this:

bar = partial(foo, c=5)
# some other code

foobar = partial(bar, a=2)
# some more other code

foobar(1)

And I will tell you that the above won't work the way you think it will. In fact, it won't work at all. When you use a partial of a partial and then use keywords on positional arguments for out of order assignment (which is a totally bad thing to do in Python), all of your arguments must be keyword arguments. In the situation above the partial class will try to apply 2 to a, then 1 to a (not b like you would expect) and then throw an TypeError exception.

But the problem is that I still need that hot, sweet partial fix. A better solution is to just make our own partial wrapper that behaves a bit more sanely. For this let's implement a generic wrapper that implements recursion to do the job:

from functools import wraps, cache
from inspect import signature


@cache
def _get_required_arguments(func):
    parameters = inspect.signature(func).parameters
    positional_args = 0
    keyword_args = 0

    for k, v in parameters.items():
        match v.kind:
            case v.POSITIONAL_OR_KEYWORD:
                if v.default == inspect._empty:
                    positional_args += 1
                else:
                    keyword_args += 1
            case v.VAR_POSITIONAL:
                # if there are no positional args,
                # then we are only expecting variadic arguments
                # if there are positional args,
                # then a variadic argument is optional
                # and doesn't increment the required amount
                if positional_args == 0:
                    positional_args += 1
            case v.VAR_KEYWORD:
                if keyword_args == 0:
                    keyword_args += 1
    return (positional_args, keyword_args)


def partialize(func, *args, **kwargs):
    @wraps(func)
    def recurse(*nargs, **nkwargs):
        return partialize(func, *args, *nargs, **kwargs, **nkwargs)

    p_args, k_args = _get_required_arguments(func)
    if len(args) < p_args:
        return recurse
    return func(*args, **kwargs)

Now we can do funny things like this which works in a similar fashion to what Ocaml offers:

@partialize
def foo(a, b, c):
    return f"{a=}, {b=}, {c=}"
foo(1, 2, 3)
foo(1)(2, 3)
foo(1, 2)(3)
foo(1)(2)(3)

However, it does not fix the out of order problem from above. Nor does it fix the double application to a because of mixed keyword/positionals problem. What it does do is allow you to do is to wrap a function to make it partialable (not a word, I know) at function definition, something the partial class cannot offer.

To fix the out of order problem we would need to make a one-off combinator like so:

def foo(a, b, c):
    return (a, b, c)

@partialize
def bar_combinator(c, a, b):
    return foo(a, b, c)

bar_1 = bar_combinator(5)
bar_2 = bar_1(2)
result1 = bar_1(5, 10)
result2 = bar_2(1)

This of course does mean that you must know the order of application ahead of time. If you're in a situation where you have no idea what order you will be getting your arguments applied, then you should stick with only keyword arguments like so:

@partialize
def bar_combinator(a=None, b=None, c=None):
    return foo(a, b, c)

Now, that's a real simple combinator. In fact, it doesn't really combine anything, but just rearranges arguments. So let's look at a real world use case that I had just t'other day.

I needed to build multiple Mac objects that take different argument and treat them in the same way. The two objects were ISO9797-algo3 and ISO9797-algo5 (aka: CMAC). Algo3 requires that the data to be macked is padded first in one of three different methods also defined in ISO9797. CMAC does not need padding, but it does need to know what crypto algorithm to use. The class definitions for each are like so:

class MacAlgo3:
    def __init__(self, key: bytes):
        self.key = key

    def __call__(self, data: bytes) -> bytes:
        return how_the_mac_is_made

class Cmac:
    def __init__(self, key: bytes, algo: Literal["AES", "TDES"]):
        self.key = key
        self.algo = get_algo(algo)

    def __call__(self, data: bytes) -> bytes:
        return how_the_cmac_is_made

I knew that I needed my mac objects first, and my key and data would come later. To make this work:

@partialize
def comb1(cls, padder, key, data):
    return cls(key)(padder(data))


@partialize
def comb2(cls, algo, key, data):
    return cls(key, algo)(data)


def do_something(macker):
    # derive key and get data
    macker_loaded = macker(key)
    mac = macker_loaded(data)


if option.a:
    macker = comb1(Mac3, algo2)
else:
    macker = comb2(Cmac, 'AES')

do_something(macker)

In this way I can shoehorn various objects into other objects that share the same signature. If I were to do this just using the partial class without combinators, I'm not sure it could be done. If I were to use combinators but use the partial class, it would probably be a mess. If I had to make a combinator without any partialization and only use hand-rolled closures, then I would probably shoot myself and my foot in the process.

Combinators can make your life EZ, but you have to make partial functions. Make that EZ too.

In my journey of learning Elixir, I've decided that there has not been enough pain. To explain, moving from an OO language to an FP one has had its challenges. I have spent quite a bit of time working on a problem from Exercism, and feeling quite satisfied with my solution, only to find that someone else in the community had solved that same problem in a much, much simpler manner. That kind of pain I think is normal when learning a new language. Those are mere inconveniences.

I decided that for a more advanced exercise before diving into the world of full-blown web development was to work on a DUKPT algorithm.

“DUKPUT? Never heard of it.”

DUKPT (Derived Unique Key Per Transaction) is a key derivation algorithm that credit/debit card readers use to ensure that each transaction uses a new encryption key. In the past few years DUKPT has adopted AES and simplified its algorithm. I would have none of that. For this exercise I went with old-school 3DES derivation.

The TLDR of DUKPT

Using a Base Derivation Key (BDK) and Key Serial Number (KSN, a 80 bit register split into 3 parts, the BDKID, the DerivationID, and the Transaction Counter), the BDK does some mangling to the BDKID and DerivationID portions of the KSN to derive a unique device key called an Initial Key (IK), which is then securely injected into the pin-pad device. Each time the device needs to encrypt a pin/account number package (pinblock) the device will increment the transaction counter and use the IK to mangle and encrypt the DerivationID and Transaction Counter portions of the KSN to generate the unique key. The back end can verify the legitimacy of the pinblock by using the BDK and the entire KSN to derive the current key and decrypt the pinblock. Got it? Good. The DUKPT standard is defined in ANSI x29.4, I can't tell you what is in it exactly, I can only summarize.

The tricky parts

  • Each binary 1 in the Transaction Counter is another round of mangling/encryption.
  • The Transaction Counter is 21 bits long... That's right. Not an easy to use 24, but 21.
  • The mangling is a lot of XOR and AND operations, meaning that there is a lot of conversion from bytes to integers and back to deal with.

Things that I thought would be easy, but were not

  • As it turns out binaries and bitstrings in Elixir are not enumerable like lists. Convenience functions like Enum.map and Enum.reduce are worthless here.
  • Encryption in Elixir, as far as I can tell, is not something that happens often. Doing plain-ole 3DES encryption means dropping into Erlang.
  • Erlang only uses single DES for ECB modes.

OK, What's tricky about binaries in Elixir?

Working with binaries and bitstrings means that while they are effectively the same thing, you cannot mix the two in most cases. Binary measurements are for binary strings and bit measurements are for bitstrings, except for when the measurement of bitstrings land on 8-bit boundaries.

When dealing with either bitstring or binary string, size specifications are made at the head of the string, not at the tail. If you need to pull from the tail, you will need to calculate how much head to discard and then float the tail.

When dealing with either, the default measurement is in bits. For instance take the following data:

iex> data = Base.decode16!("DEADBEEFFEEDBEEF")
<<222, 173, 190, 239, 254, 237, 190, 239>>

To pull a 4-bit nibble, you cannot do the following:

iex> <<head::4, tail::binary>> = data

But you can do this:

iex> <<head:4, tail::bits>> = data

or this:

iex> <<head::size(4), tail::bits>> = data

To pull a single byte one would do this:

iex> <<head::8, tail::binary>> = data

or this:

iex> <<head::binary-size(1), tail::binary>> = data

or this:

<<head::binary-size(1), tail::bits>> = data

Also note that specifying size and reading size do not invoke the same functions.

iex> bit_size(data)
64
iex> <<head::size(5), tail::bits>> = data
<<222, 173, 190, 239, 254, 237, 190, 239>>

iex> byte_size(data)
8
iex> <<head::binary-size(1), tail::binary>> = data
<<222, 173, 190, 239, 254, 237, 190, 239>>

And finally, when using a variable for a size measurement, the variable must be passed to a sizing function; it cannot be used directly

This will not work:

iex> my_len = 8
8
iex> <<head::my_len, tail::binary>> = data
warning: bitstring specifier "my_len" does not exist and is being expanded to "my_len()", please use parentheses to remove the ambiguity
  iex:15

But, this will:

iex> <<head::size(my_len), tail::binary>> = data
<<222, 173, 190, 239, 254, 237, 190, 239>>

First Things First, Make My Tools

DUKPT uses simple ECB mode of 3DES encryption so it's nothing fancy. However from the Erlang Crypto Docs, as far as I can tell, 3DES ECB isn't implemented and I'll have to resort to 3-rounds of single DES encryption to do the job.

A 3DES key can either be 16 or 24 bytes long and works like this. In the case of a 24byte key in the form of A|B|C encrypt data with bytes A, decrypt with bytes B, encrypt with bytes C. In the case of 16byte keys, encrypt data with bytes A, decrypt with bytes B, encrypt again with bytes A. Seeing that chaining data is a part of the process, Elixir's pipe operator along with pattern matching and multiple function clauses worked wonders for this:

defmodule CryptoTools do

  def des_encrypt(data, <<k1::binary-size(8), k2::binary-size(8), k3::binary-size(8)>>) do
    data
    |> des_encrypt(k1)
    |> des_decrypt(k2)
    |> des_encrypt(k3)
  end

  def des_encrypt(data, <<k1::binary-size(8), k2::binary-size(8)>>) do
    data
    |> des_encrypt(k1)
    |> des_decrypt(k2)
    |> des_encrypt(k1)
  end

  def des_encrypt(data, <<k1::binary-size(8)>>) do
    :crypto.crypto_one_time(:des_ecb, k1, data, encrypt: true)
  end


  def des_decrypt(data, <<k1::binary-size(8), k2::binary-size(8), k3::binary-size(8)>>) do
    data
    |> des_decrypt(k3)
    |> des_encrypt(k2)
    |> des_decrypt(k1)
  end

  def des_decrypt(data, <<k1::binary-size(8), k2::binary-size(8)>>) do
    data
    |> des_decrypt(k1)
    |> des_encrypt(k2)
    |> des_decrypt(k1)
  end

  def des_decrypt(data, key) do
    :crypto.crypto_one_time(:des_ecb, key, data, encrypt: false)
  end
end

From there I needed to work on a way to do my bitwise operations on bytes. Here function capturing worked like a charm.

defmodule CryptoTools do

  ...

  defp byte_wise_operation(arg1, arg2, operator, acc \\ <<>>)

  defp byte_wise_operation(<<>>, <<>>, _operator, acc), do: acc

  defp byte_wise_operation(<<arg1::8, r1::binary>>, <<arg2::8, r2::binary>>, operator, acc) do
    byte_wise_operation(r1, r2, operator, acc <> <<operator.(arg1, arg2)>>)
  end


  def and_bytes(a, b),  do: byte_wise_operation(a, b, &Bitwise.band/2)


  def or_bytes(a, b),   do: byte_wise_operation(a, b, &Bitwise.bor/2)


  def xor_bytes(a, b),  do: byte_wise_operation(a, b, &Bitwise.bxor/2)
end

The Different Approach With Functional Programming

Creating the IK is a trivial issue. The bigger challenge is the Current key in regards to the algorithm which deals with the once per binary 1 in the counter operations. In an imperative language this is solved with a shift-register, while-loop and an IF statement: Start the loop and AND the shift-register with the counter to see if we are on a one. If we are then do the necessary operations and then shift right the shift-register by one and repeat the loop, then break out of the loop when the shift-register is zero.

While I could “fake” that loop with a recursive function and still use a shift-register and ANDing on an IF, that would totally break the spirit of using Elixir's pattern matching on bitstrings. Also, there is a fair number of variables that needs to be carried over each recursion. Do I make a recursive function with a ton of arguments? Do keep the number of arguments low and use a struct and increase complexity? Do I keep the number of arguments low and derive what I need via private, convenience functions at runtime?

The third option is what I went with, because in the imperative code some of those variables had dual purposes, while in the functional code they were reduced to single purpose. I guess I could have used a struct/map to pass data between each function call, I just didn't this time.

IK code

defmodule Dukpt.TDes.InitialKey do
  @behaviour Dukpt.TDes.Deriver
  import CryptoTools
  # @variant "C0C0C0C000000000C0C0C0C000000000"
  @variant <<192, 192, 192, 192, 0, 0, 0, 0, 192, 192, 192, 192, 0, 0, 0, 0>>
  @modder <<255, 255, 255, 255, 255, 255, 255, 224, 0, 0>>


  @impl Dukpt.TDes.Deriver
  def derive_key(key_material, ksn) do
    mod_ksn = and_bytes(ksn, @modder)
    modified_bdk = xor_bytes(key_material, @variant)
    {:ok, create_initial_half_key(mod_ksn, key_material) <> create_initial_half_key(mod_ksn, modified_bdk)}
  end


  defp create_initial_half_key(ksn, key) do
    <<device_id::binary-size(8), _::binary>> = ksn
    des_encrypt(device_id, key)
  end
end

Current Key Code

defmodule Dukpt.TDes.CurrentKey do
  @behaviour Dukpt.TDes.Deriver
  import CryptoTools
  # @variant "C0C0C0C000000000C0C0C0C000000000"
  @variant <<192, 192, 192, 192, 0, 0, 0, 0, 192, 192, 192, 192, 0, 0, 0, 0>>
  @dev_id_mask <<255, 255, 255, 255, 255, 224, 0, 0>>


  @impl Dukpt.TDes.Deriver
  def derive_key(initial_key, ksn) do
    discard_bytes = byte_size(ksn) - 8
    <<_::binary-size(discard_bytes), ksnr::binary >> = ksn

    dev_id = and_bytes(ksnr, @dev_id_mask)

    discard_bit_size = bit_size(ksnr) - 21
    <<_::size(discard_bit_size), counter::bits>> = ksnr

    derivation_process(counter, initial_key, dev_id)
  end


  defp update_dev_id(dev_id, rest) do
    tail_padding = bit_size(rest)
    head_padding = bit_size(dev_id) - 1 - tail_padding
    or_bytes(<<0::size(head_padding), 1::size(1), 0::size(tail_padding)>>, dev_id)
  end


  defp derivation_process(counter, current_key, dev_id, ones \\ 0)

  defp derivation_process(_, _, _, ones) when ones >= 10, do: {:error, "Too many binary ones in the counter"}

  defp derivation_process(<<>> = _counter, current_key, _dev_id, _ones) do
    {:ok, current_key}
  end

  defp derivation_process(<<0::1, rest::bits>> = _counter, current_key, dev_id, ones) do
    derivation_process(rest, current_key, dev_id, ones)
  end

  defp derivation_process(<<1::1, rest::bits>> = _counter, current_key, dev_id, ones) do
    <<right_first::binary-size(8), right_second::binary-size(8)>> = current_key
    <<left_first::binary-size(8), left_second::binary-size(8)>> = xor_bytes(current_key, @variant)

    dev_id = update_dev_id(dev_id, rest)

    right_half =
      xor_bytes(right_second, dev_id)
      |> des_encrypt(right_first)
      |> xor_bytes(right_second)

    left_half =
      xor_bytes(left_second, dev_id)
      |> des_encrypt(left_first)
      |> xor_bytes(left_second)

    derivation_process(rest, left_half <> right_half, dev_id, ones + 1)
  end
end

In the end the update_dev_id function was the only necessary convenience function needed. In the imperative world the device_id needs to be OR'd with the shift-register on every operation. Since I'm not ANDing with a shift-register for recursion control, I just recreate it when I need it.

Conclusion

When I implemented this code in Python some four years ago, there were also convenience functions that needed to be created for byte-wise operations. This was not as nearly straight forward as they were in Elixir, due to the fact that converting between Bytes and Ints in Python is a bit more esoteric than in Elixir.

While I could have used a shift register for recursion control, I didn't need to. I could instead use the much more ergonomic mechanism of pattern matching. In other languages I would have to remember that “ANDing with 1 keeps the bit while ANDing with 0 turns off the bit.” Previously I had to focus more on the “how” and not the “what,” which is quite liberating.