diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -14,6 +14,7 @@ #ifndef MLIR_SHAPE_IR_SHAPE_H #define MLIR_SHAPE_IR_SHAPE_H +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -18,6 +18,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" //===----------------------------------------------------------------------===// // Shape op definitions @@ -492,7 +493,7 @@ } def Shape_YieldOp : Shape_Op<"yield", - [HasParent<"ReduceOp">, + [HasParent<"ReduceOp, FunctionLibraryOp">, NoSideEffect, ReturnLike, Terminator]> { @@ -780,4 +781,59 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// Shape collection ops. +//===----------------------------------------------------------------------===// + +def Shape_FunctionLibraryOp : Shape_Op<"function_library", + [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, + SingleBlockImplicitTerminator<"ShapeFunctionLibraryTerminatorOp">]> { + let summary = "Represents shape functions and corresponding ops"; + let description = [{ + Represents a list of shape functions and the ops whose shape transfer + functions they represent. + + Example: + + ```mlir + shape.function_library { + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + return %0 : !shape.shape + } + } mapping { + std.atan = @same_result_shape + } + ``` + }]; + + let arguments = (ins OptionalAttr:$sym_name, + OptionalAttr:$sym_visibility); + let arguments = (ins DictionaryAttr:$mapping); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = [{ + attr-dict-with-keyword $body `mapping` $mapping + }]; + + let extraClassDeclaration = [{ + /// Returns an associated shape function for an operation if defined. + FuncOp getShapeFunction(Operation& op); + }]; +} + +//===----------------------------------------------------------------------===// +// ShapeFunctionLibraryTerminatorOp +//===----------------------------------------------------------------------===// + +def ShapeFunctionLibraryTerminatorOp : Shape_Op<"shape_fn_lib_terminator", + [Terminator, HasParent<"FunctionLibraryOp">]> { + let summary = "A pseudo op that marks the end of a shape function library"; + let description = [{ + `shape_fn_lib_terminator` is a special pseudo terminator operation for the + shape function library. It has no semantic meaning beyond keeping the body + well-formed. + }]; + let assemblyFormat = "attr-dict"; +} + #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Function.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Transforms/InliningUtils.h" @@ -558,6 +559,19 @@ return builder.getIndexTensorAttr(extents); } +//===----------------------------------------------------------------------===// +// FunctionLibraryOp +//===----------------------------------------------------------------------===// + +FuncOp FunctionLibraryOp::getShapeFunction(Operation &op) { + auto attr = mapping() + .get(op.getName().getIdentifier()) + .dyn_cast_or_null(); + if (!attr) + return nullptr; + return lookupSymbol(attr); +} + //===----------------------------------------------------------------------===// // GetExtentOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-shape-fn-report.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s --test-shape-function-report -verify-diagnostics + +// expected-remark@+1 {{associated shape function: same_result_shape}} +func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32> + attributes {shape.function = @same_result_shape} { + // expected-remark@+1 {{no associated way}} + %0 = tanh %arg : tensor<10x20xf32> + // expected-remark@+1 {{associated shape function: same_result_shape}} + %1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32> + return %1 : tensor<10x20xf32> +} + +// The shape function library with some local functions. +shape.function_library { + // Test shape function that returns the shape of input arg as result shape. + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + return %0 : !shape.shape + } +} mapping { + test.same_operand_result_type = @same_result_shape +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Affine) +add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(Test) add_subdirectory(Tosa) diff --git a/mlir/test/lib/Dialect/Shape/CMakeLists.txt b/mlir/test/lib/Dialect/Shape/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Shape/CMakeLists.txt @@ -0,0 +1,16 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRShapeTestPasses + TestShapeFunctions.cpp + + EXCLUDE_FROM_LIBMLIR + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRShape + MLIRSupport + ) diff --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp @@ -0,0 +1,77 @@ +//===- TestShapeFunctions.cpp - Passes to test shape function ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/IR/Function.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This is a pass that reports shape functions associated with ops. +struct ReportShapeFnPass + : public PassWrapper> { + void runOnOperation() override; +}; + +void ReportShapeFnPass::runOnOperation() { + auto module = getOperation(); + + // Lookup shape function library. + shape::FunctionLibraryOp shapeFnLib = nullptr; + auto res = module.getBodyRegion().walk([&](shape::FunctionLibraryOp lib) { + if (shapeFnLib) { + lib.emitError("duplicate shape library op") + .attachNote(shapeFnLib.getLoc()) + << "previous mapping"; + return mlir::WalkResult::interrupt(); + } + shapeFnLib = lib; + return mlir::WalkResult::advance(); + }); + if (res.wasInterrupted()) + return signalPassFailure(); + + // Report the shape function available to refine the op. + auto shapeFnId = Identifier::get("shape.function", &getContext()); + auto remarkShapeFn = [&](Operation &op) { + if (op.isKnownTerminator()) + return; + if (auto typeInterface = dyn_cast(&op)) { + op.emitRemark() << "implements InferType op interface"; + } else if (auto fn = shapeFnLib.getShapeFunction(op)) { + op.emitRemark() << "associated shape function: " << fn.getName(); + } else if (auto attr = op.getAttrOfType(shapeFnId)) { + auto fn = shapeFnLib.lookupSymbol(attr); + op.emitRemark() << "associated shape function: " << fn.getName(); + } else { + op.emitRemark() << "no associated way to refine shape"; + } + }; + + module.getBodyRegion().walk([&](FuncOp func) { + // Skip ops in the shape function library. + if (isa(func.getParentOp())) + return; + + func.walk([&](Operation *op) { remarkShapeFn(*op); }); + }); +} + +} // end anonymous namespace + +namespace mlir { +void registerShapeFunctionTestPasses() { + PassRegistration( + "test-shape-function-report", + "Test pass to report associated shape functions"); +} +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -133,6 +133,12 @@ let results = (outs AnySignlessInteger:$result); } +def SameOperandsResultType : TEST_Op< + "same_operand_result_type", [SameOperandsAndResultType]> { + let arguments = (ins AnyTensor:$operands); + let results = (outs AnyTensor:$result); +} + //===----------------------------------------------------------------------===// // Test Results //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -13,6 +13,7 @@ if(MLIR_INCLUDE_TESTS) set(test_libs MLIRAffineTransformsTestPasses + MLIRShapeTestPasses MLIRSPIRVTestPasses MLIRTestDialect MLIRTestIR diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -32,6 +32,7 @@ void registerConvertToTargetEnvPass(); void registerPassManagerTestPass(); void registerPrintOpAvailabilityPass(); +void registerShapeFunctionTestPasses(); void registerSideEffectTestPasses(); void registerSliceAnalysisTestPass(); void registerSymbolTestPasses(); @@ -96,6 +97,7 @@ registerConvertToTargetEnvPass(); registerPassManagerTestPass(); registerPrintOpAvailabilityPass(); + registerShapeFunctionTestPasses(); registerSideEffectTestPasses(); registerSliceAnalysisTestPass(); registerSymbolTestPasses();