Getty Images/iStockphoto

Tip

6 essential refactoring techniques to know

Refactoring helps optimize the codebase and combat technical debt. Use these six code refactoring techniques to improve code structure, maintainability and overall quality.

Code refactoring is an important process in software development that involves restructuring code without changing its functional behavior.

Refactoring is important for many reasons, with the main reason being to improve the code's design and maintainability. As teams add changes and features to any form of software, there always arises the opportunity to refactor. In this way, refactoring is a crucial step between having written some code and writing new code.

After writing code for a new feature or a bug fix for an existing feature, developers should take a step back to consider how the changes they've introduced might or might not fit into the overall design of the software. With the addition of new changes, developers might consider an overall structure that better suits the codebase, existing methods that could be simplified or new tests that would better assert the code's correctness.

6 code refactoring techniques to know

Take a look at some different refactoring techniques and examples of those techniques in practice. Some examples -- such as "dealing with generalization" and "composing methods" are broader categories with more specific approaches described within.

1) Red, green, refactor

Red, green, refactor is a test-driven development technique that contains three main phases:

  • Red. The steps of writing a failing test for the new desired functionality.
  • Green. Writing the minimal code required to pass the failing test.
  • Refactor. Refactoring the new method.

As an example, we might want to write a new method to reverse a string in Python. The tests encourage the developer to consider the interface to the code first -- how the method will look regarding the arguments it takes or the values it returns.

The tests also serve as requirements for the method's functionality. They specify expected output given a certain input.

The tests might look something like the following:

import unittest

# We will define the function here later

class TestReverseString(unittest.TestCase):
	def test_empty_string(self):
  	  # Test case for an empty string
  	  self.assertEqual(reverse_string(""), "")

	def test_single_character_string(self):
  	  # Test case for a single character string
  	  self.assertEqual(reverse_string("a"), "a")

	def test_palindrome_string(self):
  	  # Test case for a palindrome
  	  self.assertEqual(reverse_string("madam"), "madam")

	def test_normal_string(self):
  	  # Test case for a regular string
  	  self.assertEqual(reverse_string("hello"), "olleh")
​

The first iteration of the new method might then look something like the below.

import unittest

def reverse_string(s):
	# Minimal code to pass the tests using a loop
	reversed_chars = []
	for char in s:
  	  reversed_chars.append(char)
	  reversed_chars.reverse() # Reverse the list of characters
	  return "".join(reversed_chars) # Join them back into a string

class TestReverseString(unittest.TestCase):
	def test_empty_string(self):
  	  self.assertEqual(reverse_string(""), "")

	def test_single_character_string(self):
  	  self.assertEqual(reverse_string("a"), "a")

	def test_palindrome_string(self):
  	  self.assertEqual(reverse_string("madam"), "madam")

	def test_normal_string(self):
  	  self.assertEqual(reverse_string("hello"), "olleh")

if __name__ == '__main__':
	unittest.main()

A refactor of the reverse_string method could look like the following:

import unittest

def reverse_string(s):
	# Refactored code - simpler and more Pythonic using slicing
	return s[::-1]
​

A refactor of the reverse_string method could look like the following:

import unittest

def reverse_string(s):
	# Refactored code - simpler and more Pythonic using slicing
	return s[::-1]

2) Refactor by abstraction

Refactoring by abstraction is a technique that identifies and pulls common functionality from multiple places into a higher-level abstraction, such as an abstract class or interface.

This process aims to reduce code duplication, improve modularity, and make the codebase more maintainable and accessible. Centralizing shared logic simplifies new code changes or additions. This leads to cleaner, more reliable software design and reduces the likelihood of introducing errors

As an example, start with two classes, each responsible for sending some type of message. The classes could be abstracted into a generic MessageSender class, where both specific message-sending classes can reuse the common logic.

Below are the two classes before refactoring.

class EmailSender:
	def __init__(self, server, port):
  	  self.server = server
  	  self.port = port
  	  self.connection = None

	def connect(self):
  	  print(f"Connecting to email server {self.server}:{self.port}...")
  	  # Simulate connection
  	  self.connection = f"EmailConnection:{self.server}"
  	  print("Email connected.")

	def send_message(self, recipient, message):
  	  if not self.connection:
    	 print("Error: Not connected to email server.")
    	 return
  	  print(f"Sending email to {recipient} via {self.connection}: {message}")
  	  # Simulate sending

	def disconnect(self):
  	  if self.connection:
    	 print(f"Disconnecting from {self.connection}...")
    	 self.connection = None
    	 print("Email disconnected.")

class SMSSender:
	def __init__(self, gateway, api_key):
  	  self.gateway = gateway
  	  self.api_key = api_key
  	  self.connection = None

	def connect(self):
  	  print(f"Connecting to SMS gateway {self.gateway} using API key {self.api_key}...")
  	  # Simulate connection
  	  self.connection = f"SMSConnection:{self.gateway}"
  	  print("SMS connected.")

	def send_message(self, recipient, message):
  	  if not self.connection:
    	 print("Error: Not connected to SMS gateway.")
    	 return
  	  print(f"Sending SMS to {recipient} via {self.connection}: {message}")
  	  # Simulate sending

	def disconnect(self):
  	  if self.connection:
    	 print(f"Disconnecting from {self.connection}...")
    	 self.connection = None
    	 print("SMS disconnected.")

# Usage
email_sender = EmailSender("smtp.example.com", 587)
email_sender.connect()
email_sender.send_message("[email protected]", "Hello via email!")
email_sender.disconnect()

print("-" * 20)

sms_sender = SMSSender("sms.gateway.com", "abcdef12345")
sms_sender.connect()
sms_sender.send_message("123-456-7890", "Hello via SMS!")
sms_sender.disconnect()
​

Below is the code after refactoring the common code between the two classes into one abstract class.

import abc

class MessageSender(abc.ABC):
	def __init__(self):
  	  self.connection = None

	@abc.abstractmethod
	def connect(self):
  	  # Abstract method to establish a connection
  	  pass

	@abc.abstractmethod
	def _send(self, recipient, message):
  	  # Abstract method for the specific sending logic
  	  pass

	def send_message(self, recipient, message):
  	  # Handles connection check and calls the specific send logic
  	  if not self.connection:
    	 print("Error: Not connected.") # More generic message
    	 return
  	  self._send(recipient, message) # Delegate to the specific implementation

	def disconnect(self):
  	  # Handles disconnection
  	  if self.connection:
    	 print(f"Disconnecting from {self.connection}...")
    	 self.connection = None
    	 print("Disconnected.") # More generic message

class EmailSender(MessageSender):
	def __init__(self, server, port):
  	  super().__init__()
  	  self.server = server
  	  self.port = port

	def connect(self):
  	  print(f"Connecting to email server {self.server}:{self.port}...")
  	  # Simulate connection
  	  self.connection = f"EmailConnection:{self.server}"
  	  print("Email connected.")

	def _send(self, recipient, message):
  	  print(f"Sending email to {recipient} via {self.connection}: {message}")
  	  # Simulate sending

class SMSSender(MessageSender):
	def __init__(self, gateway, api_key):
  	  super().__init__()
  	  self.gateway = gateway
  	  self.api_key = api_key

	def connect(self):
  	  print(f"Connecting to SMS gateway {self.gateway} using API key {self.api_key}...")
  	  # Simulate connection
  	  self.connection = f"SMSConnection:{self.gateway}"
  	  print("SMS connected.")

	def _send(self, recipient, message):
  	  print(f"Sending SMS to {recipient} via {self.connection}: {message}")
  	  # Simulate sending

# Usage
email_sender = EmailSender("smtp.example.com", 587)
email_sender.connect()
email_sender.send_message("[email protected]", "Hello via email!")
email_sender.disconnect()

print("-" * 20)

sms_sender = SMSSender("sms.gateway.com", "abcdef12345")
sms_sender.connect()
sms_sender.send_message("123-456-7890", "Hello via SMS!")
​

3) Composing methods

Composing methods focus on improving the clarity and structure of individual methods, primarily by managing their length and complexity.

Refactoring is a crucial step between having written some code and writing new code.

The main idea is to break down lengthy, hard-to-understand methods into smaller, more focused units, making each piece of logic easier to grasp and maintain. Some techniques central to this approach include the following:

  • Extract method, which turns a block of code within a method into a new, separately callable method.
  • Inline method, which brings the body of a simple method directly into its caller.

These aim to ensure each method operates at a single level of abstraction and has a clear, singular purpose. Refactoring in this way also makes testing methods much more concise, as each method has a clearer purpose.

Below is a class that processes orders, starting with a bloated process_order method ripe for refactoring via the composing methods technique.

class OrderProcessor:
	def __init__(self, order_items, discount_rate, tax_rate):
  	  self.order_items = order_items
  	  self.discount_rate = discount_rate
  	  self.tax_rate = tax_rate

	def process_order(self):
  	  # Calculate total price
  	  total_price = 0
  	  for item in self.order_items:
    	 total_price += item['price'] * item['quantity']

  	  # Apply discount
  	  discount_amount = total_price * self.discount_rate
  	  price_after_discount = total_price - discount_amount

  	  # Calculate tax
  	  tax_amount = price_after_discount * self.tax_rate
  	  final_price = price_after_discount + tax_amount

  	  # Format summary
  	  summary = "--- Order Summary ---\n"
  	  for item in self.order_items:
    	 summary += f"- {item['name']}: ${item['price']} x {item['quantity']}\n"
  	  summary += f"---------------------\n"
  	  summary += f"Subtotal: ${total_price:.2f}\n"
  	  summary += f"Discount ({self.discount_rate * 100}%): -${discount_amount:.2f}\n"
  	  summary += f"Tax ({self.tax_rate * 100}%): ${tax_amount:.2f}\n"
  	  summary += f"---------------------\n"
  	  summary += f"Final Total: ${final_price:.2f}\n"
  	  summary += "---------------------"

  	  print(summary)
  	  return final_price

# Example Usage
items = [
 {'name': 'Laptop', 'price': 1000, 'quantity': 1},
 {'name': 'Mouse', 'price': 25, 'quantity': 2},
]
processor = OrderProcessor(items, 0.10, 0.08) # 10% discount, 8% tax
processor.process_order()
​

The following refactor extracts the logic of the process_order method into shorter, more concise methods:

class OrderProcessor:
	def __init__(self, order_items, discount_rate, tax_rate):
  	  self.order_items = order_items
  	  self.discount_rate = discount_rate
  	  self.tax_rate = tax_rate

	def calculate_subtotal(self):
  	  # Calculates the total price of order items before discount and tax
  	  subtotal = 0
  	  for item in self.order_items:
    	 subtotal += item['price'] * item['quantity']
  	  return subtotal

	def apply_discount(self, subtotal):
  	  # Applies the discount rate to the subtotal
  	  discount_amount = subtotal * self.discount_rate
  	  return subtotal - discount_amount, discount_amount

	def calculate_tax(self, price_after_discount):
  	  # Calculates the tax amount based on the price after discount
  	  tax_amount = price_after_discount * self.tax_rate
  	  return price_after_discount + tax_amount, tax_amount

	def format_summary(self, total_price, discount_amount, tax_amount, final_price):
  	  # Formats the order summary string
  	  summary = "--- Order Summary ---\n"
  	  for item in self.order_items:
    	 summary += f"- {item['name']}: ${item['price']} x {item['quantity']}\n"
  	  summary += f"---------------------\n"
  	  summary += f"Subtotal: ${total_price:.2f}\n"
  	  summary += f"Discount ({self.discount_rate * 100}%): -${discount_amount:.2f}\n"
  	  summary += f"Tax ({self.tax_rate * 100}%): ${tax_amount:.2f}\n"
  	  summary += f"---------------------\n"
  	  summary += f"Final Total: ${final_price:.2f}\n"
  	  summary += "---------------------"
  	  return summary

	def process_order(self):
     # Orchestrates the order processing steps
  	  subtotal = self.calculate_subtotal()
  	  price_after_discount, discount_amount = self.apply_discount(subtotal)
  	  final_price, tax_amount = self.calculate_tax(price_after_discount)
  	  summary = self.format_summary(subtotal, discount_amount, tax_amount, final_price)

  	  print(summary)
  	  return final_price

# Example Usage
items = [
 {'name': 'Laptop', 'price': 1000, 'quantity': 1},
 {'name': 'Mouse', 'price': 25, 'quantity': 2},
]
processor = OrderProcessor(items, 0.10, 0.08) # 10% discount, 8% tax
processor.process_order()
​

4) Simplifying methods

Simplifying methods is a refactoring approach that reduces the internal complexity of existing methods. It makes the logic inside each method clearer and more direct.

Rather than breaking a method into multiple smaller ones, this technique addresses complicated conditional expressions, overly long parameter lists or hard-to-follow method calls within a single method. Some techniques to simplify methods include the following:

  • Replacing nested conditionals with guard clauses to flatten the structure.
  • Consolidating duplicate conditional code.
  • Introducing parameter objects to streamline method signatures.

These techniques make the method's logic easier to follow at a glance, thereby improving maintainability and reducing the risk of introducing bugs during modifications.

Here's an example, starting with a very complicated check_access method that nests many conditionals, making it especially hard to read.

def check_access(user_status, user_role, subscription_level, resource_type):
	if user_status == "active":
  	if user_role == "admin":
    	if resource_type == "sensitive":
      	# Admins always have access to sensitive resources if active
      	return True
    	else:
      	# Admins have access to all non-sensitive resources if active
      	return True
  	else: # user_role is not admin
    	if subscription_level == "premium":
      	if resource_type != "sensitive":
        	# Premium users have access to non-sensitive resources if active and not admin
        	return True
      	else:
        	# Premium users cannot access sensitive resources even if active and not admin
        	return False
    	else: # subscription_level is not premium
      	if resource_type == "public":
        	# Non-premium, non-admin active users can only access public resources
        	return True
      	else:
        	# Non-premium, non-admin active users cannot access non-public resources
        	return False
	else: # user_status is not active
  	# Inactive users have no access
  	return False

# Example Usage
print(f"Admin active, sensitive resource: {check_access('active', 'admin', 'premium', 'sensitive')}")
print(f"User active, premium, non-sensitive resource: {check_access('active', 'user', 'premium', 'standard')}")
print(f"User active, basic, public resource: {check_access('active', 'user', 'basic', 'public')}")
print(f"User active, basic, sensitive resource: {check_access('active', 'user', 'basic', 'sensitive')}")
print(f"Inactive user, admin, sensitive resource: {check_access('inactive', 'admin', 'premium', 'sensitive')}")

Refactoring using guard clauses can simplify the logic greatly while also improving readability.

def check_access_simplified(user_status, user_role, subscription_level, resource_type):
	# Guard clause: Inactive users have no access
	if user_status != "active":
  	  return False

	# Guard clause: Admins have full access if active
	if user_role == "admin":
  	  return True

	# At this point, the user is active and not an admin.

	# Guard clause: Non-premium users can only access public resources
	if subscription_level != "premium":
  	  if resource_type == "public":
       return True
  	  else:
    	 return False

	# At this point, the user is active, not an admin, and has a premium subscription.

	# Guard clause: Premium users cannot access sensitive resources
	if resource_type == "sensitive":
  	  return False

	# If none of the above conditions led to an early exit,
	# the user is active, not an admin, premium, and accessing a non-sensitive resource.
	return True

# Example Usage (should produce the same results)
print(f"Admin active, sensitive resource: {check_access_simplified('active', 'admin', 'premium', 'sensitive')}")
print(f"User active, premium, non-sensitive resource: {check_access_simplified('active', 'user', 'premium', 'standard')}")
print(f"User active, basic, public resource: {check_access_simplified('active', 'user', 'basic', 'public')}")
print(f"User active, basic, sensitive resource: {check_access_simplified('active', 'user', 'basic', 'sensitive')}")
print(f"Inactive user, admin, sensitive resource: {check_access_simplified('inactive', 'admin', 'premium', 'sensitive')}")

5) Moving features between objects

Moving features between objects focuses on improving the distribution of responsibilities within a codebase by relocating methods or fields to the classes where they are most logically at home. This process addresses situations where a method or piece of data in one class is more closely associated with another, aiming to enhance class cohesion and reduce unnecessary coupling between different parts of the system.

Techniques such as move method or move field ensure that each object encapsulates behaviors and data that are truly relevant to its core purpose, resulting in a more intuitive and maintainable object model where related functionalities reside together.

Below is an example of moving features between objects. Initially, a Customer class has a method calculate order total, but the method would logically make more sense in the Order class.

class Order:
	def __init__(self, items):
  	  self.items = items # items is a list of dictionaries, e.g., [{'name': 'book', 'price': 20, 'quantity': 1}]

class Customer:
	def __init__(self, name, email):
  	  self.name = name
  	  self.email = email
  	  self.orders = []

	def add_order(self, order):
  	  self.orders.append(order)

	def calculate_order_total(self, order):
  	  # Calculates the total price for a specific order
  	  total = 0
  	  for item in order.items:
    	 total += item['price'] * item['quantity']
  	  return total

# Example Usage
customer = Customer("Alice", "[email protected]")
order1 = Order([{'name': 'Laptop', 'price': 1000, 'quantity': 1}])
order2 = Order([{'name': 'Keyboard', 'price': 75, 'quantity': 1}, {'name': 'Mouse', 'price': 25, 'quantity': 2}])

customer.add_order(order1)
customer.add_order(order2)

# Calculating total using the method in Customer class
order1_total = customer.calculate_order_total(order1)
order2_total = customer.calculate_order_total(order2)

print(f"{customer.name}'s Order 1 Total: ${order1_total:.2f}")
print(f"{customer.name}'s Order 2 Total: ${order2_total:.2f}")

The below refactored code moves the calculate_order_total method.

class Order:
	def __init__(self, items):
  	  self.items = items # items is a list of dictionaries

	def calculate_total(self):
  	  # Calculates the total price for this order
  	  total = 0
  	  for item in self.items:
    	 total += item['price'] * item['quantity']
  	  return total

class Customer:
	def __init__(self, name, email):
  	  self.name = name
  	  self.email = email
  	  self.orders = []

	def add_order(self, order):
  	  self.orders.append(order)

	# The calculate_order_total method has been removed from here

# Example Usage
customer = Customer("Alice", "[email protected]")
order1 = Order([{'name': 'Laptop', 'price': 1000, 'quantity': 1}])
order2 = Order([{'name': 'Keyboard', 'price': 75, 'quantity': 1}, {'name': 'Mouse', 'price': 25, 'quantity': 2}])

customer.add_order(order1)
customer.add_order(order2)

# Calculating total using the method in the Order class
order1_total = order1.calculate_total()
order2_total = order2.calculate_total()

print(f"{customer.name}'s Order 1 Total: ${order1_total:.2f}")
print(f"{customer.name}'s Order 2 Total: ${order2_total:.2f}")

6) Dealing with generalization

Dealing with generalization is a collection of refactoring techniques for managing and refining the use of inheritance and type information within a codebase to improve class structure. The goal is to clarify relationships between classes, remove cumbersome conditional logic dependent on type codes and ensure that shared behaviors and data are appropriately located in the hierarchy. This involves actions such as the following:

  • Pulling common methods or fields up to a superclass.
  • Pushing specialized ones down to subclasses.
  • Extracting interfaces to define common capabilities.
  • Using polymorphism by replacing type-checking conditionals with distinct subclass implementations.

These techniques produce a cleaner, more flexible design better equipped for future extension.

Below is an example of dealing with generalization. In this case, a Product class has a get_shipping_cost method that uses conditionals based on a type. By refactoring to use subclasses of the Product class, the logic for calculating a product's shipping cost can be contained within the subclass, and we can get rid of the long, complicated conditional.

class Product:
	def __init__(self, name, price, type):
  	  self.name = name
  	  self.price = price
  	  self.type = type # 'book', 'electronics', 'clothing'

	def get_shipping_cost(self):
  	  if self.type == 'book':
    	 return 5.00
  	  elif self.type == 'electronics':
    	 return 10.00
  	  elif self.type == 'clothing':
    	 return 7.50
  	  else:
    	 return 99.99 # Default or error case

# Example Usage
book = Product("The Hitchhiker's Guide to the Galaxy", 10.00, 'book')
electronics = Product("Wireless Mouse", 25.00, 'electronics')
clothing = Product("T-Shirt", 15.00, 'clothing')
unknown = Product("Mystery Item", 50.00, 'unknown')

print(f"{book.name} Shipping Cost: ${book.get_shipping_cost():.2f}")
print(f"{electronics.name} Shipping Cost: ${electronics.get_shipping_cost():.2f}")
print(f"{clothing.name} Shipping Cost: ${clothing.get_shipping_cost():.2f}")
print(f"{unknown.name} Shipping Cost: ${unknown.get_shipping_cost():.2f}")

After refactoring, the logic for calculating the product's shipping cost is contained within the subclass.

class Product:
	def __init__(self, name, price):
  	  self.name = name
  	  self.price = price

	def get_shipping_cost(self):
  	  # Default shipping cost or base implementation
  	  return 99.99 # Could also raise a NotImplementedError for abstract behavior

class BookProduct(Product):
	def __init__(self, name, price):
  	  super().__init__(name, price)

	def get_shipping_cost(self):
  	  return 5.00

class ElectronicsProduct(Product):
	def __init__(self, name, price):
  	  super().__init__(name, price)

	def get_shipping_cost(self):
  	  return 10.00

class ClothingProduct(Product):
	def __init__(self, name, price):
  	  super().__init__(name, price)

	def get_shipping_cost(self):
  	  return 7.50

# Example Usage
book = BookProduct("The Hitchhiker's Guide to the Galaxy", 10.00)
electronics = ElectronicsProduct("Wireless Mouse", 25.00)
clothing = ClothingProduct("T-Shirt", 15.00)
# The 'unknown' type would ideally not exist with this structure,
# but we can demonstrate the base class behavior
unknown = Product("Mystery Item", 50.00)


print(f"{book.name} Shipping Cost: ${book.get_shipping_cost():.2f}")
print(f"{electronics.name} Shipping Cost: ${electronics.get_shipping_cost():.2f}")
print(f"{clothing.name} Shipping Cost: ${clothing.get_shipping_cost():.2f}")
print(f"{unknown.name} Shipping Cost: ${unknown.get_shipping_cost():.2f}")

Refactoring best practices

As developers implement these techniques, it's always advisable to keep some best practices in mind:

  • Test before refactoring. Establish a baseline of functioning code before embarking on refactoring efforts, validating all critical functionality.
  • Continuously test. After each modification, no matter how small, execute the test suite. This ensures that refactored code still functions and doesn't introduce defects.
  • Take advantage of tools. Modern IDEs have useful features for code restructuring, variable and method renaming, function extraction and formatting.
  • Take an incremental approach. Break refactoring efforts into manageable steps to make any introduced bugs easier to trace.
  • Schedule refactoring strategically. Refactor before making major changes such as update rollouts or new feature implementation.

Matt Grasberger is a DevOps engineer with experience in test automation, software development and designing automated processes to reduce work.

Dig Deeper on Software design and development