Source code for pygeotools.lib.iolib

#! /usr/bin/env python
"""
Functions for IO, mostly wrapped around GDAL

Note: This was all written before RasterIO existed, which might be a better choice. 
"""

import os
import subprocess

import numpy as np
from osgeo import gdal, gdal_array, osr

#Define drivers
mem_drv = gdal.GetDriverByName('MEM')
gtif_drv = gdal.GetDriverByName('GTiff')
vrt_drv = gdal.GetDriverByName("VRT")

#Default GDAL creation options
gdal_opt = ['COMPRESS=LZW', 'TILED=YES', 'BIGTIFF=IF_SAFER']
#gdal_opt += ['BLOCKXSIZE=1024', 'BLOCKYSIZE=1024']
#List that can be used for building commands
gdal_opt_co = []
[gdal_opt_co.extend(('-co', i)) for i in gdal_opt]

#Add methods to load ma from OpenCV, PIL, etc.
#These formats should be directly readable as np arrays

#Note: want to modify to import all bands as separate arrays in ndarray
#Unless the user requests a single band, or range of bands

#Check for file existence
[docs]def fn_check(fn): """Wrapper to check for file existence Parameters ---------- fn : str Input filename string. Returns ------- bool True if file exists, False otherwise. """ return os.path.exists(fn)
[docs]def fn_check_full(fn): """Check for file existence Avoids race condition, but slower than os.path.exists. Parameters ---------- fn : str Input filename string. Returns ------- status True if file exists, False otherwise. """ status = True if not os.path.isfile(fn): status = False else: try: open(fn) except IOError: status = False return status
[docs]def fn_list_check(fn_list): status = True for fn in fn_list: if not fn_check(fn): print('Unable to find: %s' % fn) status = False return status
[docs]def fn_list_valid(fn_list): print('%i input fn' % len(fn_list)) out_list = [] for fn in fn_list: if not fn_check(fn): print('Unable to find: %s' % fn) else: out_list.append(fn) print('%i output fn' % len(out_list)) return out_list
#Wrapper around gdal.Open
[docs]def fn_getds(fn): """Wrapper around gdal.Open() """ ds = None if fn_check(fn): ds = gdal.Open(fn, gdal.GA_ReadOnly) else: print("Unable to find %s" % fn) return ds
[docs]def fn_getma(fn, bnum=1, return_ds=False): """Get masked array from input filename Parameters ---------- fn : str Input filename string bnum : int, optional Band number Returns ------- np.ma.array Masked array containing raster values """ #Add check for filename existence ds = fn_getds(fn) out = ds_getma(ds, bnum=bnum) if return_ds: out = (out, ds) return out
#Given input dataset, return a masked array for the input band
[docs]def ds_getma(ds, bnum=1): """Get masked array from input GDAL Dataset Parameters ---------- ds : gdal.Dataset Input GDAL Datset bnum : int, optional Band number Returns ------- np.ma.array Masked array containing raster values """ b = ds.GetRasterBand(bnum) return b_getma(b)
#Given input band, return a masked array
[docs]def b_getma(b): """Get masked array from input GDAL Band Parameters ---------- b : gdal.Band Input GDAL Band Returns ------- np.ma.array Masked array containing raster values """ b_ndv = get_ndv_b(b) #bma = np.ma.masked_equal(b.ReadAsArray(), b_ndv) #This is more appropriate for float, handles precision issues bma = np.ma.masked_values(b.ReadAsArray(), b_ndv) return bma
[docs]def get_sub_dim(src_ds, scale=None, maxdim=1024): """Compute dimensions of subsampled dataset Parameters ---------- ds : gdal.Dataset Input GDAL Datset scale : int, optional Scaling factor maxdim : int, optional Maximum dimension along either axis, in pixels Returns ------- ns Numper of samples in subsampled output nl Numper of lines in subsampled output scale Final scaling factor """ ns = src_ds.RasterXSize nl = src_ds.RasterYSize maxdim = float(maxdim) if scale is None: scale_ns = ns/maxdim scale_nl = nl/maxdim scale = max(scale_ns, scale_nl) #Need to check to make sure scale is positive real elif scale > 1: ns = int(round(ns/scale)) nl = int(round(nl/scale)) return ns, nl, scale
[docs]def fn_getma_sub(fn, bnum=1, scale=None, maxdim=1024., return_ds=False): ds = gdal.Open(fn) return ds_getma_sub(ds, bnum=bnum, scale=scale, maxdim=maxdim, return_ds=return_ds)
#Load a subsampled array #Can specify scale factor or max dimension #No need to load the entire dataset for stats computation
[docs]def ds_getma_sub(src_ds, bnum=1, scale=None, maxdim=1024., return_ds=False): """Load a subsampled array, rather than full resolution This is useful when working with large rasters Uses buf_xsize and buf_ysize options from GDAL ReadAsArray method. Parameters ---------- ds : gdal.Dataset Input GDAL Datset bnum : int, optional Band number scale : int, optional Scaling factor maxdim : int, optional Maximum dimension along either axis, in pixels Returns ------- np.ma.array Masked array containing raster values """ #print src_ds.GetFileList()[0] b = src_ds.GetRasterBand(bnum) b_ndv = get_ndv_b(b) ns, nl, scale = get_sub_dim(src_ds, scale, maxdim) #The buf_size parameters determine the final array dimensions b_array = b.ReadAsArray(buf_xsize=ns, buf_ysize=nl) bma = np.ma.masked_values(b_array, b_ndv) out = bma if return_ds: dtype = src_ds.GetRasterBand(1).DataType src_ds_sub = gdal.GetDriverByName('MEM').Create('', ns, nl, 1, dtype) gt = np.array(src_ds.GetGeoTransform()) gt[[1,5]] = gt[[1,5]]*scale src_ds_sub.SetGeoTransform(list(gt)) src_ds_sub.SetProjection(src_ds.GetProjection()) b = src_ds_sub.GetRasterBand(1) b.WriteArray(bma) b.SetNoDataValue(b_ndv) out = (bma, src_ds_sub) return out
#Note: need to consolidate with warplib.writeout (takes ds, not ma) #Add option to build overviews when writing GTiff #Input proj must be WKT
[docs]def writeGTiff(a, dst_fn, src_ds=None, bnum=1, ndv=None, gt=None, proj=None, create=False, sparse=False): """Write input array to disk as GeoTiff Parameters ---------- a : np.array or np.ma.array Input array dst_fn : str Output filename src_ds: GDAL Dataset, optional Source Dataset to use for creating copy bnum : int, optional Output band ndv : float, optional Output NoData Value gt : list, optional Output GeoTransform proj : str, optional Output Projection (OGC WKT or PROJ.4 format) create : bool, optional Create new dataset sparse : bool, optional Output should be created with sparse options """ #If input is not np.ma, this creates a new ma, which has default filL_value of 1E20 #Must manually override with ndv #Also consumes a lot of memory #Should bypass if input is bool from pygeotools.lib.malib import checkma a = checkma(a, fix=False) #Want to preserve fill_value if already specified if ndv is not None: a.set_fill_value(ndv) driver = gtif_drv #Currently only support writing singleband rasters #if a.ndim > 2: # np_nbands = a.shape[2] # if src_ds.RasterCount np_nbands: # for bnum in np_nbands: nbands = 1 np_dt = a.dtype.name if src_ds is not None: #If this is a fn, get a ds #Note: this saves a lot of unnecessary iolib.fn_getds calls if isinstance(src_ds, str): src_ds = fn_getds(src_ds) #if isinstance(src_ds, gdal.Dataset): src_dt = gdal.GetDataTypeName(src_ds.GetRasterBand(bnum).DataType) src_gt = src_ds.GetGeoTransform() #This is WKT src_proj = src_ds.GetProjection() #src_srs = osr.SpatialReference() #src_srs.ImportFromWkt(src_ds.GetProjectionRef()) #Probably a cleaner way to handle this if gt is None: gt = src_gt if proj is None: proj = src_proj #Need to create a new copy of the default options opt = list(gdal_opt) #Note: packbits is better for sparse data if sparse: opt.remove('COMPRESS=LZW') opt.append('COMPRESS=PACKBITS') #Not sure if VW can handle sparse tif #opt.append('SPARSE_OK=TRUE') #Use predictor=3 for floating point data if 'float' in np_dt.lower() and 'COMPRESS=LZW' in opt: opt.append('PREDICTOR=3') #If input ma is same as src_ds, write out array using CreateCopy from existing dataset #if not create and (src_ds is not None) and ((a.shape[0] == src_ds.RasterYSize) and (a.shape[1] == src_ds.RasterXSize) and (np_dt.lower() == src_dt.lower())): #Should compare srs.IsSame(src_srs) if not create and (src_ds is not None) and ((a.shape[0] == src_ds.RasterYSize) and (a.shape[1] == src_ds.RasterXSize) and (np_dt.lower() == src_dt.lower())) and (src_gt == gt) and (src_proj == proj): #Note: third option is strict flag, set to false dst_ds = driver.CreateCopy(dst_fn, src_ds, 0, options=opt) #Otherwise, use Create else: a_dtype = a.dtype gdal_dtype = np2gdal_dtype(a_dtype) if a_dtype.name == 'bool': #Set ndv to 0 a.fill_value = False opt.remove('COMPRESS=LZW') opt.append('COMPRESS=DEFLATE') #opt.append('NBITS=1') #Create(fn, nx, ny, nbands, dtype, opt) dst_ds = driver.Create(dst_fn, a.shape[1], a.shape[0], nbands, gdal_dtype, options=opt) #Note: Need GeoMA here to make this work, or accept gt as argument #Could also do ds creation in calling script if gt is not None: dst_ds.SetGeoTransform(gt) if proj is not None: dst_ds.SetProjection(proj) dst_ds.GetRasterBand(bnum).WriteArray(a.filled()) dst_ds.GetRasterBand(bnum).SetNoDataValue(float(a.fill_value)) dst_ds = None
[docs]def writevrt(out_csv,srs='EPSG:4326',x='field_1',y='field_2'): """ Write out a vrt to accompany a csv of points """ out_vrt = os.path.splitext(out_csv)[0]+'.vrt' out_csv = os.path.split(out_csv)[-1] f = open(out_vrt, 'w') f.write('<OGRVRTDataSource>\n') f.write(' <OGRVRTLayer name="%s">\n' % os.path.splitext(out_csv)[0]) f.write(' <SrcDataSource>%s</SrcDataSource>\n' % out_csv) f.write(' <GeometryType>wkbPoint</GeometryType>\n') f.write(' <LayerSRS>%s</LayerSRS>\n' % srs) f.write(' <GeometryField encoding="PointFromColumns" x="%s" y="%s"/>\n' % (x, y)) f.write(' </OGRVRTLayer>\n') f.write('</OGRVRTDataSource>\n') f.close()
#Move to geolib? #Look up equivalent GDAL data type
[docs]def np2gdal_dtype(d): """ Get GDAL RasterBand datatype that corresponds with NumPy datatype Input should be numpy array or numpy dtype """ dt_dict = gdal_array.codes if isinstance(d, (np.ndarray, np.generic)): d = d.dtype #This creates dtype from another built-in type #d = np.dtype(d) if isinstance(d, np.dtype): if d.name == 'int8': gdal_dt = 1 elif d.name == 'bool': #Write out as Byte gdal_dt = 1 else: gdal_dt = list(dt_dict.keys())[list(dt_dict.values()).index(d)] else: print("Input must be NumPy array or NumPy dtype") gdal_dt = None return gdal_dt
[docs]def gdal2np_dtype(b): """ Get NumPy datatype that corresponds with GDAL RasterBand datatype Input can be filename, GDAL Dataset, GDAL RasterBand, or GDAL integer dtype """ dt_dict = gdal_array.codes if isinstance(b, str): b = gdal.Open(b) if isinstance(b, gdal.Dataset): b = b.GetRasterBand(1) if isinstance(b, gdal.Band): b = b.DataType if isinstance(b, int): np_dtype = dt_dict[b] else: np_dtype = None print("Input must be GDAL Dataset or RasterBand object") return np_dtype
#Replace nodata value in GDAL band
[docs]def replace_ndv(b, new_ndv): b_ndv = get_ndv_b(b) bma = np.ma.masked_values(b.ReadAsArray(), b_ndv) bma.set_fill_value(new_ndv) b.WriteArray(bma.filled()) b.SetNoDataValue(new_ndv) return b
[docs]def set_ndv(dst_fn, ndv): dst_ds = gdal.Open(dst_fn, gdal.GA_Update) for n in range(1, dst_ds.RasterCount+1): b = dst_ds.GetRasterBand(1) b.SetNoDataValue(ndv) dst_ds = None
#Should overload these functions to handle fn, ds, or b #Perhaps abstract, as many functions will need this functionality
[docs]def get_ndv_fn(fn): ds = gdal.Open(fn, gdal.GA_ReadOnly) return get_ndv_ds(ds)
#Want to modify to handle multi-band images and return list of ndv
[docs]def get_ndv_ds(ds, bnum=1): b = ds.GetRasterBand(bnum) return get_ndv_b(b)
#Return nodata value for GDAL band
[docs]def get_ndv_b(b): """Get NoData value for GDAL band. If NoDataValue is not set in the band, extract upper left and lower right pixel values. Otherwise assume NoDataValue is 0. Parameters ---------- b : GDALRasterBand object This is the input band. Returns ------- b_ndv : float NoData value """ b_ndv = b.GetNoDataValue() if b_ndv is None: #Check ul pixel for ndv ns = b.XSize nl = b.YSize ul = float(b.ReadAsArray(0, 0, 1, 1)) #ur = float(b.ReadAsArray(ns-1, 0, 1, 1)) lr = float(b.ReadAsArray(ns-1, nl-1, 1, 1)) #ll = float(b.ReadAsArray(0, nl-1, 1, 1)) #Probably better to use 3/4 corner criterion #if ul == ur == lr == ll: if np.isnan(ul) or ul == lr: b_ndv = ul else: #Assume ndv is 0 b_ndv = 0 elif np.isnan(b_ndv): b_dt = gdal.GetDataTypeName(b.DataType) if 'Float' in b_dt: b_ndv = np.nan else: b_ndv = 0 return b_ndv
#Write out a recarray as a csv
[docs]def write_recarray(outfn, ra): with open(outfn,'w') as f: f.write(','.join([str(item) for item in ra.dtype.names])+'\n') for row in ra: f.write(','.join([str(item) for item in row])+'\n')
#Check to make sure image doesn't contain errors
[docs]def image_check(fn): ds = gdal.Open(fn) status = True for i in range(ds.RasterCount): ds.GetRasterBand(i+1).Checksum() if gdal.GetLastErrorType() != 0: status = False return status
#Return number of CPUs #Logical is "virtual" cpu count with hyperthreading #Set to False for physical cpu count
[docs]def cpu_count(logical=True): """Return system CPU count """ if logical: from multiprocessing import cpu_count ncpu=cpu_count() else: import psutil ncpu=psutil.cpu_count(logical=False) return ncpu
[docs]def setstripe(dir, threads=cpu_count()): #import socket #if 'nasa' in socket.getfqdn(): #Better to use 'df -T' to determine filesystem of directory #Can do this with psutil Python lib, but need to also find mount point of file if dir is not None: if 'lustre' in str(subprocess.check_output(['df','-T'])): if os.path.exists(dir): if threads is None: threads = cpu_count() cmd = ['lfs', 'setstripe', dir, '-c', str(threads)] print(' '.join(cmd)) subprocess.call(cmd)
#This is a shared directory for files like LULC, used by multiple tools #Default location is $HOME/data #Can specify in ~/.bashrc or ~/.profile #export DATADIR=$HOME/data
[docs]def get_datadir(): default_datadir = os.path.join(os.path.expanduser('~'), 'data') datadir = os.environ.get('DATADIR', default_datadir) if not os.path.exists(datadir): os.makedirs(datadir) return datadir
#Function to get files using urllib #This works with ftp
[docs]def getfile(url, outdir=None): """Function to fetch files using urllib Works with ftp """ fn = os.path.split(url)[-1] if outdir is not None: fn = os.path.join(outdir, fn) if not os.path.exists(fn): #Find appropriate urlretrieve for Python 2 and 3 try: from urllib.request import urlretrieve except ImportError: from urllib import urlretrieve print("Retrieving: %s" % url) #Add progress bar urlretrieve(url, fn) return fn
#Function to get files using requests #Works with https authentication
[docs]def getfile2(url, auth=None, outdir=None): """Function to fetch files using requests Works with https authentication """ import requests print("Retrieving: %s" % url) fn = os.path.split(url)[-1] if outdir is not None: fn = os.path.join(outdir, fn) if auth is not None: r = requests.get(url, stream=True, auth=auth) else: r = requests.get(url, stream=True) chunk_size = 1000000 with open(fn, 'wb') as fd: for chunk in r.iter_content(chunk_size): fd.write(chunk)
#Get necessary credentials to access MODSCAG products - hopefully this will soon be archived with NSIDC
[docs]def get_auth(): """Get authorization token for https """ import getpass from requests.auth import HTTPDigestAuth #This binds raw_input to input for Python 2 input_func = input try: input_func = raw_input except NameError: pass uname = input_func("MODSCAG Username:") pw = getpass.getpass("MODSCAG Password:") auth = HTTPDigestAuth(uname, pw) #wget -A'h8v4*snow_fraction.tif' --user=uname --password=pw return auth
[docs]def readcsv(fn): """ Wrapper to read arbitrary csv, check for header Needs some work to be more robust, quickly added for demcoreg sampling """ import csv #Check first line for header with open(fn, 'r') as f: reader = csv.DictReader(f) hdr = reader.fieldnames #Assume there is a header on first line, check skiprows = 1 if np.all(f.isdigit() for f in hdr): hdr = None skiprows = 0 #Check header for lat/lon/z or x/y/z tags #Should probably do genfromtxt here if header exists and dtype of cols is variable pts = np.loadtxt(fn, delimiter=',', skiprows=skiprows, dtype=None) return pts