Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Shape/IR/Shape.cpp
Show First 20 Lines • Show All 485 Lines • ▼ Show 20 Lines | void CstrBroadcastableOp::getCanonicalizationPatterns( | ||||
OwningRewritePatternList &patterns, MLIRContext *context) { | OwningRewritePatternList &patterns, MLIRContext *context) { | ||||
// Canonicalization patterns have overlap with the considerations during | // Canonicalization patterns have overlap with the considerations during | ||||
// folding in case additional shape information is inferred at some point that | // folding in case additional shape information is inferred at some point that | ||||
// does not result in folding. | // does not result in folding. | ||||
patterns.insert<CstrBroadcastableEqOps>(context); | patterns.insert<CstrBroadcastableEqOps>(context); | ||||
} | } | ||||
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { | OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { | ||||
// TODO: Add folding for the nary case | |||||
if (operands.size() != 2) | |||||
return nullptr; | |||||
// Both operands are not needed if one is a scalar. | // Both operands are not needed if one is a scalar. | ||||
if (operands[0] && | if (operands[0] && | ||||
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) | operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) | ||||
return BoolAttr::get(getContext(), true); | return BoolAttr::get(getContext(), true); | ||||
if (operands[1] && | if (operands[1] && | ||||
operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) | operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) | ||||
return BoolAttr::get(getContext(), true); | return BoolAttr::get(getContext(), true); | ||||
if (operands[0] && operands[1]) { | if (operands[0] && operands[1]) { | ||||
auto lhsShape = llvm::to_vector<6>( | auto lhsShape = llvm::to_vector<6>( | ||||
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); | operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); | ||||
auto rhsShape = llvm::to_vector<6>( | auto rhsShape = llvm::to_vector<6>( | ||||
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); | operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); | ||||
SmallVector<int64_t, 6> resultShape; | SmallVector<int64_t, 6> resultShape; | ||||
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) | if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) | ||||
return BoolAttr::get(getContext(), true); | return BoolAttr::get(getContext(), true); | ||||
} | } | ||||
// Lastly, see if folding can be completed based on what constraints are known | // Lastly, see if folding can be completed based on what constraints are known | ||||
// on the input shapes. | // on the input shapes. | ||||
SmallVector<int64_t, 6> lhsShape, rhsShape; | SmallVector<int64_t, 6> lhsShape, rhsShape; | ||||
if (failed(getShapeVec(lhs(), lhsShape))) | if (failed(getShapeVec(shapes()[0], lhsShape))) | ||||
return nullptr; | return nullptr; | ||||
if (failed(getShapeVec(rhs(), rhsShape))) | if (failed(getShapeVec(shapes()[1], rhsShape))) | ||||
return nullptr; | return nullptr; | ||||
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) | if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) | ||||
return BoolAttr::get(getContext(), true); | return BoolAttr::get(getContext(), true); | ||||
// Because a failing witness result here represents an eventual assertion | // Because a failing witness result here represents an eventual assertion | ||||
// failure, we do not replace it with a constant witness. | // failure, we do not replace it with a constant witness. | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
static LogicalResult verify(CstrBroadcastableOp op) { | |||||
// Ensure that AssumingAllOp contains at least one operand | |||||
if (op.getNumOperands() < 2) | |||||
return op.emitOpError("required at least 2 input shapes"); | |||||
return success(); | |||||
} | |||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// CstrEqOp | // CstrEqOp | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, | void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, | ||||
MLIRContext *context) { | MLIRContext *context) { | ||||
// If inputs are equal, return passing witness | // If inputs are equal, return passing witness | ||||
patterns.insert<CstrEqEqOps>(context); | patterns.insert<CstrEqEqOps>(context); | ||||
▲ Show 20 Lines • Show All 183 Lines • ▼ Show 20 Lines | void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, | ||||
} else { | } else { | ||||
Value dim = | Value dim = | ||||
builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); | builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); | ||||
build(builder, result, builder.getIndexType(), shape, dim); | build(builder, result, builder.getIndexType(), shape, dim); | ||||
} | } | ||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// IsBroadcastableOp | |||||
//===----------------------------------------------------------------------===// | |||||
static LogicalResult verify(IsBroadcastableOp op) { | |||||
// Ensure that AssumingAllOp contains at least one operand | |||||
if (op.getNumOperands() < 2) | |||||
return op.emitOpError("required at least 2 input shapes"); | |||||
return success(); | |||||
} | |||||
//===----------------------------------------------------------------------===// | |||||
// RankOp | // RankOp | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { | OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { | ||||
auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); | auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); | ||||
if (!shape) | if (!shape) | ||||
return {}; | return {}; | ||||
int64_t rank = shape.getNumElements(); | int64_t rank = shape.getNumElements(); | ||||
▲ Show 20 Lines • Show All 317 Lines • Show Last 20 Lines |