diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ec47f62 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/venv +/migrations +.env \ No newline at end of file diff --git a/Procfile b/Procfile new file mode 100644 index 0000000..c65d50e --- /dev/null +++ b/Procfile @@ -0,0 +1 @@ +web: waitress-serve --port=$PORT flasky:app \ No newline at end of file diff --git a/README.md b/README.md index c89f7a6..3c37217 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,34 @@ -# flask-empty-project-shell -An empty project shell for structuring new Flask projects +# Flask API + +This is designed as a basic json API to serve **products** and **orders** resources. + +## API resources + +| URL | Method | Description | +| -------------------|-----------|------------------------ | +| /products/ | GET | Get all products | +| /products/ | POST | Create new product | +| /products/ | DELETE | Delete product by id | +| /products/ | PUT | Update product by id | +| /orders/ | GET | Get all orders | +| /orders/ | DELETE | Delete order by id | +| /orders/ | PUT | Update order by id | + +## Heroku deployment + +### Configs + +``` +heroku config:set FLASK_APP=flasky.py +heroku config:set FLASK_CONFIG=heroku +``` +The Heroku Postgres add on will set DATABASE_URL. Make sure to set a SECRET_KEY also. + +### Set up database + +This project uses SQLAlchemy to define the database model. The script `manage.py` uses Flask-Migrate and Flask-Script to allow setup, migration and deployment of the model without losing existing data. +``` +heroku run bash +python manage.py db init +python manage.py db migrate +python manage.py db upgrade \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py index e69de29..a29c5ec 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -0,0 +1,23 @@ +from flask import Flask +from config import config +from flask_sqlalchemy import SQLAlchemy +from flask_cors import CORS + + +db = SQLAlchemy() + +def create_app(config_name): + app = Flask(__name__) + app.config.from_object(config[config_name]) + config[config_name].init_app(app) + + db.init_app(app) + CORS(app, supports_credentials=True) + + from .main import main as main_blueprint + app.register_blueprint(main_blueprint) + + from .api import api as api_blueprint + app.register_blueprint(api_blueprint, url_prefix='/api/v1') + + return app diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..93df9f1 --- /dev/null +++ b/app/api/__init__.py @@ -0,0 +1,6 @@ +from flask import Blueprint + +api = Blueprint('api', __name__) + +from . import authentication, errors +from . import products, orders diff --git a/app/api/authentication.py b/app/api/authentication.py new file mode 100644 index 0000000..2ad48e1 --- /dev/null +++ b/app/api/authentication.py @@ -0,0 +1,47 @@ +from flask import g, jsonify +from flask_httpauth import HTTPBasicAuth +from flask_login import current_user +from .errors import unauthorized, forbidden +from . import api +from ..models import User + +auth = HTTPBasicAuth() + + +@auth.verify_password +def verify_password(name_or_token, password): + """Return True if login valid; Uses the User method verify_password; + If password is blank, token is assumed""" + print('login:', name_or_token, password) + print('user:', User.verify_auth_token(name_or_token)) + if name_or_token == '': + return False + if password == '': + g.current_user = User.verify_auth_token(name_or_token) + g.token_used = True + return g.current_user is not None + user = User.query.filter_by(username=name_or_token).first() + if not user: + return False + g.current_user = user + g.token_used = False + return user.verify_password(password) + +@auth.error_handler +def auth_error(): + return jsonify({'success': False}) + +@api.route('/login', methods=['POST']) +@auth.login_required +def get_token(): + print('login route') + if g.current_user.is_anonymous or g.token_used: + print('auth failed') + print(g.current_user.is_anonymous) + print(g.token_used) + return jsonify({'success': False}) + print('auth succeed') + print(jsonify({'token': g.current_user.generate_auth_token( + expiration=3600), 'expiration': 3600, 'success': True})) + return jsonify({'token': g.current_user.generate_auth_token( + expiration=3600), 'expiration': 3600, 'success': True}) diff --git a/app/api/decorators.py b/app/api/decorators.py new file mode 100644 index 0000000..d56462d --- /dev/null +++ b/app/api/decorators.py @@ -0,0 +1,14 @@ +from functools import wraps +from flask import g +from .errors import forbidden + + +# def permission_required(permission): +# def decorator(f): +# @wraps(f) +# def decorated_function(*args, **kwargs): +# if not g.current_user.can(permission): +# return forbidden('Insufficient permissions') +# return f(*args, **kwargs) +# return decorated_function +# return decorator diff --git a/app/api/errors.py b/app/api/errors.py new file mode 100644 index 0000000..13b45a1 --- /dev/null +++ b/app/api/errors.py @@ -0,0 +1,27 @@ +from flask import jsonify +from . import api + +class ValidationError(ValueError): + pass + +def bad_request(message): + response = jsonify({'error': 'bad request', 'message': message}) + response.status_code = 400 + return response + + +def unauthorized(message): + response = jsonify({'error': 'unauthorized', 'message': message}) + response.status_code = 401 + return response + + +def forbidden(message): + response = jsonify({'error': 'forbidden', 'message': message}) + response.status_code = 403 + return response + + +@api.errorhandler(ValidationError) +def validation_error(e): + return bad_request(e.args[0]) diff --git a/app/api/orders.py b/app/api/orders.py new file mode 100644 index 0000000..39e5ebc --- /dev/null +++ b/app/api/orders.py @@ -0,0 +1,42 @@ +from marshmallow import ValidationError +from . import api +from ..models import Order, OrderLine, Product +from .. import db +from flask import jsonify, request +from .authentication import auth +from ..services import OrderListService + +from ..schema import OrderSchema, OrderLineSchema + +@api.route('/orders/', methods=['GET']) +def get_orders(): + service = OrderListService() + return jsonify(service.get()) + +@api.route('/orders/', methods=['POST']) +def new_order(): + schema = OrderSchema() + try: + order = schema.load(request.get_json()) + except ValidationError as err: + print('Validation Error: ', err.messages) + db.session.add(order.data) + db.session.commit() + return jsonify(schema.dump(order)), 201 + +@api.route('/orders/', methods=['DELETE']) +def delete_order(id): + order = Order.query.get_or_404(id) + db.session.delete(order) + db.session.commit() + return jsonify({"success": True}) + +@api.route('/orders/', methods=['PUT']) +def edit_order(id): + # get order + order = Order.query.get_or_404(id) + # update order and commit + schema = OrderSchema() + schema.load(request.get_json(), instance=order) + db.session.commit() + return jsonify(schema.dump(order)) diff --git a/app/api/products.py b/app/api/products.py new file mode 100644 index 0000000..0891e9a --- /dev/null +++ b/app/api/products.py @@ -0,0 +1,42 @@ +from . import api +from ..models import Product +from .. import db +from flask import jsonify, request +from .authentication import auth + + +@api.route('/products/', methods=['GET']) +def get_products(): + print('get_products') + products = Product.query.all() + return jsonify([p.to_json() for p in products]), 200 + +@api.route('/products/', methods=['POST']) +@auth.login_required +def new_product(): + product = Product.from_json(request.json) + db.session.add(product) + db.session.commit() + return jsonify(product.to_json()), 201 + +@api.route('/products/', methods=['DELETE']) +@auth.login_required +def delete_product(id): + product = Product.query.get_or_404(id) + db.session.delete(product) + db.session.commit() + return jsonify({"success": True}), 200 + +@api.route('/products/', methods=['PUT']) +@auth.login_required +def edit_product(id): + product = Product.query.get_or_404(id) + + product.name = request.json.get('name') + product.category = request.json.get('category') + product.description = request.json.get('description') + product.price = request.json.get('price') + + db.session.add(product) + db.session.commit() + return jsonify(product.to_json()), 200 diff --git a/app/main/__init__.py b/app/main/__init__.py index e69de29..9ca777c 100644 --- a/app/main/__init__.py +++ b/app/main/__init__.py @@ -0,0 +1,5 @@ +from flask import Blueprint + +main = Blueprint('main', __name__) + +from . import views, errors \ No newline at end of file diff --git a/app/main/errors.py b/app/main/errors.py new file mode 100644 index 0000000..0a88fd3 --- /dev/null +++ b/app/main/errors.py @@ -0,0 +1,10 @@ +from flask import render_template +from . import main + +@main.app_errorhandler(404) +def page_not_found(e): + return render_template('404.html'), 404 + +@main.app_errorhandler(500) +def internal_server_error(e): + return render_template('500.html'), 500 \ No newline at end of file diff --git a/app/main/views.py b/app/main/views.py new file mode 100644 index 0000000..628b546 --- /dev/null +++ b/app/main/views.py @@ -0,0 +1,7 @@ +from flask import render_template +from . import main + + +@main.route('/', methods=['GET']) +def index(): + return render_template('index.html') \ No newline at end of file diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..fd7921d --- /dev/null +++ b/app/models.py @@ -0,0 +1,93 @@ +from app import db +from sqlalchemy import Column +from sqlalchemy.orm import relationship + +class Product(db.Model): + __tablename__ = 'products' + + product_id = Column(db.Integer(), primary_key = True) + name = Column(db.String(30)) + category = Column(db.String(50)) + description = Column(db.String(200)) + price = Column(db.Numeric(12, 2)) + + # relationships + products = relationship("OrderLine", back_populates="product") + +class Order(db.Model): + __tablename__ = 'orders' + + def __init__(self, products_sold=None, *args, **kwargs): + super(Order, self).__init__(*args, **kwargs) + products_sold = products_sold or [] + for prod in products_sold: + self.products_sold.append(prod) + order_id = Column(db.Integer(), primary_key = True) + name = Column(db.String(30)) + address = Column(db.String(100)) + city = Column(db.String(50)) + state = Column(db.String(20)) + zip = Column(db.String(7)) + country = Column(db.String(20)) + + # relationships + products_sold = relationship("OrderLine", back_populates="order") + + + +class OrderLine(db.Model): + __tablename__ = 'order_lines' + + order_line_id = Column(db.Integer(), primary_key = True) + order_id = Column(db.Integer(), db.ForeignKey('orders.order_id')) + product_id = Column(db.Integer(), db.ForeignKey('products.product_id')) + quantity = Column(db.Integer()) + + # relationships + product = relationship("Product", back_populates="products") + order = relationship("Order", back_populates="products_sold") + + +# User model +# + +from werkzeug.security import generate_password_hash, check_password_hash +from itsdangerous import TimedJSONWebSignatureSerializer as Serializer +from flask import current_app +from flask_login import UserMixin + +class User(UserMixin, db.Model): + __tablename__ = 'users' + user_id = Column(db.Integer(), primary_key = True) + username = Column(db.String(50)) + email = Column(db.String(50)) + password_hash = Column(db.String(128)) + confirmed = Column(db.Boolean, default=False) + + @property + def password(self): + """Prevent reading of password setter""" + raise AttributeError('password is not a readable attribute') + + @password.setter + def password(self, password): + self.password_hash = generate_password_hash(password) + + def verify_password(self, password): + return check_password_hash(self.password_hash, password) + + def generate_auth_token(self, expiration): + """Generate signed token that encodes user_id""" + s = Serializer(current_app.config['SECRET_KEY'], + expires_in=expiration) + return s.dumps({'id': self.user_id}).decode('utf-8') + + @staticmethod + def verify_auth_token(token): + s = Serializer(current_app.config['SECRET_KEY']) + try: + data = s.loads(token) + except: + return None + print(data) + return User.query.get(data['id']) diff --git a/app/schema.py b/app/schema.py new file mode 100644 index 0000000..88b6ff6 --- /dev/null +++ b/app/schema.py @@ -0,0 +1,42 @@ +from marshmallow import fields, post_load +from marshmallow_sqlalchemy import ModelSchema + +from app import db +from .models import Order, Product, OrderLine + +class OrderSchema(ModelSchema): + products_sold = fields.Nested('OrderLineSchema', many=True) + class Meta(ModelSchema.Meta): + model = Order + sqla_session = db.session + + @post_load + def make_order(self, data): + if type(data) == Order: + return data + return Order(**data) + +class ProductSchema(ModelSchema): + price = fields.Float(data_key='price') # Decimal to float + class Meta(ModelSchema.Meta): + model = Product + sqla_session = db.session + + @post_load + def make_product(self, data): + if type(data) == Product: + return data + return Product(**data) + +class OrderLineSchema(ModelSchema): + product = fields.Nested('ProductSchema', exclude=('products',)) + class Meta(ModelSchema.Meta): + model = OrderLine + sqla_session = db.session + + @post_load + def make_order_line(self, data): + db.session.flush() + if type(data) == OrderLine: + return data + return OrderLine(**data) diff --git a/app/services.py b/app/services.py new file mode 100644 index 0000000..083409b --- /dev/null +++ b/app/services.py @@ -0,0 +1,28 @@ +from sqlalchemy.orm import joinedload, selectinload + +from .models import Order, Product, OrderLine +from .schema import OrderSchema, ProductSchema, OrderLineSchema + +from app import db + +class OrderListService: + ''' + This service intended for use exclusively by /api/orders + ''' + def __init__(self, _session=None): + # your unit tests can pass in _session=MagicMock() + self.session = _session or db.session + + def _parents(self): + return ( self.session.query(Order) + .options(selectinload(Order.products_sold)) + .all() ) + + def get(self): + # [{"address": "59 Arcubus Avenue", "city": "Sheffield", + # "country": "United Kingdom", "name": "Andrew", "order_id": 17, + # "products_sold": [{ "product": { "category": "Technology", + # "description": "A small computer", "name": "Calculator", + # "price": 15.35, "product_id": 1 }, "quantity": 10 }, ...]}] + schema = OrderSchema() + return schema.dump(self._parents(), many=True).data diff --git a/app/templates/404.html b/app/templates/404.html new file mode 100644 index 0000000..57db2e9 --- /dev/null +++ b/app/templates/404.html @@ -0,0 +1 @@ +404 \ No newline at end of file diff --git a/app/templates/500.html b/app/templates/500.html new file mode 100644 index 0000000..eb1f494 --- /dev/null +++ b/app/templates/500.html @@ -0,0 +1 @@ +500 \ No newline at end of file diff --git a/app/templates/index.html b/app/templates/index.html index e69de29..b2d525b 100644 --- a/app/templates/index.html +++ b/app/templates/index.html @@ -0,0 +1 @@ +index \ No newline at end of file diff --git a/config.py b/config.py index e69de29..3b80790 100644 --- a/config.py +++ b/config.py @@ -0,0 +1,46 @@ +import os +basedir = os.path.abspath(os.path.dirname(__file__)) + +class Config: + SECRET_KEY = os.environ.get('SECRET_KEY') or 'hard to guess string' + SQLALCHEMY_TRACK_MODIFICATIONS = False + SSL_REDIRECT = False + + @staticmethod + def init_app(app): + pass + +class DevelopmentConfig(Config): + DEBUG = True + SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') + +class TestingConfig(Config): + print('Testing config') + TESTING = True + SQLALCHEMY_DATABASE_URI = os.environ.get('TEST_DATABASE_URL') + +class ProductionConfig(Config): + SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') + +class HerokuConfig(ProductionConfig): + @classmethod + def init_app(cls, app): + ProductionConfig.init_app(app) + + import logging + from logging import StreamHandler + file_handler = StreamHandler() + file_handler.setLevel(logging.INFO) + app.logger.addHandler(file_handler) + + logging.getLogger('flask_cors').level = logging.DEBUG + + SSL_REDIRECT = True if os.environ.get('DYNO') else False + +config = { + 'development': DevelopmentConfig, + 'testing': TestingConfig, + 'production': ProductionConfig, + 'default': DevelopmentConfig, + 'heroku': HerokuConfig +} diff --git a/flasky.py b/flasky.py index e69de29..f1b7680 100644 --- a/flasky.py +++ b/flasky.py @@ -0,0 +1,5 @@ +import os +from app import create_app + +print('flasky.py', os.getenv('FLASK_CONFIG') or 'default') +app = create_app(os.getenv('FLASK_CONFIG') or 'default') diff --git a/manage.py b/manage.py new file mode 100644 index 0000000..1d2ad86 --- /dev/null +++ b/manage.py @@ -0,0 +1,20 @@ +# Intended use is along the lines of: +# > python manage.py db init +# > python manage.py db migrate +# > python manage.py db upgrade + +import os +from flask_script import Manager # class for handling a set of commands +from flask_migrate import Migrate, MigrateCommand +from app import db, create_app +from app import models + +app = create_app(config_name=os.getenv('FLASK_CONFIG')) +migrate = Migrate(app, db) +manager = Manager(app) + +manager.add_command('db', MigrateCommand) + + +if __name__ == '__main__': + manager.run() diff --git a/migrations/Migration Notes.txt b/migrations/Migration Notes.txt deleted file mode 100644 index e69de29..0000000 diff --git a/requirements.txt b/requirements.txt index e69de29..5ee9318 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,20 @@ +alembic==1.0.8 +Click==7.0 +Flask==1.0.2 +Flask-Cors==3.0.7 +Flask-HTTPAuth==3.2.4 +Flask-Login==0.4.1 +Flask-Migrate==2.4.0 +Flask-Script==2.0.6 +Flask-SQLAlchemy==2.3.2 +itsdangerous==1.1.0 +Jinja2==2.10 +Mako==1.0.8 +MarkupSafe==1.1.1 +psycopg2==2.7.7 +python-dateutil==2.8.0 +python-editor==1.0.4 +six==1.12.0 +SQLAlchemy==1.3.1 +waitress==1.2.1 +Werkzeug==0.15.1 diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..5160b3c --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,106 @@ +import unittest +import os +import json +from flask import current_app + +import sys +myPath = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, myPath + '/../') + +from app import create_app, db + +class APITestCase(unittest.TestCase): + """This class represents the API test case""" + + def setUp(self): + """Define test variables and initialize app.""" + self.app = create_app(config_name="testing") + self.app_context = self.app.app_context() + self.app_context.push() + self.client = self.app.test_client() + self.bucketlist = { + 'name': 'Oranges', + 'category': 'Food', + 'description': 'Round citrus fruit', + 'price': 5.59 + } + db.create_all() + + + def test_app_exists(self): + self.assertFalse(current_app is None) + + def test_api_creation(self): + """Test API can create a product (POST request)""" + res = self.client.post('/api/v1/products/', + data=json.dumps(self.bucketlist), + content_type='application/json') + self.assertEqual(res.status_code, 201) + self.assertIn('Round citrus fruit', str(res.data)) + + def test_api_can_get_all_products(self): + """Test API can get a product (GET request).""" + res = self.client.post('/api/v1/products/', + data=json.dumps(self.bucketlist), + content_type='application/json') + self.assertEqual(res.status_code, 201) + res = self.client.get('/api/v1/products/') + self.assertEqual(res.status_code, 200) + self.assertIn('Round citrus fruit', str(res.data)) + + + # not implemented + # def test_api_can_get_product_by_id(self): + # """Test API can get a single product by using it's id.""" + # rv = self.client.post('/api/v1/products/', + # data=json.dumps(self.bucketlist), + # content_type='application/json') + # self.assertEqual(rv.status_code, 201) + # result_in_json = json.loads(rv.data.decode('utf-8').replace("'", "\"")) + # print('id',result_in_json['product_id']) + # result = self.client.get( + # '/api/v1/products/{}'.format(result_in_json['product_id'])) + # self.assertEqual(result.status_code, 200) + # self.assertIn('Round citrus fruit', str(result.data)) + + def test_product_can_be_edited(self): + """Test API can edit an existing bucketlist. (PUT request)""" + rv = self.client.post('/api/v1/products/', + data=json.dumps(self.bucketlist), + content_type='application/json') + self.assertEqual(rv.status_code, 201) + rv = self.client.put( + '/api/v1/products/1', + data = json.dumps({ + 'name': 'Oranges', + 'category': 'Food', + 'description': 'Round bright citrus fruit', + 'price': 5.59 + }), content_type='application/json') + self.assertEqual(rv.status_code, 200) + results = self.client.get('/api/v1/products/') + self.assertIn('Round bright', str(results.data)) + + def test_product_deletion(self): + """Test API can delete an existing product. (DELETE request).""" + rv = self.client.post('/api/v1/products/', + data=json.dumps(self.bucketlist), + content_type='application/json') + self.assertEqual(rv.status_code, 201) + res = self.client.delete('/api/v1/products/1') + self.assertEqual(res.status_code, 200) + # Test to see if it exists, should return a 404 + result = self.client.get('/api/v1/products') + self.assertTrue('Round bright' not in str(result.data)) + #self.assertEqual(result.status_code, 404) + + def tearDown(self): + """teardown all initialized variables.""" + with self.app.app_context(): + # drop all tables + db.session.remove() + db.drop_all() + +# Make the tests conveniently executable +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_user_model.py b/tests/test_user_model.py new file mode 100644 index 0000000..dda0bf5 --- /dev/null +++ b/tests/test_user_model.py @@ -0,0 +1,32 @@ +import unittest +import os + +import sys +myPath = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, myPath + '/../') + +from app.models import User + +class UserModelTestCase(unittest.TestCase): + def test_password_setter(self): + u = User(password='cat') + self.assertTrue(u.password_hash is not None) + + def test_no_password_getter(self): + u = User(password='cat') + with self.assertRaises(AttributeError): + u.password + + def test_password_verification(self): + u = User(password='cat') + self.assertTrue(u.verify_password('cat')) + self.assertFalse(u.verify_password('dog')) + + def test_password_salts_are_random(self): + u1 = User(password='cat') + u2 = User(password='cat') + self.assertFalse(u1.password_hash == u2.password_hash) + +# Make the tests conveniently executable +if __name__ == "__main__": + unittest.main()