Refactoring helps optimize the codebase and combat technical debt. Use these six code refactoring techniques to improve code structure, maintainability and overall quality.
Code refactoring is an important process in software development that involves restructuring code without changing its functional behavior.
Refactoring is important for many reasons, with the main reason being to improve the code's design and maintainability. As teams add changes and features to any form of software, there always arises the opportunity to refactor. In this way, refactoring is a crucial step between having written some code and writing new code.
After writing code for a new feature or a bug fix for an existing feature, developers should take a step back to consider how the changes they've introduced might or might not fit into the overall design of the software. With the addition of new changes, developers might consider an overall structure that better suits the codebase, existing methods that could be simplified or new tests that would better assert the code's correctness.
6 code refactoring techniques to know
Take a look at some different refactoring techniques and examples of those techniques in practice. Some examples -- such as "dealing with generalization" and "composing methods" are broader categories with more specific approaches described within.
Red. The steps of writing a failing test for the new desired functionality.
Green. Writing the minimal code required to pass the failing test.
Refactor. Refactoring the new method.
As an example, we might want to write a new method to reverse a string in Python. The tests encourage the developer to consider the interface to the code first -- how the method will look regarding the arguments it takes or the values it returns.
The tests might look something like the following:
import unittest
# We will define the function here later
class TestReverseString(unittest.TestCase):
def test_empty_string(self):
# Test case for an empty string
self.assertEqual(reverse_string(""), "")
def test_single_character_string(self):
# Test case for a single character string
self.assertEqual(reverse_string("a"), "a")
def test_palindrome_string(self):
# Test case for a palindrome
self.assertEqual(reverse_string("madam"), "madam")
def test_normal_string(self):
# Test case for a regular string
self.assertEqual(reverse_string("hello"), "olleh")
The first iteration of the new method might then look something like the below.
import unittest
def reverse_string(s):
# Minimal code to pass the tests using a loop
reversed_chars = []
for char in s:
reversed_chars.append(char)
reversed_chars.reverse() # Reverse the list of characters
return "".join(reversed_chars) # Join them back into a string
class TestReverseString(unittest.TestCase):
def test_empty_string(self):
self.assertEqual(reverse_string(""), "")
def test_single_character_string(self):
self.assertEqual(reverse_string("a"), "a")
def test_palindrome_string(self):
self.assertEqual(reverse_string("madam"), "madam")
def test_normal_string(self):
self.assertEqual(reverse_string("hello"), "olleh")
if __name__ == '__main__':
unittest.main()
A refactor of the reverse_string method could look like the following:
import unittest
def reverse_string(s):
# Refactored code - simpler and more Pythonic using slicing
return s[::-1]
A refactor of the reverse_string method could look like the following:
import unittest
def reverse_string(s):
# Refactored code - simpler and more Pythonic using slicing
return s[::-1]
2) Refactor by abstraction
Refactoring by abstraction is a technique that identifies and pulls common functionality from multiple places into a higher-level abstraction, such as an abstract class or interface.
This process aims to reduce code duplication, improve modularity, and make the codebase more maintainable and accessible. Centralizing shared logic simplifies new code changes or additions. This leads to cleaner, more reliable software design and reduces the likelihood of introducing errors
As an example, start with two classes, each responsible for sending some type of message. The classes could be abstracted into a generic MessageSender class, where both specific message-sending classes can reuse the common logic.
Below are the two classes before refactoring.
class EmailSender:
def __init__(self, server, port):
self.server = server
self.port = port
self.connection = None
def connect(self):
print(f"Connecting to email server {self.server}:{self.port}...")
# Simulate connection
self.connection = f"EmailConnection:{self.server}"
print("Email connected.")
def send_message(self, recipient, message):
if not self.connection:
print("Error: Not connected to email server.")
return
print(f"Sending email to {recipient} via {self.connection}: {message}")
# Simulate sending
def disconnect(self):
if self.connection:
print(f"Disconnecting from {self.connection}...")
self.connection = None
print("Email disconnected.")
class SMSSender:
def __init__(self, gateway, api_key):
self.gateway = gateway
self.api_key = api_key
self.connection = None
def connect(self):
print(f"Connecting to SMS gateway {self.gateway} using API key {self.api_key}...")
# Simulate connection
self.connection = f"SMSConnection:{self.gateway}"
print("SMS connected.")
def send_message(self, recipient, message):
if not self.connection:
print("Error: Not connected to SMS gateway.")
return
print(f"Sending SMS to {recipient} via {self.connection}: {message}")
# Simulate sending
def disconnect(self):
if self.connection:
print(f"Disconnecting from {self.connection}...")
self.connection = None
print("SMS disconnected.")
# Usage
email_sender = EmailSender("smtp.example.com", 587)
email_sender.connect()
email_sender.send_message("[email protected]", "Hello via email!")
email_sender.disconnect()
print("-" * 20)
sms_sender = SMSSender("sms.gateway.com", "abcdef12345")
sms_sender.connect()
sms_sender.send_message("123-456-7890", "Hello via SMS!")
sms_sender.disconnect()
Below is the code after refactoring the common code between the two classes into one abstract class.
import abc
class MessageSender(abc.ABC):
def __init__(self):
self.connection = None
@abc.abstractmethod
def connect(self):
# Abstract method to establish a connection
pass
@abc.abstractmethod
def _send(self, recipient, message):
# Abstract method for the specific sending logic
pass
def send_message(self, recipient, message):
# Handles connection check and calls the specific send logic
if not self.connection:
print("Error: Not connected.") # More generic message
return
self._send(recipient, message) # Delegate to the specific implementation
def disconnect(self):
# Handles disconnection
if self.connection:
print(f"Disconnecting from {self.connection}...")
self.connection = None
print("Disconnected.") # More generic message
class EmailSender(MessageSender):
def __init__(self, server, port):
super().__init__()
self.server = server
self.port = port
def connect(self):
print(f"Connecting to email server {self.server}:{self.port}...")
# Simulate connection
self.connection = f"EmailConnection:{self.server}"
print("Email connected.")
def _send(self, recipient, message):
print(f"Sending email to {recipient} via {self.connection}: {message}")
# Simulate sending
class SMSSender(MessageSender):
def __init__(self, gateway, api_key):
super().__init__()
self.gateway = gateway
self.api_key = api_key
def connect(self):
print(f"Connecting to SMS gateway {self.gateway} using API key {self.api_key}...")
# Simulate connection
self.connection = f"SMSConnection:{self.gateway}"
print("SMS connected.")
def _send(self, recipient, message):
print(f"Sending SMS to {recipient} via {self.connection}: {message}")
# Simulate sending
# Usage
email_sender = EmailSender("smtp.example.com", 587)
email_sender.connect()
email_sender.send_message("[email protected]", "Hello via email!")
email_sender.disconnect()
print("-" * 20)
sms_sender = SMSSender("sms.gateway.com", "abcdef12345")
sms_sender.connect()
sms_sender.send_message("123-456-7890", "Hello via SMS!")
3) Composing methods
Composing methods focus on improving the clarity and structure of individual methods, primarily by managing their length and complexity.
Refactoring is a crucial step between having written some code and writing new code.
The main idea is to break down lengthy, hard-to-understand methods into smaller, more focused units, making each piece of logic easier to grasp and maintain. Some techniques central to this approach include the following:
Extract method, which turns a block of code within a method into a new, separately callable method.
Inline method, which brings the body of a simple method directly into its caller.
These aim to ensure each method operates at a single level of abstraction and has a clear, singular purpose. Refactoring in this way also makes testing methods much more concise, as each method has a clearer purpose.
Below is a class that processes orders, starting with a bloated process_order method ripe for refactoring via the composing methods technique.
The following refactor extracts the logic of the process_order method into shorter, more concise methods:
class OrderProcessor:
def __init__(self, order_items, discount_rate, tax_rate):
self.order_items = order_items
self.discount_rate = discount_rate
self.tax_rate = tax_rate
def calculate_subtotal(self):
# Calculates the total price of order items before discount and tax
subtotal = 0
for item in self.order_items:
subtotal += item['price'] * item['quantity']
return subtotal
def apply_discount(self, subtotal):
# Applies the discount rate to the subtotal
discount_amount = subtotal * self.discount_rate
return subtotal - discount_amount, discount_amount
def calculate_tax(self, price_after_discount):
# Calculates the tax amount based on the price after discount
tax_amount = price_after_discount * self.tax_rate
return price_after_discount + tax_amount, tax_amount
def format_summary(self, total_price, discount_amount, tax_amount, final_price):
# Formats the order summary string
summary = "--- Order Summary ---\n"
for item in self.order_items:
summary += f"- {item['name']}: ${item['price']} x {item['quantity']}\n"
summary += f"---------------------\n"
summary += f"Subtotal: ${total_price:.2f}\n"
summary += f"Discount ({self.discount_rate * 100}%): -${discount_amount:.2f}\n"
summary += f"Tax ({self.tax_rate * 100}%): ${tax_amount:.2f}\n"
summary += f"---------------------\n"
summary += f"Final Total: ${final_price:.2f}\n"
summary += "---------------------"
return summary
def process_order(self):
# Orchestrates the order processing steps
subtotal = self.calculate_subtotal()
price_after_discount, discount_amount = self.apply_discount(subtotal)
final_price, tax_amount = self.calculate_tax(price_after_discount)
summary = self.format_summary(subtotal, discount_amount, tax_amount, final_price)
print(summary)
return final_price
# Example Usage
items = [
{'name': 'Laptop', 'price': 1000, 'quantity': 1},
{'name': 'Mouse', 'price': 25, 'quantity': 2},
]
processor = OrderProcessor(items, 0.10, 0.08) # 10% discount, 8% tax
processor.process_order()
4) Simplifying methods
Simplifying methods is a refactoring approach that reduces the internal complexity of existing methods. It makes the logic inside each method clearer and more direct.
Rather than breaking a method into multiple smaller ones, this technique addresses complicated conditional expressions, overly long parameter lists or hard-to-follow method calls within a single method. Some techniques to simplify methods include the following:
Replacing nested conditionals with guard clauses to flatten the structure.
Consolidating duplicate conditional code.
Introducing parameter objects to streamline method signatures.
These techniques make the method's logic easier to follow at a glance, thereby improving maintainability and reducing the risk of introducing bugs during modifications.
Here's an example, starting with a very complicated check_access method that nests many conditionals, making it especially hard to read.
def check_access(user_status, user_role, subscription_level, resource_type):
if user_status == "active":
if user_role == "admin":
if resource_type == "sensitive":
# Admins always have access to sensitive resources if active
return True
else:
# Admins have access to all non-sensitive resources if active
return True
else: # user_role is not admin
if subscription_level == "premium":
if resource_type != "sensitive":
# Premium users have access to non-sensitive resources if active and not admin
return True
else:
# Premium users cannot access sensitive resources even if active and not admin
return False
else: # subscription_level is not premium
if resource_type == "public":
# Non-premium, non-admin active users can only access public resources
return True
else:
# Non-premium, non-admin active users cannot access non-public resources
return False
else: # user_status is not active
# Inactive users have no access
return False
# Example Usage
print(f"Admin active, sensitive resource: {check_access('active', 'admin', 'premium', 'sensitive')}")
print(f"User active, premium, non-sensitive resource: {check_access('active', 'user', 'premium', 'standard')}")
print(f"User active, basic, public resource: {check_access('active', 'user', 'basic', 'public')}")
print(f"User active, basic, sensitive resource: {check_access('active', 'user', 'basic', 'sensitive')}")
print(f"Inactive user, admin, sensitive resource: {check_access('inactive', 'admin', 'premium', 'sensitive')}")
Refactoring using guard clauses can simplify the logic greatly while also improving readability.
def check_access_simplified(user_status, user_role, subscription_level, resource_type):
# Guard clause: Inactive users have no access
if user_status != "active":
return False
# Guard clause: Admins have full access if active
if user_role == "admin":
return True
# At this point, the user is active and not an admin.
# Guard clause: Non-premium users can only access public resources
if subscription_level != "premium":
if resource_type == "public":
return True
else:
return False
# At this point, the user is active, not an admin, and has a premium subscription.
# Guard clause: Premium users cannot access sensitive resources
if resource_type == "sensitive":
return False
# If none of the above conditions led to an early exit,
# the user is active, not an admin, premium, and accessing a non-sensitive resource.
return True
# Example Usage (should produce the same results)
print(f"Admin active, sensitive resource: {check_access_simplified('active', 'admin', 'premium', 'sensitive')}")
print(f"User active, premium, non-sensitive resource: {check_access_simplified('active', 'user', 'premium', 'standard')}")
print(f"User active, basic, public resource: {check_access_simplified('active', 'user', 'basic', 'public')}")
print(f"User active, basic, sensitive resource: {check_access_simplified('active', 'user', 'basic', 'sensitive')}")
print(f"Inactive user, admin, sensitive resource: {check_access_simplified('inactive', 'admin', 'premium', 'sensitive')}")
5) Moving features between objects
Moving features between objects focuses on improving the distribution of responsibilities within a codebase by relocating methods or fields to the classes where they are most logically at home. This process addresses situations where a method or piece of data in one class is more closely associated with another, aiming to enhance class cohesion and reduce unnecessary coupling between different parts of the system.
Techniques such as move method or move field ensure that each object encapsulates behaviors and data that are truly relevant to its core purpose, resulting in a more intuitive and maintainable object model where related functionalities reside together.
Below is an example of moving features between objects. Initially, a Customer class has a method calculate order total, but the method would logically make more sense in the Order class.
class Order:
def __init__(self, items):
self.items = items # items is a list of dictionaries, e.g., [{'name': 'book', 'price': 20, 'quantity': 1}]
class Customer:
def __init__(self, name, email):
self.name = name
self.email = email
self.orders = []
def add_order(self, order):
self.orders.append(order)
def calculate_order_total(self, order):
# Calculates the total price for a specific order
total = 0
for item in order.items:
total += item['price'] * item['quantity']
return total
# Example Usage
customer = Customer("Alice", "[email protected]")
order1 = Order([{'name': 'Laptop', 'price': 1000, 'quantity': 1}])
order2 = Order([{'name': 'Keyboard', 'price': 75, 'quantity': 1}, {'name': 'Mouse', 'price': 25, 'quantity': 2}])
customer.add_order(order1)
customer.add_order(order2)
# Calculating total using the method in Customer class
order1_total = customer.calculate_order_total(order1)
order2_total = customer.calculate_order_total(order2)
print(f"{customer.name}'s Order 1 Total: ${order1_total:.2f}")
print(f"{customer.name}'s Order 2 Total: ${order2_total:.2f}")
The below refactored code moves the calculate_order_total method.
class Order:
def __init__(self, items):
self.items = items # items is a list of dictionaries
def calculate_total(self):
# Calculates the total price for this order
total = 0
for item in self.items:
total += item['price'] * item['quantity']
return total
class Customer:
def __init__(self, name, email):
self.name = name
self.email = email
self.orders = []
def add_order(self, order):
self.orders.append(order)
# The calculate_order_total method has been removed from here
# Example Usage
customer = Customer("Alice", "[email protected]")
order1 = Order([{'name': 'Laptop', 'price': 1000, 'quantity': 1}])
order2 = Order([{'name': 'Keyboard', 'price': 75, 'quantity': 1}, {'name': 'Mouse', 'price': 25, 'quantity': 2}])
customer.add_order(order1)
customer.add_order(order2)
# Calculating total using the method in the Order class
order1_total = order1.calculate_total()
order2_total = order2.calculate_total()
print(f"{customer.name}'s Order 1 Total: ${order1_total:.2f}")
print(f"{customer.name}'s Order 2 Total: ${order2_total:.2f}")
6) Dealing with generalization
Dealing with generalization is a collection of refactoring techniques for managing and refining the use of inheritance and type information within a codebase to improve class structure. The goal is to clarify relationships between classes, remove cumbersome conditional logic dependent on type codes and ensure that shared behaviors and data are appropriately located in the hierarchy. This involves actions such as the following:
Pulling common methods or fields up to a superclass.
Pushing specialized ones down to subclasses.
Extracting interfaces to define common capabilities.
Using polymorphism by replacing type-checking conditionals with distinct subclass implementations.
These techniques produce a cleaner, more flexible design better equipped for future extension.
Below is an example of dealing with generalization. In this case, a Product class has a get_shipping_cost method that uses conditionals based on a type. By refactoring to use subclasses of the Product class, the logic for calculating a product's shipping cost can be contained within the subclass, and we can get rid of the long, complicated conditional.
class Product:
def __init__(self, name, price, type):
self.name = name
self.price = price
self.type = type # 'book', 'electronics', 'clothing'
def get_shipping_cost(self):
if self.type == 'book':
return 5.00
elif self.type == 'electronics':
return 10.00
elif self.type == 'clothing':
return 7.50
else:
return 99.99 # Default or error case
# Example Usage
book = Product("The Hitchhiker's Guide to the Galaxy", 10.00, 'book')
electronics = Product("Wireless Mouse", 25.00, 'electronics')
clothing = Product("T-Shirt", 15.00, 'clothing')
unknown = Product("Mystery Item", 50.00, 'unknown')
print(f"{book.name} Shipping Cost: ${book.get_shipping_cost():.2f}")
print(f"{electronics.name} Shipping Cost: ${electronics.get_shipping_cost():.2f}")
print(f"{clothing.name} Shipping Cost: ${clothing.get_shipping_cost():.2f}")
print(f"{unknown.name} Shipping Cost: ${unknown.get_shipping_cost():.2f}")
After refactoring, the logic for calculating the product's shipping cost is contained within the subclass.
class Product:
def __init__(self, name, price):
self.name = name
self.price = price
def get_shipping_cost(self):
# Default shipping cost or base implementation
return 99.99 # Could also raise a NotImplementedError for abstract behavior
class BookProduct(Product):
def __init__(self, name, price):
super().__init__(name, price)
def get_shipping_cost(self):
return 5.00
class ElectronicsProduct(Product):
def __init__(self, name, price):
super().__init__(name, price)
def get_shipping_cost(self):
return 10.00
class ClothingProduct(Product):
def __init__(self, name, price):
super().__init__(name, price)
def get_shipping_cost(self):
return 7.50
# Example Usage
book = BookProduct("The Hitchhiker's Guide to the Galaxy", 10.00)
electronics = ElectronicsProduct("Wireless Mouse", 25.00)
clothing = ClothingProduct("T-Shirt", 15.00)
# The 'unknown' type would ideally not exist with this structure,
# but we can demonstrate the base class behavior
unknown = Product("Mystery Item", 50.00)
print(f"{book.name} Shipping Cost: ${book.get_shipping_cost():.2f}")
print(f"{electronics.name} Shipping Cost: ${electronics.get_shipping_cost():.2f}")
print(f"{clothing.name} Shipping Cost: ${clothing.get_shipping_cost():.2f}")
print(f"{unknown.name} Shipping Cost: ${unknown.get_shipping_cost():.2f}")
Refactoring best practices
As developers implement these techniques, it's always advisable to keep some best practices in mind:
Test before refactoring. Establish a baseline of functioning code before embarking on refactoring efforts, validating all critical functionality.
Continuously test. After each modification, no matter how small, execute the test suite. This ensures that refactored code still functions and doesn't introduce defects.
Take advantage of tools. Modern IDEs have useful features for code restructuring, variable and method renaming, function extraction and formatting.
Take an incremental approach. Break refactoring efforts into manageable steps to make any introduced bugs easier to trace.
Schedule refactoring strategically. Refactor before making major changes such as update rollouts or new feature implementation.
Matt Grasberger is a DevOps engineer with experience in test automation, software development and designing automated processes to reduce work.