aboutsummaryrefslogtreecommitdiff
#!r6rs
(library (tokyo tojo map avl)
  (export map? empty? empty
          add search
          (rename (avl:remove remove))
          size
          (rename (avl:map map)
                  (avl:map/key map/key)
                  (avl:filter filter)
                  (avl:filter/key filter/key)
                  (avl:partition partition)
                  (avl:partition/key partition/key)
                  (avl:for-each for-each)
                  (avl:for-each/key for-each/key)
                  (avl:fold-left fold-left)
                  (avl:fold-left/key fold-left/key)
                  (avl:fold-right fold-right)
                  (avl:fold-right/key fold-right/key))
          union union/key)
  (import (rnrs))
  
  (define map?
    (lambda (x)
      (or (empty? x) (node? x))))

  ;; (define empty (list 'avl:empty))
  ;; (define empty? (lambda (x) (eq? x empty)))
  
  ;; (define node (list 'node))
  ;; (define node? (lambda (x)
  ;;                 (and (pair? x)
  ;;                      (eq? (car x) node))))
  ;; (define make-node
  ;;   (lambda (kv l r h)
  ;;     (list node kv l r h)))
  
  ;; (define key&value (lambda (node) (cadr node)))
  ;; (define left (lambda (node) (caddr node)))
  ;; (define right (lambda (node) (cadddr node)))
  ;; (define %height (lambda (node) (cadddr (cdr node))))
  
  (define-record-type (avl:empty make-empty empty?)
    (opaque #t))

  (define empty (make-empty))

  (define-record-type (avl:node make-node node?)
    (fields (immutable kv key&value)
            (immutable l left)
            (immutable r right)
            (immutable h %height))
    (opaque #t))

  (define error-not-map-tree
    (lambda (name x)
      (unless (map? x)
        (assertion-violation name "not map tree" x))))

  (define height
    (lambda (tr)
      (if (empty? tr)
          0
          (%height tr))))

  (define balance-factor
    (lambda (node)
      (- (height (left node))
         (height (right node)))))

  (define height+
    (lambda (l r)
      (+ 1 (max (height l)
                (height r)))))

  ;; (-> (a p (b q c))
  ;;     ((a p b) q c))
  (define rotate-left
    (lambda (p)
      (let ([q (right p)])
        (let ([a (left p)]
              [b (left q)]
              [c (right q)])
          (let ([p (make-node (key&value p)
                              a b
                              (height+ a b))])
            (make-node (key&value q)
                       p c
                       (height+ p c)))))))

  ;; (-> ((a p b) q c)
  ;;     (a p (b q c)))
  (define rotate-right
    (lambda (q)
      (let ([p (left q)])
        (let ([a (left p)]
              [b (right p)]
              [c (right q)])
          (let ([q (make-node (key&value q)
                              b c
                              (height+ b c))])
            (make-node (key&value p)
                       a q
                       (height+ a q)))))))


  ;; (-> ((a p (b q c)) r d)
  ;;     (((a p b) q c) r d)  ; left rotation
  ;;     ((a p b) q (c r d))) ; right rotation
  (define rotate-left-right
    (lambda (r)
      (let* ([p (left r)]
             [q (right p)])
        (let ([a (left p)]
              [b (left q)]
              [c (right q)]
              [d (right r)])
          (let ([p (make-node (key&value p)
                              a b
                              (height+ a b))]
                [r (make-node (key&value r)
                              c d
                              (height+ c d))])
            (make-node (key&value q)
                       p r
                       (height+ p r)))))))

  ;; (-> (a p ((b q c) r d))
  ;;     (a p (b q (c r d)))  ; right rotation
  ;;     ((a p b) q (c r d))) ; left rotation
  (define rotate-right-left
    (lambda (p)
      (let* ([r (right p)]
             [q (left r)])
        (let ([a (left p)]
              [b (left q)]
              [c (right q)]
              [d (right r)])
          (let ([p (make-node (key&value p)
                              a b
                              (height+ a b))]
                [r (make-node (key&value r)
                              c d
                              (height+ c d))])
            (make-node (key&value q)
                       p r
                       (height+ p r)))))))

  (define balance
    (lambda (tr)
      (if (empty? tr)
          empty
          (case (balance-factor tr)
            [(2)
             (case (balance-factor (left tr))
               [(0 1) (rotate-right tr)]
               [(-1) (rotate-left-right tr)]
               [else (assertion-violation 'blanace "error")])]
            [(-2)
             (case (balance-factor (right tr))
               [(0 -1) (rotate-left tr)]
               [(1)  (rotate-right-left tr)]
               [else (assertion-violation 'blanace "error")])]
            [(-1 0 1) tr]
            [else
             (assertion-violation 'balance "error"
                                  (balance-factor tr)
                                  tr)]))))

  (define add
    (case-lambda
      [(<? =? tr k v)
       (error-not-map-tree 'add tr)
       (let f ([tr tr])
         (if (not (or (node? tr) (empty? tr)))
             (assertion-violation 'add "not map tree" tr))
         (if (empty? tr)
             (make-node (cons k v) empty empty 0)
             (let ([kv (key&value tr)])
               (cond
                [(=? k (car kv))
                 (make-node (cons k v)
                            (left tr) (right tr)
                            (height tr))]
                [(<? k (car kv))
                 (let ([l (f (left tr))]
                       [r (right tr)])
                   (balance (make-node (key&value tr)
                                       l r
                                       (height+ l r))))]
                [else
                 (let ([l (left tr)]
                       [r (f (right tr))])
                   (balance (make-node (key&value tr)
                                       l r
                                       (height+ l r))))]))))]
      [(tr k v)
       (add < = tr k v)]))

  (define search
    (case-lambda
      [(<? =? tr k)
       (error-not-map-tree 'search tr)
       (let f ([tr tr])
         (if (empty? tr)
             #f
             (let ([kv (key&value tr)])
               (cond
                [(=? k (car kv)) (key&value tr)]
                [(<? k (car kv)) (f (left tr))]
                [else (f (right tr))]))))]
      [(tr k) (search < = tr k)]))

  (define node-max
    (lambda (node)
      (if (empty? (right node))
          (values (left node) (key&value node))
          (let-values ([(tr m) (node-max (right node))])
            (values (balance
                     (make-node (key&value node)
                                (left node) tr
                                (height+ (left node) tr)))
                    m)))))

  

  (define avl:remove
    (case-lambda
      [(<? =? tr k)
       (define rem
         (lambda (node)
           (cond
            [(empty? (left node)) (right node)]
            [(empty? (right node)) (left node)]
            [else
             (let-values ([(tr kv) (node-max (left node))])
               (balance
                (make-node kv
                           tr (right node)
                           (height+ tr (right node)))))])))
       (error-not-map-tree 'remove tr)
       (let f ([tr tr])
         (if (empty? tr)
             empty
             (let ([kv (key&value tr)])
               (cond
                [(=? k (car kv)) (rem tr)]
                [(<? k (car kv))
                 (let ([l (f (left tr))]
                       [r (right tr)])
                   (balance
                    (make-node kv
                               l r
                               (height+ l r))))]
                [else
                 (let ([l (left tr)]
                       [r (f (right tr))])
                   (balance
                    (make-node kv
                               l r
                               (height+ l r))))]))))]
      [(tr k)
       (avl:remove < = tr k)]))
  
  (define avl:map
    (lambda (f tr)
      (error-not-map-tree 'map tr)
      (avl:map/key (lambda (k v) (f v)) tr)))
  
  (define avl:map/key
    (lambda (f tr)
      (error-not-map-tree 'map/key tr)
      (let g ([tr tr])
        (if (empty? tr)
            empty
            (let ([kv (key&value tr)])
              (make-node (cons (car kv) (f (car kv)
                                           (cdr kv)))
                         (g (left tr))
                         (g (right tr))
                         (height tr)))))))
  
  (define avl:for-each
    (lambda (f tr)
      (error-not-map-tree 'for-each tr)
      (avl:for-each/key (lambda (k v) (f v)) tr)))
  
  (define avl:for-each/key
    (lambda (f tr)
      (error-not-map-tree 'for-each/key tr)
      (let g ([tr tr])
        (if (empty? tr)
            (if #f #f)
            (let ([kv (key&value tr)])
              (g (left tr))
              (f (car kv) (cdr kv))
              (g (right tr)))))))
  
  (define avl:fold-left
    (lambda (f init tr)
      (error-not-map-tree 'fold-left tr)
      (avl:fold-left/key (lambda (acc k v) (f acc v))
                         init tr)))
  
  (define avl:fold-left/key
    (lambda (f init tr)
      (error-not-map-tree 'fold-left/key tr)
      (let g ([acc init] [tr tr])
        (if (empty? tr)
            acc
            (let ([kv (key&value tr)])
              (g (f (g acc (left tr))
                    (car kv)
                    (cdr kv))
                 (right tr)))))))

  (define avl:fold-right
    (lambda (f init tr)
      (error-not-map-tree 'fold-right tr)
      (avl:fold-right/key (lambda (k v acc) (f v acc))
                          init tr)))
  
  (define avl:fold-right/key
    (lambda (f init tr)
      (error-not-map-tree 'fold-right/key tr)
      (let g ([acc init] [tr tr])
        (if (empty? tr)
            acc
            (let ([kv (key&value tr)])
              (g (f (car kv) (cdr kv)
                    (g acc (right tr)))
                 (left tr)))))))
  
  (define size
    (lambda (tr)
      (error-not-map-tree 'size tr)
      (avl:fold-left (lambda (acc v) (+ acc 1)) 0 tr)))

  (define avl:filter
    (lambda (p? tr)
      (error-not-map-tree 'filter tr)
      (avl:filter/key (lambda (k v) (p? v)) tr)))
  
  (define avl:filter/key
    (lambda (p? tr)
      (error-not-map-tree 'filter/key tr)
      (avl:for-each/key (lambda (k v)
                          (if (not (p? k v))
                              (set! tr (avl:remove tr k))))
                        tr)
      tr))
  
  (define avl:partition
    (lambda (p? tr)
      (error-not-map-tree 'filter tr)
      (avl:partition/key (lambda (k v) (p? v)) tr)))

  (define avl:partition/key
    (lambda (p? tr)
      (error-not-map-tree 'filter tr)
      (let ([t tr]
            [f tr])
        (avl:for-each/key (lambda (k v)
                            (if (p? k v)
                                (set! f (avl:remove f k))
                                (set! t (avl:remove t k))))
                          tr)
        (values t f))))

  (define union
    (case-lambda
      [(f m1 m2)
       (union < = f m1 m2)]
      [(<? =? f m1 m2)
       (union/key <? =? (lambda (k v1 v2) (f v1 v2)) m1 m2)]))

  (define union/key
    (case-lambda
      [(f m1 m2)
       (union < = f m1 m2)]
      [(<? =? f m1 m2)
       (let ([l1 (avl:fold-right/key (lambda (k v acc)
                                       (cons (cons k v) acc))
                                     '() m1)]
             [l2 (avl:fold-right/key (lambda (k v acc)
                                       (cons (cons k v) acc))
                                     '() m2)])
         (let loop ([l1 l1] [l2 l2] [acc empty])
           (cond [(null? l1)
                  (fold-left
                   (lambda (acc p) (add <? =? acc
                                           (car p) (cdr p)))
                   acc l2)]
                 [(null? l2)
                  (fold-left (lambda (acc p)
                               (add <? =? acc
                                       (car p) (cdr p)))
                             acc l1)]
                 [(=? (caar l1) (caar l2))
                  (loop (cdr l1) (cdr l2)
                        (add <? =? acc
                                (caar l1)
                                (f (caar l1) (cdar l1) (cdar l2))))]
                 [(<? (caar l1) (caar l2))
                  (loop (cdr l1) l2
                        (add <? =? acc
                                (caar l1) (cdar l1)))]
                 [else
                  (loop l1 (cdr l2)
                        (add <? =? acc
                                (caar l2) (cdar l2)))])))])))