diff --git a/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h b/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h --- a/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h +++ b/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h @@ -28,6 +28,10 @@ /// Return the modify-reference behavior of `op` on `location`. ModRefResult getModRef(Operation *op, Value location); + +protected: + /// Given the two values, return their aliasing behavior. + virtual AliasResult aliasImpl(Value lhs, Value rhs); }; } // namespace mlir diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -246,7 +246,7 @@ } /// Given the two values, return their aliasing behavior. -static AliasResult aliasImpl(Value lhs, Value rhs) { +AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { if (lhs == rhs) return AliasResult::MustAlias; Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr; diff --git a/mlir/test/Analysis/test-alias-analysis-extending.mlir b/mlir/test/Analysis/test-alias-analysis-extending.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-alias-analysis-extending.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-alias-analysis-extending))' -split-input-file -allow-unregistered-dialect 2>&1 | FileCheck %s + +// CHECK-LABEL: Testing : "restrict" +// CHECK-DAG: func.region0#0 <-> func.region0#1: NoAlias + +// CHECK-DAG: view1#0 <-> view2#0: NoAlias +// CHECK-DAG: view1#0 <-> func.region0#0: MustAlias +// CHECK-DAG: view1#0 <-> func.region0#1: NoAlias +// CHECK-DAG: view2#0 <-> func.region0#0: NoAlias +// CHECK-DAG: view2#0 <-> func.region0#1: MustAlias +func.func @restrict(%arg: memref, %arg1: memref {local_alias_analysis.restrict}) attributes {test.ptr = "func"} { + %0 = memref.subview %arg[0][2][1] {test.ptr = "view1"} : memref to memref<2xf32> + %1 = memref.subview %arg1[0][2][1] {test.ptr = "view2"} : memref to memref<2xf32> + return +} diff --git a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp --- a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp +++ b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp @@ -13,6 +13,8 @@ #include "TestAliasAnalysis.h" #include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -148,6 +150,67 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Testing LocalAliasAnalysis extending +//===----------------------------------------------------------------------===// + +/// Check if value is function argument. +static bool isFuncArg(Value val) { + auto blockArg = val.dyn_cast(); + if (!blockArg) + return false; + + return mlir::isa_and_nonnull( + blockArg.getOwner()->getParentOp()); +} + +/// Check if value has "restrict" attribute. Value must be a function argument. +static bool isRestrict(Value val) { + auto blockArg = val.cast(); + auto func = + mlir::cast(blockArg.getOwner()->getParentOp()); + return !!func.getArgAttr(blockArg.getArgNumber(), + "local_alias_analysis.restrict"); +} + +namespace { +/// LocalAliasAnalysis extended to support "restrict" attreibute. +class LocalAliasAnalysisRestrict : public LocalAliasAnalysis { +protected: + AliasResult aliasImpl(Value lhs, Value rhs) override { + if (lhs == rhs) + return AliasResult::MustAlias; + + // Assume no aliasing if both values are function arguments and any of them + // have restrict attr. + if (isFuncArg(lhs) && isFuncArg(rhs)) + if (isRestrict(lhs) || isRestrict(rhs)) + return AliasResult::NoAlias; + + return LocalAliasAnalysis::aliasImpl(lhs, rhs); + } +}; + +/// This pass tests adding additional analysis impls to the AliasAnalysis. +struct TestAliasAnalysisExtendingPass + : public test::TestAliasAnalysisBase, + PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisExtendingPass) + + StringRef getArgument() const final { + return "test-alias-analysis-extending"; + } + StringRef getDescription() const final { + return "Test alias analysis extending."; + } + void runOnOperation() override { + AliasAnalysis aliasAnalysis(getOperation()); + aliasAnalysis.addAnalysisImplementation(LocalAliasAnalysisRestrict()); + runAliasAnalysisOnOperation(getOperation(), aliasAnalysis); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Pass Registration //===----------------------------------------------------------------------===// @@ -155,8 +218,9 @@ namespace mlir { namespace test { void registerTestAliasAnalysisPass() { - PassRegistration(); + PassRegistration(); PassRegistration(); + PassRegistration(); } } // namespace test } // namespace mlir