@@ -50,45 +50,143 @@ leverage_weights(p::LinPred, wt::AbstractVector) = sqrt.(1 .- leverage(p, wt))
5050# beta0
5151# end
5252
53- """
54- DensePredQR
5553
56- A `LinPred` type with a dense, unpivoted QR decomposition of `X`
54+ # #########################################
55+ # ##### DensePredQR
56+ # #########################################
5757
58- # Members
58+ @static if get_pkg_version(GLM) < v" 1.9"
59+ @warn(
60+ " GLM.DensePredQR(X::AbstractMatrix, pivot::Bool=true) is not defined, " *
61+ " fallback to unpivoted RobustModels.DensePredQR definition. " *
62+ " To use pivoted QR, GLM version should be greater than or equal to v1.9."
63+ )
5964
60- - `X`: Model matrix of size `n` × `p` with `n ≥ p`. Should be full column rank.
61- - `beta0`: base coefficient vector of length `p`
62- - `delbeta`: increment to coefficient vector, also of length `p`
63- - `scratchbeta`: scratch vector of length `p`, used in `linpred!` method
64- - `qr`: a `QRCompactWY` object created from `X`, with optional row weights.
65- """
66- DensePredQR
67-
68- PRED_QR_WARNING_ISSUED = false
69-
70- function qrpred(X:: AbstractMatrix , pivot:: Bool = false )
71- try
72- return DensePredCG(Matrix(X), pivot)
73- catch e
74- if e isa MethodError
75- # GLM.DensePredCG(X::AbstractMatrix, pivot::Bool) is not defined
76- global PRED_QR_WARNING_ISSUED
77- if ! PRED_QR_WARNING_ISSUED
78- @warn(
79- " GLM.DensePredCG(X::AbstractMatrix, pivot::Bool) is not defined, " *
80- " fallback to unpivoted QR. GLM version should be >= 1.9."
81- )
82- PRED_QR_WARNING_ISSUED = true
65+ using LinearAlgebra: QRCompactWY, QRPivoted, Diagonal, qr!, qr
66+
67+ """
68+ DensePredQR
69+
70+ A `LinPred` type with a dense QR decomposition of `X`
71+
72+ # Members
73+
74+ - `X`: Model matrix of size `n` × `p` with `n ≥ p`. Should be full column rank.
75+ - `beta0`: base coefficient vector of length `p`
76+ - `delbeta`: increment to coefficient vector, also of length `p`
77+ - `scratchbeta`: scratch vector of length `p`, used in `linpred!` method
78+ - `qr`: a `QRCompactWY` object created from `X`, with optional row weights.
79+ - `scratchm1`: scratch Matrix{T} of the same size as `X`
80+ - `scratchm2`: scratch Matrix{T} of the same size as `X`
81+ - `scratchR`: scratch Matrix{T} of the same size as `qr.R`, a square matrix.
82+ """
83+ mutable struct DensePredQR{T<: BlasReal ,Q<: Union{QRCompactWY,QRPivoted} } <: DensePred
84+ X:: Matrix{T} # model matrix
85+ beta0:: Vector{T} # base coefficient vector
86+ delbeta:: Vector{T} # coefficient increment
87+ scratchbeta:: Vector{T}
88+ qr:: Q
89+ scratchm1:: Matrix{T}
90+ scratchm2:: Matrix{T}
91+ scratchR:: Matrix{T}
92+
93+ function DensePredQR(X:: AbstractMatrix , pivot:: Bool = false )
94+ n, p = size(X)
95+ T = typeof(float(zero(eltype(X))))
96+
97+ if false
98+ # if pivot
99+ F = pivoted_qr!(copy(X))
100+ else
101+ if n >= p
102+ F = qr(X)
103+ else
104+ # adjoint of X so R is square
105+ # cannot use in-place qr!
106+ F = qr(X)
107+ end
83108 end
84- return DensePredCG(Matrix(X))
109+
110+ return new{T,typeof(F)}(
111+ Matrix{T}(X),
112+ zeros(T, p),
113+ zeros(T, p),
114+ zeros(T, p),
115+ F,
116+ similar(X, T),
117+ similar(X, T),
118+ zeros(T, size(F. R)),
119+ )
120+ end
121+ end
122+
123+ # GLM.DensePredQR(X::AbstractMatrix, pivot::Bool) is not defined
124+ function qrpred(X:: AbstractMatrix , pivot:: Bool = false )
125+ return DensePredQR(Matrix(X))
126+ end
127+
128+ # GLM.delbeta!(p::DensePredQR{T}, r::Vector{T}) is ill-defined
129+ function delbeta!(p:: DensePredQR{T,<:QRCompactWY} , r:: Vector{T} ) where {T<: BlasReal }
130+ n, m = size(p. X)
131+ if n >= m
132+ p. delbeta = p. qr \ r
133+ else
134+ p. delbeta = p. qr' \ r
135+ end
136+ return p
137+ end
138+
139+ # GLM.delbeta!(p::DensePredQR{T}, r::Vector{T}, wt::Vector{T}) is not defined
140+ function delbeta!(
141+ p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}, wt::Vector{T}
142+ ) where {T<:BlasReal}
143+ rnk = rank(p.qr.R)
144+ X = p.X
145+ W = Diagonal(wt)
146+ sqrtW = Diagonal(sqrt.(wt))
147+ scratchm1 = p.scratchm1 = similar(X, T)
148+ mul!(scratchm1, sqrtW, X)
149+
150+ n, m = size(X)
151+ if n >= m
152+ # W½ X = Q R , with Q' Q = I
153+ # X'WX β = X'y => R'Q'QR β = X'y
154+ # => β = R⁻¹ R⁻ᵀ X'y
155+ qnr = p. qr = qr(scratchm1)
156+ Rinv = p. scratchR = inv(qnr. R)
157+
158+ scratchm2 = p. scratchm2 = similar(X, T)
159+ mul!(scratchm2, W, X)
160+ mul!(p. delbeta, transpose(scratchm2), r)
161+
162+ p. delbeta = Rinv * Rinv' * p.delbeta
85163 else
86- rethrow()
164+ # (W½ X)' = Q R , with Q' Q = I
165+ # W½X β = W½y => R' Q' β = y
166+ # => β = Q . [R⁻ᵀ y; 0]
167+ qnrT = p.qr = qr(scratchm1' )
168+ RTinv = p. scratchR = inv(qnrT. R)'
169+ @assert 1 <= n <= size(p. delbeta, 1 )
170+ mul!(view(p. delbeta, 1 : n), RTinv, r)
171+ p. delbeta = zeros(size(p. delbeta))
172+ p. delbeta[1 : n] .= RTinv * r
173+ lmul!(qnrT. Q, p. delbeta)
87174 end
175+ return p
88176 end
177+
178+
179+ # # Use DensePredQR from GLM
180+ else
181+ using GLM: DensePredQR
182+ import GLM: qrpred
89183end
90184
91185
186+ # #########################################
187+ # ##### [Dense/Sparse]PredCG
188+ # #########################################
189+
92190"""
93191 DensePredCG
94192
@@ -109,20 +207,8 @@ mutable struct DensePredCG{T<:BlasReal} <: DensePred
109207 scratchbeta:: Vector{T}
110208 scratchm1:: Matrix{T}
111209 scratchr1:: Vector{T}
112- function DensePredCG{T}(X:: Matrix{T} , beta0:: Vector{T} ) where {T}
113- n, p = size(X)
114- length(beta0) == p || throw(DimensionMismatch(" length(β0) ≠ size(X,2)" ))
115- return new{T}(
116- X,
117- beta0,
118- zeros(T, p),
119- zeros(T, (p, p)),
120- zeros(T, p),
121- zeros(T, (n, p)),
122- zeros(T, n),
123- )
124- end
125- function DensePredCG{T}(X:: Matrix{T} ) where {T}
210+
211+ function DensePredCG(X:: Matrix{T} ) where {T<: BlasReal }
126212 n, p = size(X)
127213 return new{T}(
128214 X,
@@ -135,10 +221,8 @@ mutable struct DensePredCG{T<:BlasReal} <: DensePred
135221 )
136222 end
137223end
138- DensePredCG(X:: Matrix , beta0:: Vector ) = DensePredCG{eltype(X)}(X, beta0)
139- DensePredCG(X:: Matrix{T} ) where {T} = DensePredCG{T}(X, zeros(T, size(X, 2 )))
140224function Base. convert(:: Type{DensePredCG{T}} , X:: Matrix{T} ) where {T}
141- return DensePredCG{T}(X, zeros(T, size(X, 2 )) )
225+ return DensePredCG(X )
142226end
143227
144228# Compatibility with cholpred(X, pivot)
0 commit comments