001/*
002 * Copyright (C) 2022 - 2024, the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *    http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package io.github.ascopes.jct.junit;
017
018import io.github.ascopes.jct.workspaces.Workspace;
019import io.github.ascopes.jct.workspaces.Workspaces;
020import java.lang.reflect.Field;
021import java.lang.reflect.Modifier;
022import java.util.ArrayList;
023import java.util.List;
024import org.apiguardian.api.API;
025import org.apiguardian.api.API.Status;
026import org.jspecify.annotations.Nullable;
027import org.junit.jupiter.api.extension.AfterAllCallback;
028import org.junit.jupiter.api.extension.AfterEachCallback;
029import org.junit.jupiter.api.extension.BeforeAllCallback;
030import org.junit.jupiter.api.extension.BeforeEachCallback;
031import org.junit.jupiter.api.extension.Extension;
032import org.junit.jupiter.api.extension.ExtensionContext;
033import org.slf4j.Logger;
034import org.slf4j.LoggerFactory;
035
036/**
037 * JUnit5 extension that will manage the lifecycle of {@link Managed}-annotated {@link Workspace}
038 * fields within JUnit5 test classes.
039 *
040 * <pre><code>
041 * {@literal @ExtendWith(JctExtension.class)}
042 * class MyTest {
043 *   {@literal @Managed}
044 *   Workspace workspace;
045 *
046 *   {@literal @JavacCompilerTest}
047 *   void myTest(JctCompiler compiler) {
048 *     // Given
049 *     workspace
050 *        .createSourcePathPackage()
051 *        ...;
052 *
053 *     // When
054 *     var compilation = compiler.compile(workspace);
055 *
056 *     // Then
057 *     ...
058 *   }
059 * }
060 * </code></pre>
061 *
062 * @author Ashley Scopes
063 * @since 0.4.0
064 */
065@API(since = "0.4.0", status = Status.STABLE)
066public final class JctExtension
067    implements Extension, BeforeEachCallback, BeforeAllCallback, AfterEachCallback, AfterAllCallback {
068
069  private static final Logger log = LoggerFactory.getLogger(JctExtension.class);
070
071  /**
072   * Initialise this extension.
073   *
074   * <p>You shouldn't ever need to call this directly. See the class description for an example
075   * of how to use this.
076   */
077  public JctExtension() {
078    // Nothing to do.
079  }
080
081  @Override
082  public void beforeAll(ExtensionContext context) throws Exception {
083    for (var field : getManagedWorkspaceFields(context.getRequiredTestClass(), true)) {
084      initWorkspaceForField(field, null);
085    }
086  }
087
088  @Override
089  public void beforeEach(ExtensionContext context) throws Exception {
090    for (var instance : context.getRequiredTestInstances().getAllInstances()) {
091      for (var field : getManagedWorkspaceFields(instance.getClass(), false)) {
092        initWorkspaceForField(field, instance);
093      }
094    }
095  }
096
097  @Override
098  public void afterAll(ExtensionContext context) throws Exception {
099    for (var field : getManagedWorkspaceFields(context.getRequiredTestClass(), true)) {
100      closeWorkspaceForField(field, null);
101    }
102  }
103
104  @Override
105  public void afterEach(ExtensionContext context) throws Exception {
106    for (var instance : context.getRequiredTestInstances().getAllInstances()) {
107      for (var field : getManagedWorkspaceFields(instance.getClass(), false)) {
108        closeWorkspaceForField(field, instance);
109      }
110    }
111  }
112
113  private List<Field> getManagedWorkspaceFields(Class<?> clazz, boolean wantStatic) {
114    var fields = new ArrayList<Field>();
115
116    Class<?> currentClass = clazz;
117
118    do {
119      for (var field : currentClass.getDeclaredFields()) {
120        var isWorkspace = field.getType().equals(Workspace.class);
121        var isManaged = field.isAnnotationPresent(Managed.class);
122        var isDesiredScope = Modifier.isStatic(field.getModifiers()) == wantStatic;
123
124        if (isWorkspace && isManaged && isDesiredScope) {
125          field.setAccessible(true);
126          fields.add(field);
127        }
128      }
129
130      // Only recurse if we are checking instance scope. We don't manage annotated fields
131      // in superclasses that are static as we cannot guarantee they are not shared with a
132      // different class running in parallel.
133      currentClass = wantStatic
134          ? null
135          : currentClass.getSuperclass();
136  
137    } while (currentClass != null);
138
139    return fields;
140  }
141
142  private void initWorkspaceForField(Field field, @Nullable Object instance) throws Exception {
143    log.atTrace()
144        .setMessage("Initialising workspace for field in {}: {} {} on instance {}")
145        .addArgument(() -> field.getDeclaringClass().getSimpleName())
146        .addArgument(() -> field.getType().getSimpleName())
147        .addArgument(field::getName)
148        .addArgument(instance)
149        .log();
150
151    var managedWorkspace = field.getAnnotation(Managed.class);
152    var workspace = Workspaces.newWorkspace(managedWorkspace.pathStrategy());
153    field.set(instance, workspace);
154  }
155
156  private void closeWorkspaceForField(Field field, @Nullable Object instance) throws Exception {
157    log.atTrace()
158        .setMessage("Closing workspace for field in {}: {} {} on instance {}")
159        .addArgument(() -> field.getDeclaringClass().getSimpleName())
160        .addArgument(() -> field.getType().getSimpleName())
161        .addArgument(field::getName)
162        .addArgument(instance)
163        .log();
164
165    var workspace = (Workspace) field.get(instance);
166    workspace.close();
167  }
168}