diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -206,6 +206,91 @@ return success(); } +namespace { +struct CstrBroadcastableOpConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CstrBroadcastableOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +// A lowering to remove the implicit broadcasting logic of +// shape.cstr_broadcastable without removing the benefits of constraints +// compared yields. To actually remove the constraint, look to +// ConvertShapeConstraints.cpp. +LogicalResult CstrBroadcastableOpConverter::matchAndRewrite( + CstrBroadcastableOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // For now, this lowering is only defined on `tensor` operands, not + // on shapes. + CstrBroadcastableOp::Adaptor transformed(operands); + if (transformed.lhs().getType().isa() || + transformed.rhs().getType().isa()) + return failure(); + + auto loc = op.getLoc(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + // Find smaller and greater rank and extent tensor. + Value lhsRank = rewriter.create(loc, transformed.lhs(), zero); + Value rhsRank = rewriter.create(loc, transformed.rhs(), zero); + Value lhsRankULE = + rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); + Type indexTy = rewriter.getIndexType(); + Value lesserRank = + rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); + Value greaterRank = + rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); + auto erasedRankType = + RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); + Value rankErasedLhs = + rewriter.create(loc, erasedRankType, transformed.lhs()); + Value rankErasedRhs = + rewriter.create(loc, erasedRankType, transformed.rhs()); + Value lesserRankOperand = + rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); + Value greaterRankOperand = + rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); + Value rankDiff = + rewriter.create(loc, indexTy, greaterRank, lesserRank); + Type i1Ty = rewriter.getI1Type(); + Value init = + rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); + + // Determine if all overlapping extents are broadcastable. + auto reduceResult = rewriter.create( + loc, rankDiff, greaterRank, one, ValueRange{init}, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { + Value greaterRankOperandExtent = + b.create(loc, greaterRankOperand, ValueRange{iv}); + Value greaterRankOperandExtentIsOne = b.create( + loc, CmpIPredicate::eq, greaterRankOperandExtent, one); + Value ivShifted = b.create(loc, indexTy, iv, rankDiff); + Value lesserRankOperandExtent = b.create( + loc, lesserRankOperand, ValueRange{ivShifted}); + Value lesserRankOperandExtentIsOne = b.create( + loc, CmpIPredicate::eq, lesserRankOperandExtent, one); + Value extentsAreEqual = + b.create(loc, CmpIPredicate::eq, greaterRankOperandExtent, + lesserRankOperandExtent); + Value broadcastableExtents = b.create( + loc, iterArgs[0], + b.create(loc, + b.create(loc, greaterRankOperandExtentIsOne, + lesserRankOperandExtentIsOne), + extentsAreEqual)); + b.create(loc, broadcastableExtents); + }); + + rewriter.replaceOpWithNewOp(op, reduceResult.results().front(), + "required broadcastable shapes"); + return success(); +} + namespace { class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -499,7 +584,7 @@ MLIRContext &ctx = getContext(); ConversionTarget target(ctx); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); // Setup conversion patterns. OwningRewritePatternList patterns; @@ -521,6 +606,7 @@ BroadcastOpConverter, ConstShapeOpConverter, ConstSizeOpConversion, + CstrBroadcastableOpConverter, GetExtentOpConverter, RankOpConverter, ReduceOpConverter, diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -382,3 +382,42 @@ : tensor<2xindex>, tensor<3xindex> -> tensor return } + +// ----- + +func @broadcast(%a : tensor, %b : tensor) -> !shape.witness { + %0 = shape.cstr_broadcastable %a, %b : tensor, tensor + return %0 : !shape.witness +} + +// CHECK-LABEL: func @broadcast( +// CHECK-SAME: %[[LHS:.*]]: tensor, +// CHECK-SAME: %[[RHS:.*]]: tensor) -> !shape.witness { +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor +// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor +// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index +// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor to tensor +// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor +// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor +// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor +// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index +// CHECK: %[[TRUE:.*]] = constant true +// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { +// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor +// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index +// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index +// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor +// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index +// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index +// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1 +// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1 +// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1 +// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1 +// CHECK: } +// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes" +// CHECK: return %[[RESULT]] : !shape.witness +// CHECK: }