Skip to content

Add a string template tag handler for securely composing queries. #1926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add a string template tag handler for securely composing queries.
This is a rough draft.  It is probably not suitable in its current
form.

https://nodesecroadmap.fyi/chapter-7/query-langs.html describes
this approach as part of a larger discussion about library support
for safe coding practices.

This enables

    connection.query`SELECT * FROM T WHERE x = ${x}, y = ${y}, z = ${z}`(callback)

and similar idioms.
  • Loading branch information
mikesamuel committed Jan 23, 2018
commit b09edcc4086bb399cbb98ccfd618352ac065ea37
41 changes: 32 additions & 9 deletions index.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
var Classes = Object.create(null);
var calledAsTemplateTagQuick = require('template-tag-common')
.calledAsTemplateTagQuick;

/**
* Create a new Connection instance.
Expand Down Expand Up @@ -46,10 +48,18 @@ exports.createPoolCluster = function createPoolCluster(config) {
* @return {Query} New query object
* @public
*/
exports.createQuery = function createQuery(sql, values, callback) {
exports.createQuery = function createQuery(...args) {
var Connection = loadClass('Connection');

return Connection.createQuery(sql, values, callback);
if (calledAsTemplateTagQuick(args[0], args.length)) {
var Template = loadClass('Template');
const sqlFragment = Template.sql(...args);
return function (callback) {
return Connection.createQuery(sqlFragment.content, [], callback);
};
} else {
let [ sql, values, callback ] = args
return Connection.createQuery(sql, values, callback);
}
};

/**
Expand Down Expand Up @@ -106,14 +116,24 @@ exports.raw = function raw(sql) {
return SqlString.raw(sql);
};

/**
* The type constants.
* @public
*/
Object.defineProperty(exports, 'Types', {
get: loadClass.bind(null, 'Types')
Object.defineProperties(exports, {
/**
* The type constants.
* @public
*/
'Types': {
get: loadClass.bind(null, 'Types')
},
/**
* The SQL template tag.
* @public
*/
'sql': {
get: loadClass.bind(null, 'Template')
}
});


/**
* Load the given class.
* @param {string} className Name of class to default
Expand Down Expand Up @@ -147,6 +167,9 @@ function loadClass(className) {
case 'SqlString':
Class = require('./lib/protocol/SqlString');
break;
case 'Template':
Class = require('./lib/Template');
break;
case 'Types':
Class = require('./lib/protocol/constants/types');
break;
Expand Down
232 changes: 232 additions & 0 deletions lib/Template.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
const Mysql = require('../index')
const {
memoizedTagFunction,
trimCommonWhitespaceFromLines,
TypedString
} = require('template-tag-common')

// A simple lexer for SQL.
// SQL has many divergent dialects with subtly different
// conventions for string escaping and comments.
// This just attempts to roughly tokenize MySQL's specific variant.
// See also
// https://www.w3.org/2005/05/22-SPARQL-MySQL/sql_yacc
// https://github.com/twitter/mysql/blob/master/sql/sql_lex.cc
// https://dev.mysql.com/doc/refman/5.7/en/string-literals.html

// "--" followed by whitespace starts a line comment
// "#"
// "/*" starts an inline comment ended at first "*/"
// \N means null
// Prefixed strings x'...' is a hex string, b'...' is a binary string, ....
// '...', "..." are strings. `...` escapes identifiers.
// doubled delimiters and backslash both escape
// doubled delimiters work in `...` identifiers

const PREFIX_BEFORE_DELIMITER = new RegExp(
'^(?:' +
(
// Comment
'--(?=[\\t\\r\\n ])[^\\r\\n]*' +
'|#[^\\r\\n]*' +
'|/[*][\\s\\S]*?[*]/'
) +
'|' +
(
// Run of non-comment non-string starts
'(?:[^\'"`\\-/#]|-(?!-)|/(?![*]))'
) +
')*')
const DELIMITED_BODIES = {
'\'': /^(?:[^'\\]|\\[\s\S]|'')*/,
'"': /^(?:[^"\\]|\\[\s\S]|"")*/,
'`': /^(?:[^`\\]|\\[\s\S]|``)*/
}

/** Template tag that creates a new Error with a message. */
function msg (strs, ...dyn) {
let message = String(strs[0])
for (let i = 0; i < dyn.length; ++i) {
message += JSON.stringify(dyn[i]) + strs[i + 1]
}
return message
}

/**
* Returns a function that can be fed chunks of input and which
* returns a delimiter context.
*/
function makeLexer () {
let errorMessage = null
let delimiter = null
return (text) => {
if (errorMessage) {
// Replay the error message if we've already failed.
throw new Error(errorMessage)
}
text = String(text)
while (text) {
const pattern = delimiter
? DELIMITED_BODIES[delimiter]
: PREFIX_BEFORE_DELIMITER
const match = pattern.exec(text)
if (!match) {
throw new Error(
errorMessage = msg`Failed to lex starting at ${text}`)
}
let nConsumed = match[0].length
if (text.length > nConsumed) {
const chr = text.charAt(nConsumed)
if (delimiter) {
if (chr === delimiter) {
delimiter = null
++nConsumed
} else {
throw new Error(
errorMessage = msg`Expected ${chr} at ${text}`)
}
} else if (Object.hasOwnProperty.call(DELIMITED_BODIES, chr)) {
delimiter = chr
++nConsumed
} else {
throw new Error(
errorMessage = msg`Expected delimiter at ${text}`)
}
}
text = text.substring(nConsumed)
}
return delimiter
}
}

/** A string wrapper that marks its content as a SQL identifier. */
class Identifier extends TypedString {}

/**
* A string wrapper that marks its content as a series of
* well-formed SQL tokens.
*/
class SqlFragment extends TypedString {}

/**
* Analyzes the static parts of the tag content.
*
* @return An record like { delimiters, chunks }
* where delimiter is a contextual cue and chunk is
* the adjusted raw text.
*/
function computeStatic (strings) {
const { raw } = trimCommonWhitespaceFromLines(strings)

const delimiters = []
const chunks = []

const lexer = makeLexer()

let delimiter = null
for (let i = 0, len = raw.length; i < len; ++i) {
let chunk = String(raw[i])
if (delimiter === '`') {
// Treat raw \` in an identifier literal as an ending delimiter.
chunk = chunk.replace(/^([^\\`]|\\[\s\S])*\\`/, '$1`')
}
const newDelimiter = lexer(chunk)
if (newDelimiter === '`' && !delimiter) {
// Treat literal \` outside a string context as starting an
// identifier literal
chunk = chunk.replace(
/((?:^|[^\\])(?:\\\\)*)\\(`(?:[^`\\]|\\[\s\S])*)$/, '$1$2')
}

chunks.push(chunk)
delimiters.push(newDelimiter)
delimiter = newDelimiter
}

if (delimiter) {
throw new Error(`Unclosed quoted string: ${delimiter}`)
}

return { raw, delimiters, chunks }
}

function interpolateSqlIntoFragment (
{ raw, delimiters, chunks }, strings, values) {
// A buffer to accumulate output.
let [ result ] = chunks
for (let i = 1, len = raw.length; i < len; ++i) {
const chunk = chunks[i]
// The count of values must be 1 less than the surrounding
// chunks of literal text.
if (i !== 0) {
const delimiter = delimiters[i - 1]
const value = values[i - 1]
if (delimiter) {
result += escapeDelimitedValue(value, delimiter)
} else {
result = appendValue(result, value, chunk)
}
}

result += chunk
}

return new SqlFragment(result)
}

function escapeDelimitedValue (value, delimiter) {
if (delimiter === '`') {
return Mysql.escapeId(String(value)).replace(/^`|`$/g, '')
}
const escaped = Mysql.escape(String(value))
return escaped.substring(1, escaped.length - 1)
}

function appendValue (resultBefore, value, chunk) {
let needsSpace = false
let result = resultBefore
const valueArray = Array.isArray(value) ? value : [ value ]
for (let i = 0, nValues = valueArray.length; i < nValues; ++i) {
if (i) {
result += ', '
}

const one = valueArray[i]
let valueStr = null
if (one instanceof SqlFragment) {
if (!/(?:^|[\n\r\t ,\x28])$/.test(result)) {
result += ' '
}
valueStr = one.toString()
needsSpace = i + 1 === nValues
} else if (one instanceof Identifier) {
valueStr = Mysql.escapeId(one.toString())
} else {
// If we need to handle nested arrays, we would recurse here.
valueStr = Mysql.format('?', one)
}
result += valueStr
}

if (needsSpace && chunk && !/^[\n\r\t ,\x29]/.test(chunk)) {
result += ' '
}

return result
}

/**
* Template tag function that contextually autoescapes values
* producing a SqlFragment.
*/
const sql = memoizedTagFunction(computeStatic, interpolateSqlIntoFragment)
sql.Identifier = Identifier
sql.Fragment = SqlFragment

module.exports = sql

if (global.test) {
// Expose for testing.
// Harmless if this leaks
exports.makeLexer = makeLexer
}
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"bignumber.js": "4.0.4",
"readable-stream": "2.3.3",
"safe-buffer": "5.1.1",
"sqlstring": "2.3.0"
"sqlstring": "2.3.0",
"template-tag-common": "1.0.8"
},
"devDependencies": {
"after": "0.8.2",
Expand Down
1 change: 1 addition & 0 deletions test/common.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ common.Parser = require(common.lib + '/protocol/Parser');
common.PoolConfig = require(common.lib + '/PoolConfig');
common.PoolConnection = require(common.lib + '/PoolConnection');
common.SqlString = require(common.lib + '/protocol/SqlString');
common.Template = require(common.lib + '/Template');
common.Types = require(common.lib + '/protocol/constants/types');

var Mysql = require(path.resolve(common.lib, '../index'));
Expand Down
14 changes: 7 additions & 7 deletions test/integration/connection/test-query.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ var common = require('../../common');
common.getTestConnection(function (err, connection) {
assert.ifError(err);

connection.query('SELECT 1', function (err, rows, fields) {
function callback (err, rows, fields) {
assert.ifError(err);
assert.deepEqual(rows, [{1: 1}]);
assert.equal(fields[0].name, '1');
});
}

connection.query({ sql: 'SELECT ?' }, [ 1 ], function (err, rows, fields) {
assert.ifError(err);
assert.deepEqual(rows, [{1: 1}]);
assert.equal(fields[0].name, '1');
});
connection.query('SELECT 1', callback);

connection.query({ sql: 'SELECT ?' }, [ 1 ], callback);

connection.query`SELECT ${ 1 }`(callback);

connection.end(assert.ifError);
});
Loading