diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ec47f62 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/venv +/migrations +.env \ No newline at end of file diff --git a/Procfile b/Procfile new file mode 100644 index 0000000..c65d50e --- /dev/null +++ b/Procfile @@ -0,0 +1 @@ +web: waitress-serve --port=$PORT flasky:app \ No newline at end of file diff --git a/README.md b/README.md index c89f7a6..3c37217 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,34 @@ -# flask-empty-project-shell -An empty project shell for structuring new Flask projects +# Flask API + +This is designed as a basic json API to serve **products** and **orders** resources. + +## API resources + +| URL | Method | Description | +| -------------------|-----------|------------------------ | +| /products/ | GET | Get all products | +| /products/ | POST | Create new product | +| /products/ | DELETE | Delete product by id | +| /products/ | PUT | Update product by id | +| /orders/ | GET | Get all orders | +| /orders/ | DELETE | Delete order by id | +| /orders/ | PUT | Update order by id | + +## Heroku deployment + +### Configs + +``` +heroku config:set FLASK_APP=flasky.py +heroku config:set FLASK_CONFIG=heroku +``` +The Heroku Postgres add on will set DATABASE_URL. Make sure to set a SECRET_KEY also. + +### Set up database + +This project uses SQLAlchemy to define the database model. The script `manage.py` uses Flask-Migrate and Flask-Script to allow setup, migration and deployment of the model without losing existing data. +``` +heroku run bash +python manage.py db init +python manage.py db migrate +python manage.py db upgrade \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py index e69de29..0aefef3 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -0,0 +1,25 @@ +from flask import Flask +from config import config +from flask_sqlalchemy import SQLAlchemy +from flask_cors import CORS + + + +db = SQLAlchemy() + + +def create_app(config_name): + app = Flask(__name__) + app.config.from_object(config[config_name]) + config[config_name].init_app(app) + + db.init_app(app) + CORS(app, supports_credentials=True) + + from .main import main as main_blueprint + app.register_blueprint(main_blueprint) + + from .api import api as api_blueprint + app.register_blueprint(api_blueprint, url_prefix='/api/v1') + + return app diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..93df9f1 --- /dev/null +++ b/app/api/__init__.py @@ -0,0 +1,6 @@ +from flask import Blueprint + +api = Blueprint('api', __name__) + +from . import authentication, errors +from . import products, orders diff --git a/app/api/authentication.py b/app/api/authentication.py new file mode 100644 index 0000000..2ad48e1 --- /dev/null +++ b/app/api/authentication.py @@ -0,0 +1,47 @@ +from flask import g, jsonify +from flask_httpauth import HTTPBasicAuth +from flask_login import current_user +from .errors import unauthorized, forbidden +from . import api +from ..models import User + +auth = HTTPBasicAuth() + + +@auth.verify_password +def verify_password(name_or_token, password): + """Return True if login valid; Uses the User method verify_password; + If password is blank, token is assumed""" + print('login:', name_or_token, password) + print('user:', User.verify_auth_token(name_or_token)) + if name_or_token == '': + return False + if password == '': + g.current_user = User.verify_auth_token(name_or_token) + g.token_used = True + return g.current_user is not None + user = User.query.filter_by(username=name_or_token).first() + if not user: + return False + g.current_user = user + g.token_used = False + return user.verify_password(password) + +@auth.error_handler +def auth_error(): + return jsonify({'success': False}) + +@api.route('/login', methods=['POST']) +@auth.login_required +def get_token(): + print('login route') + if g.current_user.is_anonymous or g.token_used: + print('auth failed') + print(g.current_user.is_anonymous) + print(g.token_used) + return jsonify({'success': False}) + print('auth succeed') + print(jsonify({'token': g.current_user.generate_auth_token( + expiration=3600), 'expiration': 3600, 'success': True})) + return jsonify({'token': g.current_user.generate_auth_token( + expiration=3600), 'expiration': 3600, 'success': True}) diff --git a/app/api/decorators.py b/app/api/decorators.py new file mode 100644 index 0000000..d56462d --- /dev/null +++ b/app/api/decorators.py @@ -0,0 +1,14 @@ +from functools import wraps +from flask import g +from .errors import forbidden + + +# def permission_required(permission): +# def decorator(f): +# @wraps(f) +# def decorated_function(*args, **kwargs): +# if not g.current_user.can(permission): +# return forbidden('Insufficient permissions') +# return f(*args, **kwargs) +# return decorated_function +# return decorator diff --git a/app/api/errors.py b/app/api/errors.py new file mode 100644 index 0000000..13b45a1 --- /dev/null +++ b/app/api/errors.py @@ -0,0 +1,27 @@ +from flask import jsonify +from . import api + +class ValidationError(ValueError): + pass + +def bad_request(message): + response = jsonify({'error': 'bad request', 'message': message}) + response.status_code = 400 + return response + + +def unauthorized(message): + response = jsonify({'error': 'unauthorized', 'message': message}) + response.status_code = 401 + return response + + +def forbidden(message): + response = jsonify({'error': 'forbidden', 'message': message}) + response.status_code = 403 + return response + + +@api.errorhandler(ValidationError) +def validation_error(e): + return bad_request(e.args[0]) diff --git a/app/api/orders.py b/app/api/orders.py new file mode 100644 index 0000000..b915b25 --- /dev/null +++ b/app/api/orders.py @@ -0,0 +1,90 @@ +from . import api +from ..models import Order, OrderLine, Product +from .. import db +from flask import jsonify, request +from .authentication import auth + +import sqlalchemy +from sqlalchemy.sql import func, literal_column, select + +# Helper function: +# Pulls Postgres SQL function json_agg +# See - https://trvrm.github.io/using-sqlalchemy-and-postgres-functions-to-produce-json-tree-structures-from-sql-joins.html +def json_agg(table): + return func.json_agg(literal_column('"'+table.name+'"')) + +def order_details(db): + + OrderProducts = ( + db.session.query( + func.json_agg(func.json_build_object( + 'name', Product.name, + 'category', Product.category, + 'quantity', OrderLine.quantity + ).label('products')), + Order.order_id) + .group_by(Order.order_id) + ).cte('order_products') + + query = ( + db.session.query( + func.json_build_object( + 'order_id', Order.order_id, + 'name', Order.name, + 'address', Order.address, + 'city', Order.city, + 'state', Order.state, + 'zip', Order.zip, + 'country', Order.country, + 'quantity', OrderLine.quantity), + OrderProducts) + .join(OrderLine, OrderLine.order_id == Order.order_id) + .join(OrderProducts) + ) + # Common Table Expressions (CTEs) + results = query.all() + return results + +@api.route('/orders/', methods=['GET']) +def get_orders(): + orders = Order.query.all() + # TODO: add cart lines + import pprint + pprint.pprint(order_details(db)) + return jsonify([ + o.to_json() + for o in orders]) + +@api.route('/orders/', methods=['POST']) +def new_order(): + order = Order.from_json(request.json) + db.session.add(order) + db.session.flush() + order_id = order.order_id + for line in request.json.get('lines'): + order_line = OrderLine.add_line(order_id, line) + db.session.add(order_line) + db.session.commit() + return jsonify(order.to_json()), 201 + +@api.route('/orders/', methods=['DELETE']) +def delete_order(id): + order = Order.query.get_or_404(id) + db.session.delete(order) + db.session.commit() + return jsonify({"success": True}) + +@api.route('/orders/', methods=['PUT']) +def edit_order(id): + order = Order.query.get_or_404(id) + + order.order_id = request.json.get('order_id') + order.name = request.json.get('name') + order.address = request.json.get('address') + order.city = request.json.get('city') + order.state = request.json.get('state') + order.zip = request.json.get('zip') + order.country = request.json.get('country') + + db.session.add(order) + db.session.commit() diff --git a/app/api/products.py b/app/api/products.py new file mode 100644 index 0000000..0891e9a --- /dev/null +++ b/app/api/products.py @@ -0,0 +1,42 @@ +from . import api +from ..models import Product +from .. import db +from flask import jsonify, request +from .authentication import auth + + +@api.route('/products/', methods=['GET']) +def get_products(): + print('get_products') + products = Product.query.all() + return jsonify([p.to_json() for p in products]), 200 + +@api.route('/products/', methods=['POST']) +@auth.login_required +def new_product(): + product = Product.from_json(request.json) + db.session.add(product) + db.session.commit() + return jsonify(product.to_json()), 201 + +@api.route('/products/', methods=['DELETE']) +@auth.login_required +def delete_product(id): + product = Product.query.get_or_404(id) + db.session.delete(product) + db.session.commit() + return jsonify({"success": True}), 200 + +@api.route('/products/', methods=['PUT']) +@auth.login_required +def edit_product(id): + product = Product.query.get_or_404(id) + + product.name = request.json.get('name') + product.category = request.json.get('category') + product.description = request.json.get('description') + product.price = request.json.get('price') + + db.session.add(product) + db.session.commit() + return jsonify(product.to_json()), 200 diff --git a/app/main/__init__.py b/app/main/__init__.py index e69de29..9ca777c 100644 --- a/app/main/__init__.py +++ b/app/main/__init__.py @@ -0,0 +1,5 @@ +from flask import Blueprint + +main = Blueprint('main', __name__) + +from . import views, errors \ No newline at end of file diff --git a/app/main/errors.py b/app/main/errors.py new file mode 100644 index 0000000..0a88fd3 --- /dev/null +++ b/app/main/errors.py @@ -0,0 +1,10 @@ +from flask import render_template +from . import main + +@main.app_errorhandler(404) +def page_not_found(e): + return render_template('404.html'), 404 + +@main.app_errorhandler(500) +def internal_server_error(e): + return render_template('500.html'), 500 \ No newline at end of file diff --git a/app/main/views.py b/app/main/views.py new file mode 100644 index 0000000..628b546 --- /dev/null +++ b/app/main/views.py @@ -0,0 +1,7 @@ +from flask import render_template +from . import main + + +@main.route('/', methods=['GET']) +def index(): + return render_template('index.html') \ No newline at end of file diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..7ac57cf --- /dev/null +++ b/app/models.py @@ -0,0 +1,136 @@ +from app import db +from sqlalchemy import Column + +class Product(db.Model): + __tablename__ = 'products' + + product_id = Column(db.Integer(), primary_key = True) + name = Column(db.String(30)) + category = Column(db.String(50)) + description = Column(db.String(200)) + price = Column(db.Numeric(12, 2)) + + def to_json(self): + json_product = { + 'product_id': self.product_id, + 'name': self.name, + 'category': self.category, + 'description': self.description, + 'price': float(self.price or 0) # Decimal to float + } + return json_product + + @staticmethod + def from_json(json_product): + name = json_product.get('name') + category = json_product.get('category') + description = json_product.get('description') + price = json_product.get('price') + return Product(name=name, category=category, + description=description, price=price) + +class Order(db.Model): + __tablename__ = 'orders' + + order_id = Column(db.Integer(), primary_key = True) + name = Column(db.String(30)) + address = Column(db.String(100)) + city = Column(db.String(50)) + state = Column(db.String(20)) + zip = Column(db.String(7)) + country = Column(db.String(20)) + + def to_json(self): + json_order = { + 'order_id': self.order_id, + 'name': self.name, + 'address': self.address, + 'city': self.city, + 'state': self.state, + 'zip': self.zip, + 'country': self.country + } + return json_order + + @staticmethod + def from_json(json_order): + + order_id = json_order.get('order_id') + name = json_order.get('name') + address = json_order.get('address') + city = json_order.get('city') + state = json_order.get('state') + zip = json_order.get('zip') + country = json_order.get('country') + + return Order(order_id=order_id, name=name, address=address, + city=city, state=state, zip=zip, country=country) + +class OrderLine(db.Model): + __tablename__ = 'order_lines' + + order_line_id = Column(db.Integer(), primary_key = True) + order_id = Column(db.Integer(), db.ForeignKey('orders.order_id')) + product_id = Column(db.Integer(), db.ForeignKey('products.product_id')) + quantity = Column(db.Integer()) + + @staticmethod + def add_line(order_id, line): + product_id = line.get('product_id') + quantity = line.get('quantity') + + return OrderLine(order_id=order_id, product_id=product_id, + quantity=quantity) + + def to_json(self): + json_order_line = { + 'order_line_id': self.order_line_id, + 'order_id': self.order_id, + 'product_id': self.product_id, + 'quantity': self.quantity + } + return json_order_line + +# User model +# + +from werkzeug.security import generate_password_hash, check_password_hash +from itsdangerous import TimedJSONWebSignatureSerializer as Serializer +from flask import current_app +from flask_login import UserMixin + +class User(UserMixin, db.Model): + __tablename__ = 'users' + user_id = Column(db.Integer(), primary_key = True) + username = Column(db.String(50)) + email = Column(db.String(50)) + password_hash = Column(db.String(128)) + confirmed = Column(db.Boolean, default=False) + + @property + def password(self): + """Prevent reading of password setter""" + raise AttributeError('password is not a readable attribute') + + @password.setter + def password(self, password): + self.password_hash = generate_password_hash(password) + + def verify_password(self, password): + return check_password_hash(self.password_hash, password) + + def generate_auth_token(self, expiration): + """Generate signed token that encodes user_id""" + s = Serializer(current_app.config['SECRET_KEY'], + expires_in=expiration) + return s.dumps({'id': self.user_id}).decode('utf-8') + + @staticmethod + def verify_auth_token(token): + s = Serializer(current_app.config['SECRET_KEY']) + try: + data = s.loads(token) + except: + return None + print(data) + return User.query.get(data['id']) diff --git a/app/templates/404.html b/app/templates/404.html new file mode 100644 index 0000000..57db2e9 --- /dev/null +++ b/app/templates/404.html @@ -0,0 +1 @@ +404 \ No newline at end of file diff --git a/app/templates/500.html b/app/templates/500.html new file mode 100644 index 0000000..eb1f494 --- /dev/null +++ b/app/templates/500.html @@ -0,0 +1 @@ +500 \ No newline at end of file diff --git a/app/templates/index.html b/app/templates/index.html index e69de29..b2d525b 100644 --- a/app/templates/index.html +++ b/app/templates/index.html @@ -0,0 +1 @@ +index \ No newline at end of file diff --git a/config.py b/config.py index e69de29..3b80790 100644 --- a/config.py +++ b/config.py @@ -0,0 +1,46 @@ +import os +basedir = os.path.abspath(os.path.dirname(__file__)) + +class Config: + SECRET_KEY = os.environ.get('SECRET_KEY') or 'hard to guess string' + SQLALCHEMY_TRACK_MODIFICATIONS = False + SSL_REDIRECT = False + + @staticmethod + def init_app(app): + pass + +class DevelopmentConfig(Config): + DEBUG = True + SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') + +class TestingConfig(Config): + print('Testing config') + TESTING = True + SQLALCHEMY_DATABASE_URI = os.environ.get('TEST_DATABASE_URL') + +class ProductionConfig(Config): + SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') + +class HerokuConfig(ProductionConfig): + @classmethod + def init_app(cls, app): + ProductionConfig.init_app(app) + + import logging + from logging import StreamHandler + file_handler = StreamHandler() + file_handler.setLevel(logging.INFO) + app.logger.addHandler(file_handler) + + logging.getLogger('flask_cors').level = logging.DEBUG + + SSL_REDIRECT = True if os.environ.get('DYNO') else False + +config = { + 'development': DevelopmentConfig, + 'testing': TestingConfig, + 'production': ProductionConfig, + 'default': DevelopmentConfig, + 'heroku': HerokuConfig +} diff --git a/flasky.py b/flasky.py index e69de29..f1b7680 100644 --- a/flasky.py +++ b/flasky.py @@ -0,0 +1,5 @@ +import os +from app import create_app + +print('flasky.py', os.getenv('FLASK_CONFIG') or 'default') +app = create_app(os.getenv('FLASK_CONFIG') or 'default') diff --git a/manage.py b/manage.py new file mode 100644 index 0000000..1d2ad86 --- /dev/null +++ b/manage.py @@ -0,0 +1,20 @@ +# Intended use is along the lines of: +# > python manage.py db init +# > python manage.py db migrate +# > python manage.py db upgrade + +import os +from flask_script import Manager # class for handling a set of commands +from flask_migrate import Migrate, MigrateCommand +from app import db, create_app +from app import models + +app = create_app(config_name=os.getenv('FLASK_CONFIG')) +migrate = Migrate(app, db) +manager = Manager(app) + +manager.add_command('db', MigrateCommand) + + +if __name__ == '__main__': + manager.run() diff --git a/migrations/Migration Notes.txt b/migrations/Migration Notes.txt deleted file mode 100644 index e69de29..0000000 diff --git a/requirements.txt b/requirements.txt index e69de29..5ee9318 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,20 @@ +alembic==1.0.8 +Click==7.0 +Flask==1.0.2 +Flask-Cors==3.0.7 +Flask-HTTPAuth==3.2.4 +Flask-Login==0.4.1 +Flask-Migrate==2.4.0 +Flask-Script==2.0.6 +Flask-SQLAlchemy==2.3.2 +itsdangerous==1.1.0 +Jinja2==2.10 +Mako==1.0.8 +MarkupSafe==1.1.1 +psycopg2==2.7.7 +python-dateutil==2.8.0 +python-editor==1.0.4 +six==1.12.0 +SQLAlchemy==1.3.1 +waitress==1.2.1 +Werkzeug==0.15.1 diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..5160b3c --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,106 @@ +import unittest +import os +import json +from flask import current_app + +import sys +myPath = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, myPath + '/../') + +from app import create_app, db + +class APITestCase(unittest.TestCase): + """This class represents the API test case""" + + def setUp(self): + """Define test variables and initialize app.""" + self.app = create_app(config_name="testing") + self.app_context = self.app.app_context() + self.app_context.push() + self.client = self.app.test_client() + self.bucketlist = { + 'name': 'Oranges', + 'category': 'Food', + 'description': 'Round citrus fruit', + 'price': 5.59 + } + db.create_all() + + + def test_app_exists(self): + self.assertFalse(current_app is None) + + def test_api_creation(self): + """Test API can create a product (POST request)""" + res = self.client.post('/api/v1/products/', + data=json.dumps(self.bucketlist), + content_type='application/json') + self.assertEqual(res.status_code, 201) + self.assertIn('Round citrus fruit', str(res.data)) + + def test_api_can_get_all_products(self): + """Test API can get a product (GET request).""" + res = self.client.post('/api/v1/products/', + data=json.dumps(self.bucketlist), + content_type='application/json') + self.assertEqual(res.status_code, 201) + res = self.client.get('/api/v1/products/') + self.assertEqual(res.status_code, 200) + self.assertIn('Round citrus fruit', str(res.data)) + + + # not implemented + # def test_api_can_get_product_by_id(self): + # """Test API can get a single product by using it's id.""" + # rv = self.client.post('/api/v1/products/', + # data=json.dumps(self.bucketlist), + # content_type='application/json') + # self.assertEqual(rv.status_code, 201) + # result_in_json = json.loads(rv.data.decode('utf-8').replace("'", "\"")) + # print('id',result_in_json['product_id']) + # result = self.client.get( + # '/api/v1/products/{}'.format(result_in_json['product_id'])) + # self.assertEqual(result.status_code, 200) + # self.assertIn('Round citrus fruit', str(result.data)) + + def test_product_can_be_edited(self): + """Test API can edit an existing bucketlist. (PUT request)""" + rv = self.client.post('/api/v1/products/', + data=json.dumps(self.bucketlist), + content_type='application/json') + self.assertEqual(rv.status_code, 201) + rv = self.client.put( + '/api/v1/products/1', + data = json.dumps({ + 'name': 'Oranges', + 'category': 'Food', + 'description': 'Round bright citrus fruit', + 'price': 5.59 + }), content_type='application/json') + self.assertEqual(rv.status_code, 200) + results = self.client.get('/api/v1/products/') + self.assertIn('Round bright', str(results.data)) + + def test_product_deletion(self): + """Test API can delete an existing product. (DELETE request).""" + rv = self.client.post('/api/v1/products/', + data=json.dumps(self.bucketlist), + content_type='application/json') + self.assertEqual(rv.status_code, 201) + res = self.client.delete('/api/v1/products/1') + self.assertEqual(res.status_code, 200) + # Test to see if it exists, should return a 404 + result = self.client.get('/api/v1/products') + self.assertTrue('Round bright' not in str(result.data)) + #self.assertEqual(result.status_code, 404) + + def tearDown(self): + """teardown all initialized variables.""" + with self.app.app_context(): + # drop all tables + db.session.remove() + db.drop_all() + +# Make the tests conveniently executable +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_user_model.py b/tests/test_user_model.py new file mode 100644 index 0000000..dda0bf5 --- /dev/null +++ b/tests/test_user_model.py @@ -0,0 +1,32 @@ +import unittest +import os + +import sys +myPath = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, myPath + '/../') + +from app.models import User + +class UserModelTestCase(unittest.TestCase): + def test_password_setter(self): + u = User(password='cat') + self.assertTrue(u.password_hash is not None) + + def test_no_password_getter(self): + u = User(password='cat') + with self.assertRaises(AttributeError): + u.password + + def test_password_verification(self): + u = User(password='cat') + self.assertTrue(u.verify_password('cat')) + self.assertFalse(u.verify_password('dog')) + + def test_password_salts_are_random(self): + u1 = User(password='cat') + u2 = User(password='cat') + self.assertFalse(u1.password_hash == u2.password_hash) + +# Make the tests conveniently executable +if __name__ == "__main__": + unittest.main()