diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -242,6 +242,21 @@ let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"]; } +def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> { + let summary = "Convert shape constraint operations to the standard dialect"; + let description = [{ + This pass eliminates shape constraints from the program, converting them to + eager (side-effecting) error handling code. + + This pass is separate from the regular convert-shape-to-standard, despite + converting between the same dialects, because converting shape constraints + can happen at a different part of the program than general shape + computation lowering. + }]; + let constructor = "mlir::createConvertShapeConstraintsPass()"; + let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"]; +} + //===----------------------------------------------------------------------===// // SPIRVToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h --- a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h +++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h @@ -15,6 +15,7 @@ class MLIRContext; class ModuleOp; +class FuncOp; template class OperationPass; class OwningRewritePatternList; @@ -24,6 +25,11 @@ std::unique_ptr> createConvertShapeToStandardPass(); +void populateConvertShapeConstraintsConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +std::unique_ptr> createConvertShapeConstraintsPass(); + } // namespace mlir #endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_ diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(MLIRShapeToStandard + ConvertShapeConstraints.cpp ShapeToStandard.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -0,0 +1,147 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +using namespace mlir; + +namespace { +class ConvertCstrBroadcastableOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + if (op.getType().isa() || + op.lhs().getType().isa() || + op.rhs().getType().isa()) { + return rewriter.notifyMatchFailure( + op, "cannot convert error-propagating shapes"); + } + + auto loc = op.getLoc(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + // Find smaller and greater rank and extent tensor. + Value lhsRank = rewriter.create(loc, op.lhs(), zero); + Value rhsRank = rewriter.create(loc, op.rhs(), zero); + Value lhsSmaller = + rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); + Type indexTy = rewriter.getIndexType(); + Type extentTensorTy = op.lhs().getType(); + auto ifOp = rewriter.create( + loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, + lhsSmaller, + [&](OpBuilder &b, Location loc) { + b.create( + loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()}); + }, + [&](OpBuilder &b, Location loc) { + b.create( + loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()}); + }); + Value lesserRank = ifOp.getResult(0); + Value lesserRankOperand = ifOp.getResult(1); + Value greaterRank = ifOp.getResult(2); + Value greaterRankOperand = ifOp.getResult(3); + + Value rankDiff = + rewriter.create(loc, indexTy, greaterRank, lesserRank); + + // Compare the shapes extent by extent, and emit errors for + // non-broadcast-compatible shapes. + // Two extents are broadcast-compatible if + // 1. they are both equal, or + // 2. at least one of them is 1. + + rewriter.create( + loc, rankDiff, greaterRank, one, llvm::None, + [&](OpBuilder &b, Location loc, Value iv, ValueRange) { + Value greaterRankOperandExtent = b.create( + loc, greaterRankOperand, ValueRange{iv}); + Value ivShifted = b.create(loc, indexTy, iv, rankDiff); + Value lesserRankOperandExtent = b.create( + loc, lesserRankOperand, ValueRange{ivShifted}); + + Value greaterRankOperandExtentIsOne = b.create( + loc, CmpIPredicate::eq, greaterRankOperandExtent, one); + Value lesserRankOperandExtentIsOne = b.create( + loc, CmpIPredicate::eq, lesserRankOperandExtent, one); + Value extentsAgree = + b.create(loc, CmpIPredicate::eq, greaterRankOperandExtent, + lesserRankOperandExtent); + auto broadcastIsValid = + b.create(loc, b.getI1Type(), extentsAgree, + b.create(loc, greaterRankOperandExtentIsOne, + lesserRankOperandExtentIsOne)); + b.create(loc, broadcastIsValid, "invalid broadcast"); + b.create(loc); + }); + + // Now that we have emitted all the assertions, the witness is trivially + // satisfied. + rewriter.replaceOpWithNewOp(op, true); + return success(); + } +}; +} // namespace + +namespace { +class ConvertCstrRequireOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(shape::CstrRequireOp op, + PatternRewriter &rewriter) const override { + rewriter.create(op.getLoc(), op.pred(), op.msgAttr()); + rewriter.replaceOpWithNewOp(op, true); + return success(); + } +}; +} // namespace + +void mlir::populateConvertShapeConstraintsConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); + patterns.insert(ctx); + // Add in the canonicalization patterns for shape.assuming so that it gets + // inlined when its witness becomes a true constant witness. + // TODO: Add further targeted canonicalization patterns as needed. + shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx); +} + +namespace { +// This pass eliminates shape constraints from the program, converting them to +// eager error handling code. +class ConvertShapeConstraints + : public ConvertShapeConstraintsBase { + void runOnOperation() { + auto func = getOperation(); + auto *context = &getContext(); + + OwningRewritePatternList patterns; + populateConvertShapeConstraintsConversionPatterns(patterns, context); + + if (failed(applyPatternsAndFoldGreedily(func, patterns))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::createConvertShapeConstraintsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt -convert-shape-constraints <%s | FileCheck %s + +// There's not very much useful to check here other than pasting the output. +// CHECK-LABEL: func @cstr_broadcastable( +// CHECK-SAME: %[[LHS:.*]]: tensor, +// CHECK-SAME: %[[RHS:.*]]: tensor) -> !shape.witness { +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[RET:.*]] = shape.const_witness true +// CHECK: %[[LHSEXTENT:.*]] = dim %[[LHS]], %[[C0]] : tensor +// CHECK: %[[RHSEXTENT:.*]] = dim %[[RHS]], %[[C0]] : tensor +// CHECK: %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSEXTENT]], %[[RHSEXTENT]] : index +// CHECK: %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor, index, tensor) { +// CHECK: scf.yield %[[LHSEXTENT]], %[[LHS]], %[[RHSEXTENT]], %[[RHS]] : index, tensor, index, tensor +// CHECK: } else { +// CHECK: scf.yield %[[RHSEXTENT]], %[[RHS]], %[[LHSEXTENT]], %[[LHS]] : index, tensor, index, tensor +// CHECK: } +// CHECK: %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index +// CHECK: scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] { +// CHECK: %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor +// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index +// CHECK: %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor +// CHECK: %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index +// CHECK: %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index +// CHECK: %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index +// CHECK: %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1 +// CHECK: %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1 +// CHECK: assert %[[BROADCASTISVALID]], "invalid broadcast" +// CHECK: } +// CHECK: return %[[RET]] : !shape.witness +// CHECK: } +func @cstr_broadcastable(%arg0: tensor, %arg1: tensor) -> !shape.witness { + %witness = shape.cstr_broadcastable %arg0, %arg1 : tensor, tensor + return %witness : !shape.witness +} + +// Check that `shape.assuming` is eliminated after we create the error handling code. +// CHECK-LABEL: func @assuming +func @assuming(%arg0: tensor, %arg1: tensor) -> tensor<2xf32> { + %witness = shape.cstr_broadcastable %arg0, %arg1 : tensor, tensor + // CHECK-NOT: shape.assuming + // CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<2xf32> + %0 = shape.assuming %witness -> tensor<2xf32> { + %c = constant dense<0.0> : tensor<2xf32> + shape.assuming_yield %c : tensor<2xf32> + } + // CHECK: return %[[CST]] + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: func @cstr_require +func @cstr_require(%arg0: i1) -> !shape.witness { + // CHECK: %[[RET:.*]] = shape.const_witness true + // CHECK: assert %arg0, "msg" + // CHECK: return %[[RET]] + %witness = shape.cstr_require %arg0, "msg" + return %witness : !shape.witness +}