diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2100,4 +2100,79 @@ }]; } +//===----------------------------------------------------------------------===// +// TypeRelaxationOp +//===----------------------------------------------------------------------===// + +def TypeRelaxationOp : Std_Op<"type_relaxation", [NoSideEffect]> { + let summary = "Relax the value type."; + let description = [{ + It is typically introduced by the type inference pass to satisfy type + constraints, when a less specialized type is required. + + Example: + ``` + %value = "getValue"() : () -> (tensor<1x2xi32>) + %relaxed_value = type_relaxation %value : tensor<1x2xi32> to tensor + "op_disallowing_input_refinement"(%relaxed_value) : (tensor<*xi32>) -> () + ``` + }]; + + let arguments = (ins AnyType:$in); + let results = (outs AnyType:$out); + + let builders = [ + OpBuilder<(ins "Value":$source, "Type":$destType), [{ + impl::buildCastOp($_builder, $_state, source, destType); + }]> + ]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; + + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// TypeSpecializationOp +//===----------------------------------------------------------------------===// + +def TypeSpecializationOp : Std_Op<"type_specialization", [NoSideEffect]> { + let summary = "Specialize the value type."; + let description = [{ + It is typically introduced by the type inference pass to satisfy type + constraints, when a less specialized type is required. + + Example: + ``` + %join = "join_types"(%x, %y) {allowOutputRefinement = false} : (tensor<1x?xi32>, tensor) -> tensor<*xi32> + %specialized_join = type_specialization %join : tensor<*xi32> to tensor + "op_allowing_input_refinement"(%specialized_join) : (tensor) -> () + ``` + }]; + + let arguments = (ins AnyType:$in); + let results = (outs AnyType:$out); + + let builders = [ + OpBuilder<(ins "Value":$source, "Type":$destType), [{ + impl::buildCastOp($_builder, $_state, source, destType); + }]> + ]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; + + let hasFolder = 1; +} + + #endif // STANDARD_OPS diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -178,4 +178,45 @@ ]; } +def AllowsInputTypesRefinementInterface : + OpInterface<"AllowsInputTypesRefinementInterface"> { + let description = [{ + Interface to specify what inputs may be specialized. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{Indicate whether the specified operand type may be specialized}], + /*retTy=*/"bool", + /*methodName=*/"allowsInputTypeRefinement", + /*args=*/(ins "unsigned":$operandIndex), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{return true;}] + >, + ]; +} + +def AllowsOutputTypesRefinementInterface : + OpInterface<"AllowsOutputTypesRefinementInterface"> { + let description = [{ + Interface to specify what outputs may be specialized. + + Support for output types refinement is by default conditional on support for + `InferTypeOpInterface`. + When implemented, this interface allows overriding this default to + explicitly specify what outputs may be specialized. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{Indicate whether the specified output type can be specialized}], + /*retTy=*/"bool", + /*methodName=*/"allowsOutputTypeRefinement", + /*args=*/(ins "unsigned":$outputIndex), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{return true;}] + >, + ]; +} + #endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -137,6 +137,9 @@ /// (identity) layout map. std::unique_ptr> createNormalizeMemRefsPass(); +/// Creates a pass to infer operation types. +std::unique_ptr createTypeInferencePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -691,6 +691,89 @@ let constructor = "mlir::createSymbolDCEPass()"; } +def TypeInferencePass : Pass<"infer-types"> { + let summary = "Infer the types of values"; + let description = [{ + This pass propagates type information through the IR to refine value types. + + The operation interface `InferTypeOpInterface` allows expressing + per-operation type inference logic. + + Support for input types refinement is fully controlled by + `AllowsInputTypesRefinementInterface`. + + Support for ouptut types refinement is explicit if + `AllowsOuptputTypesRefinementInterface` is implemented. If not, it is + implicit if `InferTypeOpInterface` is implemented. + + For example, running type inference on the input + + ``` + func @test_input_refinement(%x : tensor<1x?xi32>, %y : tensor) { + %tmp = "test.ti.join_types"(%x, %y) : (tensor<1x?xi32>, tensor) -> tensor<*xi32> + "test.ti.disallow_input_refinement"(%tmp) : (tensor<*xi32>) -> () + "test.ti.allow_input_refinement"(%tmp) : (tensor<*xi32>) -> () + "test.ti.conditionally_allow_input_refinement"(%tmp, %tmp) : (tensor<*xi32>, tensor<*xi32>) -> () + return + } + ``` + + will output + + ``` + func @test_input_refinement(%arg0: tensor<1x?xi32>, %arg1: tensor) { + %0 = "test.ti.join_types"(%arg0, %arg1) : (tensor<1x?xi32>, tensor) -> tensor + %1 = type_relaxation %0 : tensor to tensor<*xi32> + "test.ti.disallow_input_refinement"(%1) : (tensor<*xi32>) -> () + "test.ti.allow_input_refinement"(%0) : (tensor) -> () + "test.ti.conditionally_allow_input_refinement"(%0, %1) : (tensor, tensor<*xi32>) -> () + return + } + ``` + . + + And running on the input + + ``` + func @test_branch_join(%cond: i1, %x : tensor<1xi32>, %y : tensor<2xi32>) { + cond_br %cond, ^bb1, ^bb2 + ^bb1: + %val_true = "test.ti.join_types"(%x) : (tensor<1xi32>) -> tensor<*xi32> + br ^bb3(%val_true : tensor<*xi32>) + ^bb2: + %val_false = "test.ti.join_types"(%y) : (tensor<2xi32>) -> tensor<*xi32> + br ^bb3(%val_false : tensor<*xi32>) + ^bb3(%val: tensor<*xi32>): + "test.ti.allow_input_refinement"(%val) : (tensor<*xi32>) -> () + "test.ti.disallow_input_refinement"(%val) : (tensor<*xi32>) -> () + return + } + ``` + + will output + + ``` + func @test_branch_join(%arg0: i1, %arg1: tensor<1xi32>, %arg2: tensor<2xi32>) { + cond_br %arg0, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + %0 = "test.ti.join_types"(%arg1) : (tensor<1xi32>) -> tensor<1xi32> + %1 = type_relaxation %0 : tensor<1xi32> to tensor<*xi32> + br ^bb3(%1 : tensor<*xi32>) + ^bb2: // pred: ^bb0 + %2 = "test.ti.join_types"(%arg2) : (tensor<2xi32>) -> tensor<2xi32> + %3 = type_relaxation %2 : tensor<2xi32> to tensor<*xi32> + br ^bb3(%3 : tensor<*xi32>) + ^bb3(%4: tensor<*xi32>): // 2 preds: ^bb1, ^bb2 + %5 = type_specialization %4 : tensor<*xi32> to tensor + "test.ti.allow_input_refinement"(%5) : (tensor) -> () + "test.ti.disallow_input_refinement"(%4) : (tensor<*xi32>) -> () + return + ``` + . + }]; + let constructor = "mlir::createTypeInferencePass()"; +} + def ViewOpGraphPass : Pass<"view-op-graph"> { let summary = "Print Graphviz visualization of an operation"; let description = [{ diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -13,6 +13,7 @@ MLIRCastInterfaces MLIRControlFlowInterfaces MLIRIR + MLIRJoinMeetTypeInterface MLIRSideEffectInterfaces MLIRVectorInterfaces ) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/APFloat.h" @@ -2807,6 +2808,36 @@ return success(); } +//===----------------------------------------------------------------------===// +// TypeRelaxationOp +//===----------------------------------------------------------------------===// + +OpFoldResult TypeRelaxationOp::fold(ArrayRef operands) { + return impl::foldCastOp(*this); +} + +static LogicalResult verify(TypeRelaxationOp op) { + if (!isLessSpecialized(op.getType(), op.in().getType())) + return op.emitOpError( + "output type is not less specialized than the input type"); + return success(); +} + +//===----------------------------------------------------------------------===// +// TypeSpecializationOp +//===----------------------------------------------------------------------===// + +OpFoldResult TypeSpecializationOp::fold(ArrayRef operands) { + return impl::foldCastOp(*this); +} + +static LogicalResult verify(TypeSpecializationOp op) { + if (!isMoreSpecialized(op.getType(), op.in().getType())) + return op.emitOpError( + "output type is not more specialized than the input type"); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -20,6 +20,7 @@ SCCP.cpp StripDebugInfo.cpp SymbolDCE.cpp + TypeInference.cpp ViewOpGraph.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Transforms/TypeInference.cpp b/mlir/lib/Transforms/TypeInference.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/TypeInference.cpp @@ -0,0 +1,351 @@ +//===- TypeInference.cpp - Infer Types of MLIR operations -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The type inference analysis uses the dataflow analysis infrastructure to +// infer MLIR value types over the IR. +// The type inference pass uses the results of the analysis to update MLIR value +// types, querying operation interfaces to appropriately update types in place +// or insert type relaxation or specialization ops. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Format.h" + +#define DEBUG_TYPE "type-inference" + +using namespace mlir; +using llvm::dbgs; + +namespace { +struct TypeLatticeValue { + TypeLatticeValue() = default; + TypeLatticeValue(Type type) : type(type) {} + + static TypeLatticeValue getPessimisticValueState(MLIRContext *context) { + return TypeLatticeValue(); + } + + static TypeLatticeValue getPessimisticValueState(Value value) { + return value.getType(); + } + + static TypeLatticeValue join(const TypeLatticeValue &lhs, + const TypeLatticeValue &rhs) { + return joinTypes(lhs.type, rhs.type); + } + + bool operator==(const TypeLatticeValue &rhs) const { + return type == rhs.type; + } + + Type type; +}; + +struct TypeInferenceAnalysis + : public ForwardDataFlowAnalysis { +public: + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + TypeInferenceAnalysis(MLIRContext *context) + : ForwardDataFlowAnalysis(context), status(success()) {} + ~TypeInferenceAnalysis() override = default; + + ChangeResult visitOperation( + Operation *op, + ArrayRef *> operands) override; + + LogicalResult getStatus() const { return status; } + +private: + ChangeResult visitInferTypeOpInterface( + Operation *op, ArrayRef *> operands); + ChangeResult visitSameOperandsAndResultShape( + Operation *op, ArrayRef *> operands); + ChangeResult visitSameOperandsAndResultType( + Operation *op, ArrayRef *> operands); + + // A wrapper allowing tracking updates for debugging purposes. + ChangeResult updateLatticeElement(Value value, Type ty) { + LatticeElement &latticeElement = getLatticeElement(value); + LLVM_DEBUG(dbgs() << "update lattice element for :\t"; value.print(dbgs()); + dbgs() << "\n\toriginal type:\t"; + if (latticeElement.isUninitialized()) dbgs() << "uninitialized"; + else latticeElement.getValue().type.print(dbgs()); + dbgs() << "\n\tjoining with:\t"; ty.print(dbgs());); + ChangeResult change = latticeElement.join(ty); + LLVM_DEBUG(dbgs() << "\n\tyielding:\t"; + latticeElement.getValue().type.print(dbgs()); dbgs() << "\n"); + return change; + } + + LogicalResult status; +}; + +ChangeResult TypeInferenceAnalysis::visitOperation( + Operation *op, ArrayRef *> operands) { + // Do not continue with the analysis if something went wrong. + // TODO: Having a way to interrupt the analysis on error would be + // convenient. + if (LLVM_UNLIKELY(failed(status))) + return ChangeResult::NoChange; + + if (auto relaxation = dyn_cast(op)) { + // We cannot specialiaze the output type of `type_relaxation`, but we want + // to propagate the right type in the analysis. + return updateLatticeElement(relaxation.out(), operands[0]->getValue().type); + } + + // Handle different ways to infer the result types. + // Process in order of "strongest" (e.g. full type vs shape) and "simplest" + // (e.g. `SameOperandsAndResultType` vs `InferTypeOpInterface`). + if (op->hasTrait()) + return visitSameOperandsAndResultType(op, operands); + if (isa(op)) + return visitInferTypeOpInterface(op, operands); + if (op->hasTrait()) + return visitSameOperandsAndResultShape(op, operands); + // TODO: Extend support to other interfaces. + + // By default, simply use the current result types. + ChangeResult change = ChangeResult::NoChange; + for (Value result : op->getResults()) + change |= updateLatticeElement(result, result.getType()); + return change; +} + +ChangeResult TypeInferenceAnalysis::visitInferTypeOpInterface( + Operation *op, ArrayRef *> operands) { + // Ephemerally override operand types with inferred types. + llvm::SmallVector originalOperandTys; + const unsigned numOperands = op->getNumOperands(); + originalOperandTys.resize(numOperands); + for (unsigned i = 0; i != numOperands; ++i) { + auto operand = op->getOperand(i); + originalOperandTys[i] = operand.getType(); + operand.setType(operands[i]->getValue().type); + } + + // Infer types for the operation. + SmallVector inferredReturnTypes; + status = cast(op).inferReturnTypes( + op->getContext(), op->getLoc(), op->getOperands(), + op->getAttrDictionary(), op->getRegions(), inferredReturnTypes); + + // Immediately restore original operand types. + for (unsigned i = 0; i != numOperands; ++i) + op->getOperand(i).setType(originalOperandTys[i]); + + if (failed(status)) { + op->emitError("failed to infer types"); + return ChangeResult::NoChange; + } + + ChangeResult change = ChangeResult::NoChange; + + for (unsigned int i = 0; i < op->getNumResults(); i++) + change |= updateLatticeElement(op->getResult(i), inferredReturnTypes[i]); + + return change; +} + +ChangeResult TypeInferenceAnalysis::visitSameOperandsAndResultType( + Operation *op, ArrayRef *> operands) { + if (operands.empty()) + return ChangeResult::NoChange; + + Type ty = operands[0]->getValue().type; + for (auto *operand : operands) + ty = meetTypes(ty, operand->getValue().type); + + ChangeResult change = ChangeResult::NoChange; + for (auto result : op->getResults()) + change |= updateLatticeElement(result, ty); + + return change; +} + +ChangeResult TypeInferenceAnalysis::visitSameOperandsAndResultShape( + Operation *op, ArrayRef *> operands) { + if (operands.empty()) + return ChangeResult::NoChange; + + // TODO: Introduce helper so that we can handle scalars. + SmallVector shape; + ArrayRef initialShape = + operands[0]->getValue().type.cast().getShape(); + shape.insert(shape.end(), initialShape.begin(), initialShape.end()); + + for (auto *operand : operands) + shape = *meetShapes(shape, + operand->getValue().type.cast().getShape()); + + ChangeResult change = ChangeResult::NoChange; + for (auto result : op->getResults()) + change |= updateLatticeElement( + result, result.getType().cast().clone(shape)); + + return change; +} + +} // end anonymous namespace + +namespace { +class TypeInference { +public: + explicit TypeInference(Operation *op) : op(op), analysis(op->getContext()) {} + + /// Analyzes and update types for the processed operation. + LogicalResult run() { + analysis.run(op); + if (failed(analysis.getStatus())) + return analysis.getStatus(); + return updateOp(op); + } + +private: + LogicalResult updateOp(Operation *op); + LogicalResult updateBlock(Block &block); + LogicalResult updateOpResults(Operation *op); + LogicalResult updateType(Value result); + + Operation *op; + TypeInferenceAnalysis analysis; +}; + +LogicalResult TypeInference::updateOp(Operation *op) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + if (failed(updateBlock(block))) + return failure(); + } + } + return updateOpResults(op); +} + +LogicalResult TypeInference::updateBlock(Block &block) { + for (auto blockArg : block.getArguments()) + if (failed(updateType(blockArg))) + return failure(); + for (auto &nestedOp : block) + if (failed(updateOp(&nestedOp))) + return failure(); + return success(); +} + +LogicalResult TypeInference::updateOpResults(Operation *op) { + for (auto result : op->getResults()) + if (failed(updateType(result))) + return failure(); + return success(); +} + +LogicalResult TypeInference::updateType(Value value) { + LatticeElement *latticeElement = + analysis.lookupLatticeElement(value); + if (!latticeElement) + return success(); + + Type originalTy = value.getType(); + Type inferredTy = latticeElement->getValue().type; + if (inferredTy == originalTy) + return success(); + + Operation *definingOp = value.getDefiningOp(); + + // Never specialize `type_relaxation`. + if (dyn_cast_or_null(definingOp)) + return success(); + + if (!isMoreSpecialized(inferredTy, originalTy)) + return emitError(value.getLoc(), "inferred type ") + << inferredTy << " is not more specialized than the original type" + << originalTy; + + // Values and uses may or may not allow specializing input or output types. + // Keep track of the "value" with its original type, and after type + // specialization. + Value valueWithOriginalType = value; + Value valueWithInferredType = {}; + + // Attempt to specialize the output type. + auto interface = + dyn_cast_or_null(definingOp); + bool explicitlyAllowedOutputTypeRefinement = + interface && interface.allowsOutputTypeRefinement( + value.cast().getResultNumber()); + bool implicitlyAllowedOutputTypeRefinement = + !interface && dyn_cast_or_null(definingOp); + if (explicitlyAllowedOutputTypeRefinement || + implicitlyAllowedOutputTypeRefinement) { + value.setType(inferredTy); + valueWithOriginalType = {}; + valueWithInferredType = value; + } + + // Update uses with the appropriate "value", specialized or not. + auto getValueWithOriginalType = [&]() -> Value { + if (!valueWithOriginalType) { + OpBuilder builder(value.getContext()); + builder.setInsertionPointAfterValue(value); + valueWithOriginalType = builder.createOrFold( + value.getLoc(), value, originalTy); + } + return valueWithOriginalType; + }; + auto getValueWithSpecializedType = [&]() -> Value { + if (!valueWithInferredType) { + OpBuilder builder(value.getContext()); + builder.setInsertionPointAfterValue(value); + valueWithInferredType = builder.createOrFold( + value.getLoc(), value, inferredTy); + } + return valueWithInferredType; + }; + unsigned nSpecializedUses = 0; + for (auto &use : llvm::make_early_inc_range(value.getUses())) { + auto interface = + dyn_cast(use.getOwner()); + bool allowsInputTypeRefinement = + interface && + interface.allowsInputTypeRefinement(use.getOperandNumber()); + use.set(allowsInputTypeRefinement ? getValueWithSpecializedType() + : getValueWithOriginalType()); + nSpecializedUses += allowsInputTypeRefinement; + } + + LLVM_DEBUG(dbgs() << "updated type from "; originalTy.print(dbgs()); + dbgs() << " to "; inferredTy.print(dbgs()); + dbgs() << llvm::format( + " (with %u/%u specialized uses) for ", nSpecializedUses, + std::distance(value.getUses().begin(), value.getUses().end())); + value.print(dbgs()); dbgs() << "\n"); + + return success(); +} + +struct TypeInferencePass : public TypeInferencePassBase { + void runOnOperation() override { + TypeInference typeInference(getOperation()); + if (failed(typeInference.run())) + signalPassFailure(); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::createTypeInferencePass() { + return std::make_unique(); +} diff --git a/mlir/test/Interfaces/ControlFlowInterfaces/operand-types.mlir b/mlir/test/Interfaces/ControlFlowInterfaces/operand-types.mlir --- a/mlir/test/Interfaces/ControlFlowInterfaces/operand-types.mlir +++ b/mlir/test/Interfaces/ControlFlowInterfaces/operand-types.mlir @@ -37,7 +37,7 @@ return ^bb2: - "test.cfe.br"(%0)[^bb3] : (tensor<1xi32>) -> () + "test.ti.br"(%0)[^bb3] : (tensor<1xi32>) -> () ^bb3(%arg2: tensor): return } @@ -47,7 +47,7 @@ func @succ_type_dynamic_to_fixed() { ^bb0: %0 = "getValue"() : () -> tensor - "test.cfe.br"(%0)[^bb1] : (tensor) -> () + "test.ti.br"(%0)[^bb1] : (tensor) -> () // expected-error@-1 {{type mismatch for bb argument #0 of successor #0}} ^bb1(%arg3: tensor<1xi32>): return @@ -95,17 +95,17 @@ test.region_if_yield %arg1 : memref } - %tmp2 = test.cfe.region_if %0 : memref<1xi32> -> (memref) then { + %tmp2 = test.ti.region_if %0 : memref<1xi32> -> (memref) then { ^bb0(%arg1 : memref<1xi32>): %true_value = "getValue"(%arg1) : (memref<1xi32>) -> (memref<2xi32>) - test.cfe.region_if_yield %true_value : memref<2xi32> + test.ti.region_if_yield %true_value : memref<2xi32> } else { ^bb0(%arg1 : memref): %false_value = "getValue"(%arg1) : (memref) -> (memref<3xi32>) - test.cfe.region_if_yield %false_value : memref<3xi32> + test.ti.region_if_yield %false_value : memref<3xi32> } join { ^bb0(%arg1 : memref): - test.cfe.region_if_yield %arg1 : memref + test.ti.region_if_yield %arg1 : memref } return diff --git a/mlir/test/Transforms/type-inference.mlir b/mlir/test/Transforms/type-inference.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/type-inference.mlir @@ -0,0 +1,319 @@ +// RUN: mlir-opt -allow-unregistered-dialect -infer-types %s -split-input-file -o %t1 +// RUN: FileCheck %s < %t1 +// Verify idempotence. +// RUN: mlir-opt -allow-unregistered-dialect -infer-types %t1 -split-input-file -o %t2 +// RUN: FileCheck %s < %t2 + +func @test_input_refinement(%x : tensor<1x?xi32>, %y : tensor) { + %tmp = "test.ti.join_types"(%x, %y) : (tensor<1x?xi32>, tensor) -> tensor<*xi32> + "test.ti.disallow_input_refinement"(%tmp) : (tensor<*xi32>) -> () + "test.ti.allow_input_refinement"(%tmp) : (tensor<*xi32>) -> () + "test.ti.conditionally_allow_input_refinement"(%tmp, %tmp) : (tensor<*xi32>, tensor<*xi32>) -> () + return +} + +// CHECK: [[TMP:%.+]] = "test.ti.join_types"{{.+}} : (tensor<1x?xi32>, tensor) -> tensor +// CHECK: [[RELAXED_TMP:%.+]] = type_relaxation [[TMP]] : tensor to tensor<*xi32> +// CHECK: "test.ti.disallow_input_refinement"([[RELAXED_TMP]]) : (tensor<*xi32>) -> () +// CHECK: "test.ti.allow_input_refinement"([[TMP]]) : (tensor) -> () +// CHECK: "test.ti.conditionally_allow_input_refinement"([[TMP]], [[RELAXED_TMP]]) : (tensor, tensor<*xi32>) -> () + +// ----- + +func @test_output_refinement(%x : tensor<1x?xi32>, %y : tensor) { + %join1 = "test.ti.join_types"(%x, %y) { allowOutputRefinement = false } : (tensor<1x?xi32>, tensor) -> tensor<*xi32> + "test.ti.allow_input_refinement"(%join1) : (tensor<*xi32>) -> () + + %join2 = "test.ti.join_types"(%x, %y) { allowOutputRefinement = true } : (tensor<1x?xi32>, tensor) -> tensor<*xi32> + "test.ti.allow_input_refinement"(%join2) : (tensor<*xi32>) -> () + return +} + +// CHECK: [[JOIN1:%.+]] = "test.ti.join_types"({{.+}}) {allowOutputRefinement = false} : (tensor<1x?xi32>, tensor) -> tensor<*xi32> +// CHECK: [[SPECIALIZED_JOIN1:%.+]] = type_specialization [[JOIN1]] : tensor<*xi32> to tensor +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_JOIN1]]) : (tensor) -> () +// CHECK: [[JOIN2:%.+]] = "test.ti.join_types"({{.+}}) {allowOutputRefinement = true} : (tensor<1x?xi32>, tensor) -> tensor +// CHECK: "test.ti.allow_input_refinement"([[JOIN2]]) : (tensor) -> () + +// ----- + +func @test_relaxation_propagation(%value : tensor<1xi32>) { + %relaxed_value = type_relaxation %value : tensor<1xi32> to tensor<*xi32> + %res = "test.ti.join_types"(%relaxed_value) : (tensor<*xi32>) -> tensor<*xi32> + return +} + +// CHECK: func @test_relaxation_propagation([[VALUE:%.+]]: tensor<1xi32>) { +// CHECK: [[RELAXED_VALUE:%.+]] = type_relaxation [[VALUE]] : tensor<1xi32> to tensor<*xi32> +// CHECK: %1 = "test.ti.join_types"([[RELAXED_VALUE]]) : (tensor<*xi32>) -> tensor<1xi32> +// CHECK: return +// CHECK: } + +// ----- + +func @test_branch_disallowing_input_refinement(%x : tensor<1x?xi32>, %y : tensor) { + %tmp = "test.ti.join_types"(%x, %y) : (tensor<1x?xi32>, tensor) -> tensor<*xi32> + "test.br"(%tmp)[^bb1] : (tensor<*xi32>) -> () +^bb1(%arg: tensor<*xi32>): + "test.ti.disallow_input_refinement"(%arg) : (tensor<*xi32>) -> () + "test.ti.allow_input_refinement"(%arg) : (tensor<*xi32>) -> () + "test.ti.conditionally_allow_input_refinement"(%arg, %arg) : (tensor<*xi32>, tensor<*xi32>) -> () + return +} + +// CHECK: [[TMP:%.+]] = "test.ti.join_types"({{.+}}) : (tensor<1x?xi32>, tensor) -> tensor +// CHECK: [[RELAXED_TMP:%.+]] = type_relaxation [[TMP]] : tensor to tensor<*xi32> +// CHECK: "test.br"([[RELAXED_TMP]])[^bb1] : (tensor<*xi32>) -> () +// CHECK: ^bb1([[ARG:%.+]]: tensor<*xi32>): // pred: ^bb0 +// CHECK: [[SPECIALIZED_ARG:%.+]] = type_specialization [[ARG]] : tensor<*xi32> to tensor +// CHECK: "test.ti.disallow_input_refinement"([[ARG]]) : (tensor<*xi32>) -> () +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_ARG]]) : (tensor) -> () +// CHECK: "test.ti.conditionally_allow_input_refinement"([[SPECIALIZED_ARG]], [[ARG]]) : (tensor, tensor<*xi32>) -> () +// CHECK: return + +// ----- + +func @test_branch_allowing_input_refinement(%x : tensor<1x?xi32>, %y : tensor) { + %tmp = "test.ti.join_types"(%x, %y) : (tensor<1x?xi32>, tensor) -> tensor<*xi32> + "test.ti.br"(%tmp)[^bb1] : (tensor<*xi32>) -> () +^bb1(%arg: tensor<*xi32>): + "test.ti.disallow_input_refinement"(%arg) : (tensor<*xi32>) -> () + "test.ti.allow_input_refinement"(%arg) : (tensor<*xi32>) -> () + "test.ti.conditionally_allow_input_refinement"(%arg, %arg) : (tensor<*xi32>, tensor<*xi32>) -> () + return +} + +// CHECK: [[TMP:%.+]] = "test.ti.join_types"({{.+}}) : (tensor<1x?xi32>, tensor) -> tensor +// CHECK: "test.ti.br"([[TMP]])[^bb1] : (tensor) -> () +// CHECK: ^bb1([[ARG:%.+]]: tensor<*xi32>): // pred: ^bb0 +// CHECK: [[SPECIALIZED_ARG:%.+]] = type_specialization [[ARG]] : tensor<*xi32> to tensor +// CHECK: "test.ti.disallow_input_refinement"([[ARG]]) : (tensor<*xi32>) -> () +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_ARG]]) : (tensor) -> () +// CHECK: "test.ti.conditionally_allow_input_refinement"([[SPECIALIZED_ARG]], [[ARG]]) : (tensor, tensor<*xi32>) -> () +// CHECK: return + +// ----- + +func @test_branch_join(%cond: i1, %x : tensor<1xi32>, %y : tensor<2xi32>) { + cond_br %cond, ^bb1, ^bb2 +^bb1: + %val_true = "test.ti.join_types"(%x) : (tensor<1xi32>) -> tensor<*xi32> + br ^bb3(%val_true : tensor<*xi32>) +^bb2: + %val_false = "test.ti.join_types"(%y) : (tensor<2xi32>) -> tensor<*xi32> + br ^bb3(%val_false : tensor<*xi32>) +^bb3(%val: tensor<*xi32>): + "test.ti.allow_input_refinement"(%val) : (tensor<*xi32>) -> () + "test.ti.disallow_input_refinement"(%val) : (tensor<*xi32>) -> () + return +} + +// CHECK: func @test_branch_join([[COND:%.+]]: i1, [[X:%.+]]: tensor<1xi32>, [[Y:%.+]]: tensor<2xi32>) { +// CHECK: cond_br [[COND]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: [[JOIN:%.+]] = "test.ti.join_types"([[X]]) : (tensor<1xi32>) -> tensor<1xi32> +// CHECK: [[RELAXED_JOIN:%.+]] = type_relaxation [[JOIN]] : tensor<1xi32> to tensor<*xi32> +// CHECK: br ^bb3([[RELAXED_JOIN]] : tensor<*xi32>) +// CHECK: ^bb2: +// CHECK: [[JOIN:%.+]] = "test.ti.join_types"([[Y]]) : (tensor<2xi32>) -> tensor<2xi32> +// CHECK: [[RELAXED_JOIN:%.+]] = type_relaxation [[JOIN]] : tensor<2xi32> to tensor<*xi32> +// CHECK: br ^bb3([[RELAXED_JOIN]] : tensor<*xi32>) +// CHECK: ^bb3([[VAL:%.+]]: tensor<*xi32>): +// CHECK: [[SPECIALIZED_VAL:%.+]] = type_specialization [[VAL]] : tensor<*xi32> to tensor +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_VAL]]) : (tensor) -> () +// CHECK: "test.ti.disallow_input_refinement"([[VAL]]) : (tensor<*xi32>) -> () +// CHECK: return +// CHECK: } +// ----- + +func @test_regionbranchopinterface() -> tensor { + %cond = "getValue"() : () -> (i1) + + %res = test.region_if %cond : i1 -> (tensor) then { + ^bb0(%arg1 : i1): + %tmp1 = "getValue"() : () -> (tensor<1x2xi32>) + %then_value = "test.ti.join_types"(%tmp1) : (tensor<1x2xi32>) -> tensor + test.region_if_yield %then_value : tensor + } else { + ^bb0(%arg1 : i1): + %tmp1 = "getValue"() : () -> (tensor<1x9xi32>) + %else_value = "test.ti.join_types"(%tmp1) : (tensor<1x9xi32>) -> tensor + test.region_if_yield %else_value : tensor + } join { + ^bb0(%arg1 : tensor): + test.region_if_yield %arg1 : tensor + } + "test.ti.allow_input_refinement"(%res) : (tensor) -> () + return %res : tensor +} + +// CHECK: [[COND:%.+]] = "getValue"() : () -> i1 +// CHECK: [[RES:%.+]] = test.region_if [[COND]]: i1 -> tensor then { +// CHECK: ^bb0({{.+}}: i1): // no predecessors +// CHECK: [[TMP1:%.+]] = "getValue"() : () -> tensor<1x2xi32> +// CHECK: [[THEN_VALUE:%.+]] = "test.ti.join_types"([[TMP1]]) : (tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: [[RELAXED_THEN_VALUE:%.+]] = type_relaxation [[THEN_VALUE]] : tensor<1x2xi32> to tensor +// CHECK: test.region_if_yield [[RELAXED_THEN_VALUE]] : tensor +// CHECK: } else { +// CHECK: ^bb0({{.+}}: i1): // no predecessors +// CHECK: [[TMP2:%.+]] = "getValue"() : () -> tensor<1x9xi32> +// CHECK: [[ELSE_VALUE:%.+]] = "test.ti.join_types"([[TMP2]]) : (tensor<1x9xi32>) -> tensor<1x9xi32> +// CHECK: [[RELAXED_ELSE_VALUE:%.+]] = type_relaxation [[ELSE_VALUE]] : tensor<1x9xi32> to tensor +// CHECK: test.region_if_yield [[RELAXED_ELSE_VALUE]] : tensor +// CHECK: } join { +// CHECK: ^bb0([[ARG:%.+]]: tensor): // no predecessors +// CHECK: test.region_if_yield [[ARG]] : tensor +// CHECK: } +// CHECK: [[SPECIALIZED_RES:%.+]] = type_specialization [[RES]] : tensor to tensor<1x?xi32> +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_RES]]) : (tensor<1x?xi32>) -> () +// CHECK: return [[RES]] : tensor + +// ----- + +func @test_regionbranchopinterface() -> tensor { + %cond = "getValue"() : () -> (i1) + + %res = test.ti.region_if %cond : i1 -> (tensor) then { + ^bb0(%arg1 : i1): + %tmp1 = "getValue"() : () -> (tensor<1x2xi32>) + %then_value = "test.ti.join_types"(%tmp1) : (tensor<1x2xi32>) -> tensor + test.ti.region_if_yield %then_value : tensor + } else { + ^bb0(%arg1 : i1): + %tmp1 = "getValue"() : () -> (tensor<1x9xi32>) + %else_value = "test.ti.join_types"(%tmp1) : (tensor<1x9xi32>) -> tensor + test.ti.region_if_yield %else_value : tensor + } join { + ^bb0(%arg1 : tensor): + test.ti.region_if_yield %arg1 : tensor + } + "test.ti.allow_input_refinement"(%res) : (tensor) -> () + return %res : tensor +} + +// CHECK: [[COND:%.+]] = "getValue"() : () -> i1 +// CHECK: [[RES:%.+]] = test.ti.region_if [[COND]]: i1 -> tensor<1x?xi32> then { +// CHECK: ^bb0({{.+}}: i1): // no predecessors +// CHECK: [[TMP1:%.+]] = "getValue"() : () -> tensor<1x2xi32> +// CHECK: [[THEN_VALUE:%.+]] = "test.ti.join_types"([[TMP1]]) : (tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: test.ti.region_if_yield [[THEN_VALUE]] : tensor<1x2xi32> +// CHECK: } else { +// CHECK: ^bb0({{.+}}: i1): // no predecessors +// CHECK: [[TMP2:%.+]] = "getValue"() : () -> tensor<1x9xi32> +// CHECK: [[ELSE_VALUE:%.+]] = "test.ti.join_types"([[TMP2]]) : (tensor<1x9xi32>) -> tensor<1x9xi32> +// CHECK: test.ti.region_if_yield [[ELSE_VALUE]] : tensor<1x9xi32> +// CHECK: } join { +// CHECK: ^bb0([[ARG:%.+]]: tensor): // no predecessors +// CHECK: [[SPECIALIZED_ARG:%.+]] = type_specialization [[ARG]] : tensor to tensor<1x?xi32> +// CHECK: test.ti.region_if_yield [[SPECIALIZED_ARG]] : tensor<1x?xi32> +// CHECK: } +// CHECK: [[RELAXED_RES:%.+]] = type_relaxation [[RES]] : tensor<1x?xi32> to tensor +// CHECK: "test.ti.allow_input_refinement"([[RES]]) : (tensor<1x?xi32>) -> () +// CHECK: return [[RELAXED_RES]] : tensor + +// ----- + +func @test_dynamicize(%arg : tensor<1x2x3xi32>) { + %tmp1 = "test.ti.dynamicize"(%arg) : (tensor<1x2x3xi32>) -> tensor<*xi32> + %tmp2 = "test.ti.dynamicize"(%tmp1) : (tensor<*xi32>) -> tensor<*xi32> + %tmp3 = "test.ti.dynamicize"(%tmp2) : (tensor<*xi32>) -> tensor<*xi32> + %tmp4 = "test.ti.dynamicize"(%tmp3) : (tensor<*xi32>) -> tensor<*xi32> + return +} + +// CHECK: func @test_dynamicize([[ARG:%.+]]: tensor<1x2x3xi32>) { +// CHECK: [[TMP1:%.+]] = "test.ti.dynamicize"([[ARG]]) : (tensor<1x2x3xi32>) -> tensor +// CHECK: [[TMP2:%.+]] = "test.ti.dynamicize"([[TMP1]]) : (tensor) -> tensor +// CHECK: [[TMP3:%.+]] = "test.ti.dynamicize"([[TMP2]]) : (tensor) -> tensor +// CHECK: [[TMP4:%.+]] = "test.ti.dynamicize"([[TMP3]]) : (tensor) -> tensor + +// ----- + +func @test_simple_while(%func_arg : tensor<1x2x3xi32>) { + %relaxed_func_arg = type_relaxation %func_arg : tensor<1x2x3xi32> to tensor<*xi32> + %while = scf.while (%before_arg = %relaxed_func_arg) : (tensor<*xi32>) -> tensor<*xi32> { + %cond = "getValue"() : () -> (i1) + scf.condition(%cond) %before_arg : tensor<*xi32> + } do { + ^bb0(%after_arg: tensor<*xi32>): + %join = "test.ti.join_types"(%after_arg) : (tensor<*xi32>) -> (tensor<*xi32>) + scf.yield %join : tensor<*xi32> + } + "test.ti.allow_input_refinement"(%while) : (tensor<*xi32>) -> () + return +} + +// CHECK: func @test_simple_while([[FUNC_ARG:%.+]]: tensor<1x2x3xi32>) { +// CHECK: [[RELAXED_FUNC_ARG:%.+]] = type_relaxation [[FUNC_ARG]] : tensor<1x2x3xi32> to tensor<*xi32> +// CHECK: [[WHILE:%.+]] = scf.while ([[BEFORE_ARG:%.+]] = [[RELAXED_FUNC_ARG]]) : (tensor<*xi32>) -> tensor<*xi32> { +// CHECK: [[COND:%.+]] = "getValue"() : () -> i1 +// CHECK: scf.condition([[COND]]) [[BEFORE_ARG]] : tensor<*xi32> +// CHECK: } do { +// CHECK: ^bb0([[AFTER_ARG:%.+]]: tensor<*xi32>): // no predecessors +// CHECK: [[JOIN:%.+]] = "test.ti.join_types"([[AFTER_ARG]]) : (tensor<*xi32>) -> tensor<1x2x3xi32> +// CHECK: [[RELAXED_JOIN:%.+]] = type_relaxation [[JOIN]] : tensor<1x2x3xi32> to tensor<*xi32> +// CHECK: scf.yield [[RELAXED_JOIN]] : tensor<*xi32> +// CHECK: } +// CHECK: [[SPECIALIZED_WHILE:%.+]] = type_specialization [[WHILE]] : tensor<*xi32> to tensor<1x2x3xi32> +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_WHILE]]) : (tensor<1x2x3xi32>) -> () +// CHECK: return +// CHECK: } + +// ----- + +func @test_while(%func_arg : tensor<1x2x3xi32>) { + %relaxed_func_arg = type_relaxation %func_arg : tensor<1x2x3xi32> to tensor<*xi32> + %while = scf.while (%before_arg = %relaxed_func_arg) : (tensor<*xi32>) -> tensor<*xi32> { + %cond = "getValue"() : () -> (i1) + scf.condition(%cond) %before_arg : tensor<*xi32> + } do { + ^bb0(%after_arg: tensor<*xi32>): + %dyn = "test.ti.dynamicize"(%after_arg) : (tensor<*xi32>) -> (tensor<*xi32>) + scf.yield %dyn : tensor<*xi32> + } + "test.ti.allow_input_refinement"(%while) : (tensor<*xi32>) -> () + return +} + +// CHECK: func @test_while([[FUNC_ARG:%.+]]: tensor<1x2x3xi32>) { +// CHECK: [[RELAXED_FUNC_ARG:%.+]] = type_relaxation [[FUNC_ARG]] : tensor<1x2x3xi32> to tensor<*xi32> +// CHECK: [[WHILE:%.+]] = scf.while ([[BEFORE_ARG:%.+]] = [[RELAXED_FUNC_ARG]]) : (tensor<*xi32>) -> tensor<*xi32> { +// CHECK: [[COND:%.+]] = "getValue"() : () -> i1 +// CHECK: scf.condition([[COND]]) [[BEFORE_ARG]] : tensor<*xi32> +// CHECK: } do { +// CHECK: ^bb0([[AFTER_ARG:%.+]]: tensor<*xi32>): // no predecessors +// CHECK: [[SPECIALIZED_AFTER_ARG:%.+]] = type_specialization [[AFTER_ARG]] : tensor<*xi32> to tensor +// CHECK: [[DYN:%.+]] = "test.ti.dynamicize"([[SPECIALIZED_AFTER_ARG]]) : (tensor) -> tensor +// CHECK: [[RELAXED_DYN:%.+]] = type_relaxation [[DYN]] : tensor to tensor<*xi32> +// CHECK: scf.yield [[RELAXED_DYN]] : tensor<*xi32> +// CHECK: } +// CHECK: [[SPECIALIZED_WHILE:%.+]] = type_specialization [[WHILE]] : tensor<*xi32> to tensor +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_WHILE]]) : (tensor) -> () +// CHECK: return +// CHECK: } + +// ----- + +func @test_same_operands_and_result_type(%x : tensor<1x?xi32>, %y : tensor) { + %tmp = "test.variadic_with_same_operand_results"(%x, %y) : (tensor<1x?xi32>, tensor) -> tensor<*xi32> + "test.ti.allow_input_refinement"(%tmp) : (tensor<*xi32>) -> () + return +} + +// CHECK: func @test_same_operands_and_result_type([[X:%.+]]: tensor<1x?xi32>, [[Y:%.+]]: tensor) { +// CHECK: [[RES:%.+]] = "test.variadic_with_same_operand_results"([[X]], [[Y]]) : (tensor<1x?xi32>, tensor) -> tensor<*xi32> +// CHECK: [[SPECIALIZED_RES:%.+]] = type_specialization [[RES]] : tensor<*xi32> to tensor<1x2xi32> +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_RES]]) : (tensor<1x2xi32>) -> () + +// ----- + +func @test_same_operands_and_result_shape(%x : tensor<1x?xi32>, %y : tensor) { + %tmp = "test.same_operand_and_result_shape"(%x, %y) : (tensor<1x?xi32>, tensor) -> tensor<*xi1> + "test.ti.allow_input_refinement"(%tmp) : (tensor<*xi1>) -> () + return +} + +// CHECK: func @test_same_operands_and_result_shape([[X:%.+]]: tensor<1x?xi32>, [[Y:%.+]]: tensor) { +// CHECK: [[RES:%.+]] = "test.same_operand_and_result_shape"([[X]], [[Y]]) : (tensor<1x?xi32>, tensor) -> tensor<*xi1> +// CHECK: [[SPECIALIZED_RES:%.+]] = type_specialization [[RES]] : tensor<*xi1> to tensor<1x2xi1> +// CHECK: "test.ti.allow_input_refinement"([[SPECIALIZED_RES]]) : (tensor<1x2xi1>) -> () diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/JoinMeetTypeInterface.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" @@ -1090,28 +1091,47 @@ } //===----------------------------------------------------------------------===// -// Control-Flow Edge Operand Types Test Ops +// Test Type Inferences //===----------------------------------------------------------------------===// +LogicalResult TIJoinTypesOp::inferReturnTypes( + ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + Type resultTy; + for (auto operand : operands) + resultTy = + !resultTy ? operand.getType() : joinTypes(resultTy, operand.getType()); + inferredReturnTypes.push_back(resultTy); + return success(); +} + +bool TIJoinTypesOp::isCompatibleReturnTypes(::mlir::TypeRange lhs, + ::mlir::TypeRange rhs) { + return llvm::all_of_zip(lhs, rhs, isMoreSpecializedOrSame) || + llvm::all_of_zip(lhs, rhs, isLessSpecializedOrSame); +} + Optional -CFEBranchOp::getMutableSuccessorOperands(unsigned index) { +TIBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); return targetOperandsMutable(); } -bool CFEBranchOp::areCompatibleControlFlowEdgeOperandTypes(Type operandTy, - Type blockArgTy) { +bool TIBranchOp::areCompatibleControlFlowEdgeOperandTypes(Type operandTy, + Type blockArgTy) { return isMoreSpecializedOrSame(operandTy, blockArgTy); } -static void print(OpAsmPrinter &p, CFERegionIfOp op) { printRegionIfOp(p, op); } +static void print(OpAsmPrinter &p, TIRegionIfOp op) { printRegionIfOp(p, op); } -OperandRange CFERegionIfOp::getSuccessorEntryOperands(unsigned index) { +OperandRange TIRegionIfOp::getSuccessorEntryOperands(unsigned index) { assert(index < 2 && "invalid region index"); return getOperands(); } -void CFERegionIfOp::getSuccessorRegions( +void TIRegionIfOp::getSuccessorRegions( Optional index, ArrayRef operands, SmallVectorImpl ®ions) { // We always branch to the join region. @@ -1128,11 +1148,46 @@ regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); } -bool CFERegionIfOp::areCompatibleControlFlowEdgeOperandTypes(Type operandTy, - Type blockArgTy) { +bool TIRegionIfOp::areCompatibleControlFlowEdgeOperandTypes(Type operandTy, + Type blockArgTy) { return isMoreSpecializedOrSame(operandTy, blockArgTy); } +LogicalResult +TIDynamicizeOp::inferReturnTypes(MLIRContext *, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + RankedTensorType inputTy = operands[0].getType().dyn_cast(); + if (!inputTy) { + inferredReturnTypes.push_back(operands[0].getType()); + return success(); + } + + ArrayRef inputShape = inputTy.getShape(); + + // Set the first non-dynamic dimension to dynamic. + SmallVector shape; + bool modified = false; + for (int64_t dim : inputShape) { + if (!ShapedType::isDynamic(dim) && !modified) { + dim = ShapedType::kDynamicSize; + modified = true; + } + shape.push_back(dim); + } + + inferredReturnTypes.push_back(inputTy.clone(shape)); + + return success(); +} + +bool TIDynamicizeOp::isCompatibleReturnTypes(::mlir::TypeRange lhs, + ::mlir::TypeRange rhs) { + return llvm::all_of_zip(lhs, rhs, isMoreSpecializedOrSame) || + llvm::all_of_zip(lhs, rhs, isLessSpecializedOrSame); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -160,8 +160,8 @@ def VariadicWithSameOperandsResult : TEST_Op<"variadic_with_same_operand_results", [SameOperandsAndResultType]> { - let arguments = (ins Variadic:$operands); - let results = (outs AnySignlessInteger:$result); + let arguments = (ins Variadic:$operands); + let results = (outs AnyTensor:$result); } def SameOperandsResultType : TEST_Op< @@ -2328,11 +2328,56 @@ } //===----------------------------------------------------------------------===// -// Test Control-Flow Edge Operand Types +// Test Type Inferences //===----------------------------------------------------------------------===// -def CFEBranchOp : TEST_Op<"cfe.br", - [DeclareOpInterfaceMethods { + let description = [{Disallow type refinement for all operands.}]; + let arguments = (ins Variadic:$inputs); +} + +def TIAllowInputRefinementOp + : TEST_Op<"ti.allow_input_refinement", + [DeclareOpInterfaceMethods]> { + let description = [{Allow type refinement for all operands.}]; + let arguments = (ins Variadic:$inputs); +} + +def TIConditionallyAllowInputRefinementOp + : TEST_Op<"ti.conditionally_allow_input_refinement", + [DeclareOpInterfaceMethods]> { + let description = [{Allow type refinement for evenly-indexed operands.}]; + let arguments = (ins Variadic:$inputs); + let extraClassDeclaration = [{ + bool allowsInputTypeRefinement(unsigned operandIndex) { + return operandIndex % 2 == 0; + } + }]; +} + +def TIJoinTypesOp + : TEST_Op<"ti.join_types", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let description = [{ + Type inference joins the operand types. + Output type refinement is configurable via a boolean attribute. + }]; + let arguments = (ins Variadic:$inputs, + DefaultValuedAttr:$allowOutputRefinement); + let results = (outs AnyType:$result); + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(::mlir::TypeRange lhs, ::mlir::TypeRange rhs); + bool allowsOutputTypeRefinement(unsigned outputIndex) { + return this->allowOutputRefinement(); + } + }]; +} + +def TIBranchOp : TEST_Op<"ti.br", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Terminator]> { let description = [{Branch allowing type refinement for all operands.}]; @@ -2340,19 +2385,22 @@ let successors = (successor AnySuccessor:$target); } -def CFERegionIfYieldOp : TEST_Op<"cfe.region_if_yield", - [NoSideEffect, ReturnLike, Terminator]> { +def TIRegionIfYieldOp : TEST_Op<"ti.region_if_yield", + [DeclareOpInterfaceMethods, + NoSideEffect, ReturnLike, Terminator]> { let arguments = (ins Variadic:$results); let assemblyFormat = [{ $results `:` type($results) attr-dict }]; } -def CFERegionIfOp - : TEST_Op<"cfe.region_if", - [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"CFERegionIfYieldOp">, + SingleBlockImplicitTerminator<"TIRegionIfYieldOp">, RecursiveSideEffects]> { let description =[{ Represents an abstract if-then-else-join pattern. In this context, the then @@ -2383,4 +2431,20 @@ let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; } +def TIDynamicizeOp + : TEST_Op<"ti.dynamicize", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let description = [{ + An op with particularly weird type inference, that progressively sets each + dimension to dynamic (`?`). + }]; + let arguments = (ins AnyTensor); + let results = (outs AnyTensor); + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(::mlir::TypeRange lhs, ::mlir::TypeRange rhs); + }]; +} + #endif // TEST_OPS