Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
Show All 13 Lines | |||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||||
#include "mlir/IR/PatternMatch.h" | #include "mlir/IR/PatternMatch.h" | ||||
#include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||
#include "mlir/Pass/PassRegistry.h" | #include "mlir/Pass/PassRegistry.h" | ||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||||
using namespace mlir; | using namespace mlir; | ||||
namespace { | namespace { | ||||
class ConvertCstrBroadcastableOp | #include "ShapeToStandard.cpp.inc" | ||||
: public OpRewritePattern<shape::CstrBroadcastableOp> { | |||||
public: | |||||
using OpRewritePattern::OpRewritePattern; | |||||
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, | |||||
PatternRewriter &rewriter) const override { | |||||
if (op.getType().isa<shape::ShapeType>() || | |||||
op.lhs().getType().isa<shape::ShapeType>() || | |||||
op.rhs().getType().isa<shape::ShapeType>()) { | |||||
return rewriter.notifyMatchFailure( | |||||
op, "cannot convert error-propagating shapes"); | |||||
} | |||||
auto loc = op.getLoc(); | |||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0); | |||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1); | |||||
// Find smaller and greater rank and extent tensor. | |||||
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero); | |||||
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero); | |||||
Value lhsRankULE = | |||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); | |||||
Type indexTy = rewriter.getIndexType(); | |||||
Value lesserRank = | |||||
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); | |||||
Value greaterRank = | |||||
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); | |||||
Value lesserRankOperand = | |||||
rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs()); | |||||
Value greaterRankOperand = | |||||
rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs()); | |||||
Value rankDiff = | |||||
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); | |||||
// Generate code to compare the shapes extent by extent, and emit errors for | |||||
// non-broadcast-compatible shapes. | |||||
// Two extents are broadcast-compatible if | |||||
// 1. they are both equal, or | |||||
// 2. at least one of them is 1. | |||||
rewriter.create<scf::ForOp>( | |||||
loc, rankDiff, greaterRank, one, llvm::None, | |||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange) { | |||||
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>( | |||||
loc, greaterRankOperand, ValueRange{iv}); | |||||
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); | |||||
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( | |||||
loc, lesserRankOperand, ValueRange{ivShifted}); | |||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( | |||||
loc, CmpIPredicate::eq, greaterRankOperandExtent, one); | |||||
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>( | |||||
loc, CmpIPredicate::eq, lesserRankOperandExtent, one); | |||||
Value extentsAgree = | |||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent, | |||||
lesserRankOperandExtent); | |||||
auto broadcastIsValid = | |||||
b.create<OrOp>(loc, b.getI1Type(), extentsAgree, | |||||
b.create<OrOp>(loc, greaterRankOperandExtentIsOne, | |||||
lesserRankOperandExtentIsOne)); | |||||
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast"); | |||||
b.create<scf::YieldOp>(loc); | |||||
}); | |||||
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); | |||||
return success(); | |||||
} | |||||
}; | |||||
} // namespace | } // namespace | ||||
namespace { | namespace { | ||||
class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { | class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { | ||||
public: | public: | ||||
using OpRewritePattern::OpRewritePattern; | using OpRewritePattern::OpRewritePattern; | ||||
LogicalResult matchAndRewrite(shape::CstrRequireOp op, | LogicalResult matchAndRewrite(shape::CstrRequireOp op, | ||||
PatternRewriter &rewriter) const override { | PatternRewriter &rewriter) const override { | ||||
rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr()); | rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr()); | ||||
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); | rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
} // namespace | } // namespace | ||||
void mlir::populateConvertShapeConstraintsConversionPatterns( | void mlir::populateConvertShapeConstraintsConversionPatterns( | ||||
OwningRewritePatternList &patterns, MLIRContext *ctx) { | OwningRewritePatternList &patterns, MLIRContext *ctx) { | ||||
patterns.insert<ConvertCstrBroadcastableOp>(ctx); | patterns.insert<CstrBroadcastableToRequire>(ctx); | ||||
patterns.insert<ConvertCstrRequireOp>(ctx); | patterns.insert<ConvertCstrRequireOp>(ctx); | ||||
} | } | ||||
namespace { | namespace { | ||||
// This pass eliminates shape constraints from the program, converting them to | // This pass eliminates shape constraints from the program, converting them to | ||||
// eager (side-effecting) error handling code. After eager error handling code | // eager (side-effecting) error handling code. After eager error handling code | ||||
// is emitted, witnesses are satisfied, so they are replace with | // is emitted, witnesses are satisfied, so they are replace with | ||||
// `shape.const_witness true`. | // `shape.const_witness true`. | ||||
Show All 19 Lines |