function overfit( option ) % OVERFIT Demonstrates the effects of overfitting a model % USE: Type overfit at the MATLAB prompt. % Peter Dunn % 10 August 2000 if nargin==0, % Data for button image M = ones(20,40)*0.7; for i = 1:20; for j = 1:20; if (i+j>20), M(i,j)=1; end; end;end; for i = 1:20; for j = 21:40; if (j-i<21), M(i,j)=1; end; end;end; UPimage = ones(20,40,3)*0.7; UPimage(:,:,3) = M; Mone = ones( size(M) )*0.7; Mone(M==1) = 0; UPimage(:,:,1) = Mone; UPimage(:,:,2) = Mone; Mdown= flipud(M); Monedown = flipud(Mone); DOWNimage = ones(20,40,3)*0.7; DOWNimage(:,:,1) = Monedown; DOWNimage(:,:,2) = Monedown; DOWNimage(:,:,3) = Mdown; % Create figure hOverFit = figure('Name','Overfitting Demonstration','tag','tagOverFit'); % Set up canvas axes( 'Position', [0.1 0.1 0.65 0.8] ); set(gca, 'FontSize',14); % Menu options uimenu( findobj(gcf,'tag','tagOverFit'), ... 'Label','Options', 'tag','tagOFOptions'); % Control Buttons uicontrol('Style','Frame','Units','normalized',... 'Position',[0.78 0.02 0.20 0.43] ); uicontrol('Style','PushButton','Units','normalized',... 'Position',[0.80 0.04 0.16 0.07], 'String','Quit',... 'Callback','delete(gcf);'); uicontrol('Style','PushButton','Units','normalized',... 'Position',[0.80 0.12 0.16 0.07], ... 'CData',DOWNimage, ... 'Enable','off',... 'tag','tagOFDown','Callback','overfit(13);'); uicontrol('Style','Edit','Units','normalized',... 'Position',[0.80 0.20 0.16 0.07], 'String','0',... 'Callback','overfit(14);', 'tag','tagOFEnterOrder',... 'Enable','off'); uicontrol('Style','PushButton','Units','normalized',... 'Position',[0.80 0.28 0.16 0.07], ... 'Callback','overfit(12);', 'tag','tagOFUp',... 'Enable','off',... 'CData',UPimage); uicontrol('Style','PushButton','Units','normalized',... 'Position',[0.80 0.36 0.16 0.07], ... 'String','Plot Data',... 'Callback','overfit(1);'); % Data Source uicontrol('Style','Frame','Units','normalized',... 'Position',[0.78 0.75 0.20 0.23] ); uicontrol('Style','Text','Units','normalized',... 'Position',[0.8 0.91 0.16 0.06],... 'String','Data Source'); uicontrol('Style','RadioButton','Units','normalized',... 'Position',[0.8 0.78 0.16 0.06],... 'String','Generate',... 'Value',1, ... 'tag','tagOFGenerate', ... 'Callback','overfit(51);'); uicontrol('Style','RadioButton','Units','normalized',... 'Position',[0.8 0.85 0.16 0.06],... 'String','Data File',... 'tag','tagOFLoad', ... 'Callback','overfit(52);'); % Option Buttons uicontrol('Style','Frame','Units','normalized',... 'Position',[0.78 0.47 0.20 0.26] ); % Create stuff for generating data overfit(4); elseif nargin==1, if option == 1, % Plot Data hold off; % If we use generated data, generate and plot. % Otherwise, use the loaded data and plot. datatype = get( findobj('tag','tagOFGenerate'),'Value'); % A zero means to use data file; a one to generate if datatype == 1, % Generate data and plot x = [ linspace(1,5, 5), linspace(8, 12, 5), linspace(15,25,5) ]; x = [ 1.1 logspace(0.2,0.6,15) 4.7]; x( [4, 6:10]) = []; x = x(:); % Find error variance errvar = get( findobj('tag','tagErrorVar'), 'String'); errvar = str2num( errvar ); % What polynomial has been used to generate the data? polytype = get( findobj('tag','tagPolyType'), 'Value'); % Generate data if polytype == 1, % linear y = 2 - x + randn( size(x))*sqrt(errvar); elseif polytype==2, % quadratic y = 2 - 6*x + x.^2 + randn(size(x))*sqrt(errvar); else % cubic y = -9 + 18*x - 7.5*x.^2 + x.^3 + randn( size(x))*sqrt(errvar); end; else % Use loaded data file UserData = get( findobj('tag','tagOverFit'), 'UserData'); y = UserData(:,1); x = UserData(:,2); end; figure( findobj('tag','tagOverFit') ); plot(x,y,'+', 'MarkerSize',12, 'LineWidth',2); xlabel('X values','FontSize',14); ylabel('y values','FontSize',14); % Save data for use later PlotData = [y, x]; set( findobj('tag','tagOverFit'), 'UserData',PlotData); % Enable other buttons now set( findobj('tag','tagOFUp'), 'Enable','on'); set( findobj('tag','tagOFDown'), 'Enable','on'); set( findobj('tag','tagOFEnterOrder'), 'Enable','on'); % Set up Table overfit(7); % Restore First Fitting order to 0 set( findobj('tag','tagOFEnterOrder'),'String','0'); elseif option == 2, % Fit the real order polynomial hold on; % What polynomial has been used to generate the data? polytype = get( findobj('tag','tagPolyType'), 'Value'); % Prepare to plot the `real' lines PlotData = get( findobj('tag', 'tagOverFit'), 'UserData'); y = PlotData(:,1); x = PlotData(:,2:end); X = [ ones(size(y)) x ]; Xplot = linspace(min(x), max(x), 100)'; Xp = [ ones(size(Xplot)) Xplot ]; if polytype==1, %linear beta = inv( X'*X) * X'*y; elseif polytype==2, % quadratic X = [ X x.^2 ]; Xp = [ Xp Xplot.^2 ]; beta = inv( X'*X) * X'*y; else % cubic X = [ X x.^2 x.^3 ]; Xp = [ Xp Xplot.^2 Xplot.^3 ]; beta = inv( X'*X) * X'*y; end; plot( Xplot, Xp*beta, 'k-', 'LineWidth',2); title( 'Best Fitted Polynomial', 'FontSize', 18 ); PlotData = get( findobj('tag','tagOverFit'), 'UserData'); set( findobj('tag','tagOverFit'), 'UserData',PlotData); elseif option == 3, % Fit the polynomials UP or DOWN PlotData = get( findobj('tag', 'tagOverFit'), 'UserData'); y = PlotData(:,1); x = PlotData(:,2); hold off; % To make sure we don't get lots of things on there % on top of each other plot(x,y,'+', 'MarkerSize',12, 'LineWidth',2); set( findobj('tag','tagOverFit'), 'UserData',PlotData); hold on; if exist('h'), delete(h); end; % Determine the type of data, if appropriate polytype = get( findobj('tag','tagPolyType'), 'Value'); if isempty(polytype), polytype = 8; end; PlotData = get( findobj('tag', 'tagOverFit'), 'UserData'); y = PlotData(:,1); x = PlotData(:,2:end); X = [ ones(size(y)) ]; warntype = warning; warning off; n = length(y); SST = y' * y - n * mean(y)^2; i = 0; allok = 1; Xplot = linspace(min(x), max(x), 100)'; Xp = [ ones(size(Xplot)) ]; %Now fit real model as specified first to get full SSE and co X = [ ones(size(y)) ]; for i = 1:polytype, X = [ X x.^i ]; end; beta = inv( X'*X) * X'*y; resids = y - X*beta; SSEfull = resids'*resids; X = [ ones(size(y)) ]; % If we don't set polytype back to empty, a black % line is drawn for 8th order. if get( findobj('tag','tagOFLoad'),'Value')==1, polytype = []; end; %while (allok), % Get the order to plot order = get( findobj('tag','tagOFEnterOrder'),'String'); order = str2num(order); for i = 1:order, if exist('h'); delete(h); end; X = [ X x.^i ]; Xp = [ Xp Xplot.^i ]; end; beta = inv(X'*X) * X'*y; p = length(beta); if ( i==polytype ), h = plot( Xplot, Xp*beta, 'k-', 'LineWidth',2); title( [ 'Polynomial Order ',num2str(i)], 'FontSize',18 ); else h = plot( Xplot, Xp*beta, 'r-', 'LineWidth',2); title( [ 'Polynomial Order ',num2str(i)], 'FontSize', 18 ); end; note = ' '; resids = y - X*beta; SSE = resids'*resids; R2 = 1 - (SSE / SST); adjR2 = 1 - (SSE/(n-p)) / ( SST/(n-1) ); hat = diag( X * inv(X'*X) * X' ); PRESS = resids ./ sqrt( 1 - hat ); %Cp = SSEfull / ( SSE/(n-p) ) - ( n - 2*p ); Cp = SSE / ( SSEfull/(n-p) ) - ( n - 2*p ); if ( i == polytype ), note = '*'; end; fprintf('%c%1.0f%c %9.5f %9.5f %9.5f %9.5f %9.5f %9.5f %3.0f\n',... note,i,note,SSE,R2,adjR2,SSE/(n-p), sum(PRESS.^2), Cp, p); if ( rcond(X'*X) < 1.0d-20 ), allok = 0; end; if i > 10, allok = 0; end; elseif option == 4, % Generate data Stuff uicontrol('Style','Popupmenu','Units','normalized',... 'Position',[0.80 0.49 0.16 0.06], ... 'String','linear|quadratic|cubic',... 'tag','tagPolyType','Callback','overfit(40);'); uicontrol('Style','Text','Units','normalized',... 'tag','tagTextPolyType',... 'Position',[0.80 0.545 0.16 0.06], 'String','Poly Type:'); uicontrol('Style','Text','Units','normalized',... 'tag','tagTextErrorVar',... 'Position',[0.80 0.65 0.17 0.06], 'String','Error var:'); uicontrol('Style','Edit','Units','normalized',... 'tag','tagErrorVar',... 'Position',[0.80 0.605 0.16 0.06], 'String','0.2',... 'Callback','overfit(10);'); elseif option == 5, % Load data file Stuff uicontrol('Style','Text','Units','normalized',... 'Position',[0.80 0.65 0.17 0.06],... 'tag','tagTextFileName',... 'String','Current File:'); uicontrol('Style','Text','Units','normalized',... 'Position',[0.80 0.60 0.17 0.06],... 'tag','tagFileName',... 'String','(none)'); uicontrol('Style','Pushbutton','Units','normalized',... 'Position',[0.80 0.49 0.17 0.06],... 'tag','tagPressToLoad',... 'String','Press to Load', ... 'Callback','overfit(6);'); elseif option == 6, % Load data file ok = 1; [filename, pathname] = uigetfile('*','Load Data File'); if ischar( filename ), % No error occurred ok = 1; else ok = 0; end; % Load data file eval( 'load ( [ pathname, filename ]); ok=1;','ok=0;'); if ok == 1, % Change file name being displayed set( findobj('tag','tagFileName'), 'String',filename); % Now, we need to check we can make sense of this file % This is where MATLAB will truncate: shortfilename = filename(1:min(findstr(filename,'.'))-1); eval( [ '[nr nc] = size(', shortfilename,' );' ] ); if ~any([nr,nc]==2), ok = 0; else if nr == 2, %Transpose to correct shape eval( [shortfilename,' = ', shortfilename,'''' ] ); end; end; end; if ( ok == 0 ), % A problem errstr = { [ 'A problem loading file ',pathname, filename ], 'Check it exists, and has two columns the same length.', 'The first column is x, the second column is y.' }; errordlg( errstr, 'Error Loading File'); else % All is OK evalstr = [ 'UserData = [ ',shortfilename,'(:,2) ', ... shortfilename,'(:,1) ];' ]; eval( evalstr ); set( findobj('tag','tagOverFit'), 'UserData', UserData ); end; elseif option == 7, % Set up table disp(' '); disp(['Order SSE R^2 adj R^2 ', ... ' s^2 sum(PRESS^2) Cp p']); disp('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'); elseif option == 10, % Error var changed errvar = get( findobj('tag','tagErrorVar'), 'String'); errvar = str2num(errvar); if isempty( errvar ), % Eg,. not a number entered errvar = 0.2; end; if ( errvar < 0 ), % Let it equal zero, for perfect fit? errvar = 0.05; end; if ( errvar > 1 ), errvar = 1; end; set( findobj('tag','tagErrorVar'), 'String', num2str(errvar)); elseif option == 12, %UP pressed order = get( findobj('tag','tagOFEnterOrder'), 'String' ); order = str2num(order) + 1; if order<1, order = 1; end; if order>8, order = 8; end; set( findobj('tag','tagOFEnterOrder'), 'String', num2str(order) ); overfit(3); elseif option == 13, % DOWN pressed order = get( findobj('tag','tagOFEnterOrder'), 'String' ); order = str2num(order) - 1; if order<1, order = 1; end; if order>8, order = 8; end; set( findobj('tag','tagOFEnterOrder'), 'String', num2str(order) ); overfit(3); elseif option == 14, % Order entered order = get( findobj('tag','tagOFEnterOrder'), 'String' ); order = round( str2num(order) ); if order<1, order = 1; end; if order>8, order = 8; end; set( findobj('tag','tagOFEnterOrder'), 'String', num2str(order) ); overfit(3); elseif option == 40, % Linear/Quadratic/Cubic changed set( findobj('tag','tagOFUp'), 'Enable','off'); set( findobj('tag','tagOFDown'), 'Enable','off'); set( findobj('tag','tagOFEnterOrder'), 'Enable','off'); elseif option == 51, % Generate Data set( findobj('tag','tagOFLoad'), 'Value',0); delete(findobj('tag','tagPressToLoad') ); delete(findobj('tag','tagFileName') ); delete(findobj(gcf,'tag','tagTextFileName') ); overfit(4); elseif option == 52, % Load Data File set( findobj('tag','tagOFGenerate'), 'Value',0); delete( findobj(gcf,'tag','tagPolyType') ); delete( findobj(gcf,'tag','tagErrorVar') ); delete( findobj(gcf,'tag','tagTextPolyType') ); delete( findobj(gcf,'tag','tagTextErrorVar') ); overfit(5); end; end;