diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -17,6 +17,7 @@ #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "llvm/Support/PointerLikeTypeTraits.h" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -17,6 +17,7 @@ include "mlir/IR/BuiltinDialect.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" // Base class for Builtin dialect ops. class Builtin_Op traits = []> : @@ -220,4 +221,49 @@ let assemblyFormat = "attr-dict"; } +//===----------------------------------------------------------------------===// +// PartialConversionCastOp +//===----------------------------------------------------------------------===// + +def PartialConversionCastOp : Builtin_Op<"partial_conversion_cast", [ + DeclareOpInterfaceMethods + ]> { + let summary = "An operation that opaquely converts from one type to another"; + let description = [{ + A `partial_conversion_cast` operation represents an opaque conversion from + one set of types to another, that is used to enable the inter-mixing of + different type systems. This operation should not be attributed any special + representational or execution semantics, and is generally only intended to + be used to satisfy the temporary intermixing of type systems during the + conversion of one type system to another. + + This operation may produce results of arity 1-N, and accept as input + operands of arity 0-N. + + Example: + + ```mlir + // An opaque 0-1 cast. + %result = partial_conversion_cast to f64 + + // An opaque 1-1 cast. + %result = partial_conversion_cast(%operand : i32) to f64 + + // An opaque 1-N cast. + %results:2 = partial_conversion_cast(%operand : i32) to i32, i32 + + // An opaque N-1 cast. + %result = partial_conversion_cast(%operand, %operand : i32, i32) to i32 + ``` + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); + let assemblyFormat = [{ + (`(` $inputs^ `:` type($inputs) `)`)? `to` type($outputs) attr-dict + }]; + let hasCanonicalizer = 1; + let hasFolder = 1; +} + #endif // BUILTIN_OPS diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/MapVector.h" using namespace mlir; @@ -236,6 +237,73 @@ return success(); } +//===----------------------------------------------------------------------===// +// PartialConversionCastOp +//===----------------------------------------------------------------------===// + +namespace { +/// Fold dead partial conversions that no longer have uses. +/// `PartialConversionCastOp` does not provide any information about +/// effects(purposefully to remain opaque to users), so we need an explicit +/// canonicalization to erase it if unused. +struct SimplifyDeadPartialConversionCastOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PartialConversionCastOp op, + PatternRewriter &rewriter) const override { + if (op.use_empty()) { + rewriter.eraseOp(op); + return success(); + } + return failure(); + } +}; +} // end anonymous namespace + +void PartialConversionCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +/// Fold the case where all inputs originate from another partial conversion +/// cast operation of the same type, with input types that match this +/// operation's results. +static LogicalResult +foldPartialCastOpPassThrough(OperandRange operands, ResultRange results, + SmallVectorImpl &foldResults) { + // Check that the input is a cast with results that all feed into this + // operation, and operand types that directly match the result types of this + // operation. + Value firstInput = operands.front(); + auto inputOp = firstInput.getDefiningOp(); + if (!inputOp || inputOp.getResults() != operands || + inputOp.getOperandTypes() != results.getTypes()) + return failure(); + + // If everything matches up, we can fold the passthrough. + foldResults.append(inputOp->operand_begin(), inputOp->operand_end()); + return success(); +} + +LogicalResult +PartialConversionCastOp::fold(ArrayRef attrOperands, + SmallVectorImpl &foldResults) { + OperandRange operands = inputs(); + if (operands.empty()) + return failure(); + if (succeeded(foldPartialCastOpPassThrough(operands, outputs(), foldResults))) + return success(); + + return failure(); +} + +bool PartialConversionCastOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + // `PartialConversionCastOp` is agnostic of the input/output types. + return true; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -37,6 +37,7 @@ MLIRBuiltinOpsIncGen MLIRBuiltinTypesIncGen MLIRCallInterfacesIncGen + MLIRCastInterfacesIncGen MLIROpAsmInterfaceIncGen MLIRRegionKindInterfaceIncGen MLIRSymbolInterfacesIncGen diff --git a/mlir/test/Dialect/Builtin/canonicalize.mlir b/mlir/test/Dialect/Builtin/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Builtin/canonicalize.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s -canonicalize | FileCheck %s + +//===----------------------------------------------------------------------===// +// PartialConversionCastOp +//===----------------------------------------------------------------------===// + +// Test multiple pass through partial conversion cast. +// CHECK-LABEL: func @multiple_conversion_casts +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: +func @multiple_conversion_casts(%arg0: i32, %arg1: i32) -> (i32, i32) { + // CHECK-NOT: partial_conversion_cast + // CHECK: return %[[ARG0]], %[[ARG1]] + %inputs:2 = partial_conversion_cast (%arg0, %arg1 : i32, i32) to i64, i64 + %outputs:2 = partial_conversion_cast (%inputs#0, %inputs#1 : i64, i64) to i32, i32 + return %outputs#0, %outputs#1 : i32, i32 +} + +// CHECK-LABEL: func @multiple_conversion_casts +func @multiple_conversion_casts_failure(%arg0: i32, %arg1: i32, %arg2: i64) -> (i32, i32) { + // CHECK: partial_conversion_cast + // CHECK: partial_conversion_cast + %inputs:2 = partial_conversion_cast (%arg0, %arg1 : i32, i32) to i64, i64 + %outputs:2 = partial_conversion_cast (%arg2, %inputs#1 : i64, i64) to i32, i32 + return %outputs#0, %outputs#1 : i32, i32 +} diff --git a/mlir/test/Dialect/Builtin/invalid.mlir b/mlir/test/Dialect/Builtin/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Builtin/invalid.mlir @@ -0,0 +1,7 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+1 {{expected at least one result for cast operation}} +"partial_conversion_cast"() : () -> () + +// ----- + diff --git a/mlir/test/Dialect/Builtin/ops.mlir b/mlir/test/Dialect/Builtin/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Builtin/ops.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect | mlir-opt -allow-unregistered-dialect + +//===----------------------------------------------------------------------===// +// PartialConversionCastOp +//===----------------------------------------------------------------------===// + +%operand = "foo.op"() : () -> i32 + +// An opaque 0-1 cast. +%result = partial_conversion_cast to f64 + +// An opaque 1-1 cast. +%result1 = partial_conversion_cast(%operand : i32) to f64 + +// An opaque 1-N cast. +%results2:2 = partial_conversion_cast(%operand : i32) to i32, i32 + +// An opaque N-1 cast. +%result3 = partial_conversion_cast(%operand, %operand : i32, i32) to i32