diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -193,11 +193,12 @@ let assemblyFormat = "$input attr-dict `:` type($input)"; } -def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> { - let summary = "Determines if 2 shapes can be successfully broadcasted"; +def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", + [Commutative, InferTypeOpInterface]> { + let summary = "Determines if 2+ shapes can be successfully broadcasted"; let description = [{ - Given two input shapes or extent tensors, return a predicate specifying if - they are broadcastable. This broadcastable follows the same logic as what + Given multiple input shapes or extent tensors, return a predicate specifying + if they are broadcastable. This broadcastable follows the same logic as what shape.broadcast documents. Concretely, shape.is_broadcastable returning true implies that @@ -212,11 +213,30 @@ ``` }]; - let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, - Shape_ShapeOrExtentTensorType:$rhs); + let arguments = (ins Variadic:$shapes); let results = (outs I1:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs), + [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>, + ]; + let extraClassDeclaration = [{ + // TODO: This should really be automatic. Figure out how to not need this defined. + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { + inferredReturnTypes.push_back(::mlir::IntegerType::get(context, + /*width=*/1)); + return success(); + }; + }]; + + let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; + let verifier = [{ + return success(getNumOperands() >= 2); + }]; + } def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { @@ -694,11 +714,12 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } -def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { - let summary = "Determines if 2 shapes can be successfully broadcasted"; +def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", + [Commutative, InferTypeOpInterface]> { + let summary = "Determines if 2+ shapes can be successfully broadcasted"; let description = [{ - Given two input shapes or extent tensors, return a witness specifying if - they are broadcastable. This broadcastable follows the same logic as what + Given input shapes or extent tensors, return a witness specifying if they + are broadcastable. This broadcastable follows the same logic as what shape.broadcast documents. "cstr" operations represent runtime assertions. @@ -710,14 +731,32 @@ ``` }]; - let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, - Shape_ShapeOrExtentTensorType:$rhs); + let arguments = (ins Variadic:$shapes); let results = (outs Shape_WitnessType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; + + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs), + [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>, + ]; + + let extraClassDeclaration = [{ + // TODO: This should really be automatic. Figure out how to not need this defined. + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { + inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context)); + return success(); + }; + }]; let hasCanonicalizer = 1; let hasFolder = 1; + let verifier = [{ + return success(getNumOperands() >= 2); + }]; } def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -19,77 +19,8 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; - namespace { -class ConvertCstrBroadcastableOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, - PatternRewriter &rewriter) const override { - if (op.getType().isa() || - op.lhs().getType().isa() || - op.rhs().getType().isa()) { - return rewriter.notifyMatchFailure( - op, "cannot convert error-propagating shapes"); - } - - auto loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - - // Find smaller and greater rank and extent tensor. - Value lhsRank = rewriter.create(loc, op.lhs(), zero); - Value rhsRank = rewriter.create(loc, op.rhs(), zero); - Value lhsRankULE = - rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); - Type indexTy = rewriter.getIndexType(); - Value lesserRank = - rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); - Value greaterRank = - rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); - Value lesserRankOperand = - rewriter.create(loc, lhsRankULE, op.lhs(), op.rhs()); - Value greaterRankOperand = - rewriter.create(loc, lhsRankULE, op.rhs(), op.lhs()); - - Value rankDiff = - rewriter.create(loc, indexTy, greaterRank, lesserRank); - - // Generate code to compare the shapes extent by extent, and emit errors for - // non-broadcast-compatible shapes. - // Two extents are broadcast-compatible if - // 1. they are both equal, or - // 2. at least one of them is 1. - - rewriter.create( - loc, rankDiff, greaterRank, one, llvm::None, - [&](OpBuilder &b, Location loc, Value iv, ValueRange) { - Value greaterRankOperandExtent = b.create( - loc, greaterRankOperand, ValueRange{iv}); - Value ivShifted = b.create(loc, indexTy, iv, rankDiff); - Value lesserRankOperandExtent = b.create( - loc, lesserRankOperand, ValueRange{ivShifted}); - - Value greaterRankOperandExtentIsOne = b.create( - loc, CmpIPredicate::eq, greaterRankOperandExtent, one); - Value lesserRankOperandExtentIsOne = b.create( - loc, CmpIPredicate::eq, lesserRankOperandExtent, one); - Value extentsAgree = - b.create(loc, CmpIPredicate::eq, greaterRankOperandExtent, - lesserRankOperandExtent); - auto broadcastIsValid = - b.create(loc, b.getI1Type(), extentsAgree, - b.create(loc, greaterRankOperandExtentIsOne, - lesserRankOperandExtentIsOne)); - b.create(loc, broadcastIsValid, "invalid broadcast"); - b.create(loc); - }); - - rewriter.replaceOpWithNewOp(op, true); - return success(); - } -}; +#include "ShapeToStandard.cpp.inc" } // namespace namespace { @@ -107,7 +38,7 @@ void mlir::populateConvertShapeConstraintsConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); patterns.insert(ctx); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -237,63 +237,84 @@ // For now, this lowering is only defined on `tensor` operands, not // on shapes. IsBroadcastableOp::Adaptor transformed(operands); - if (transformed.lhs().getType().isa() || - transformed.rhs().getType().isa()) + if (!llvm::all_of(op.shapes(), + [](Value v) { return !v.getType().isa(); })) return failure(); auto loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + ImplicitLocOpBuilder lb(loc, rewriter); + Value zero = lb.create(0); + Value one = lb.create(1); + Type indexTy = lb.getIndexType(); + + // Save all the ranks for bounds checking. Because this is a tensor + // representing the shape extents, the rank is the extent of the only + // dimension in the tensor. + SmallVector ranks, rankDiffs; + llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { + return lb.create(v, zero); + })); + + // Find the maximum rank + Value maxRank = ranks.front(); + for (Value v : llvm::drop_begin(ranks, 1)) { + Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + maxRank = lb.create(rankIsGreater, v, maxRank); + } + + // Calculate the difference of ranks and the maximum rank for later offsets. + llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { + return lb.create(indexTy, maxRank, v); + })); - // Find smaller and greater rank and extent tensor. - Value lhsRank = rewriter.create(loc, transformed.lhs(), zero); - Value rhsRank = rewriter.create(loc, transformed.rhs(), zero); - Value lhsRankULE = - rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); - Type indexTy = rewriter.getIndexType(); - Value lesserRank = - rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); - Value greaterRank = - rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); - auto erasedRankType = - RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); - Value rankErasedLhs = - rewriter.create(loc, erasedRankType, transformed.lhs()); - Value rankErasedRhs = - rewriter.create(loc, erasedRankType, transformed.rhs()); - Value lesserRankOperand = - rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); - Value greaterRankOperand = - rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); - Value rankDiff = - rewriter.create(loc, indexTy, greaterRank, lesserRank); Type i1Ty = rewriter.getI1Type(); - Value init = + Value trueVal = rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); - // Determine if all overlapping extents are broadcastable. - auto reduceResult = rewriter.create( - loc, rankDiff, greaterRank, one, ValueRange{init}, + auto reduceResult = lb.create( + loc, zero, maxRank, one, ValueRange{trueVal}, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { - Value greaterRankOperandExtent = b.create( - loc, greaterRankOperand, ValueRange{iv}); - Value greaterRankOperandExtentIsOne = b.create( - loc, CmpIPredicate::eq, greaterRankOperandExtent, one); - Value ivShifted = b.create(loc, indexTy, iv, rankDiff); - Value lesserRankOperandExtent = b.create( - loc, lesserRankOperand, ValueRange{ivShifted}); - Value lesserRankOperandExtentIsOne = b.create( - loc, CmpIPredicate::eq, lesserRankOperandExtent, one); - Value extentsAreEqual = - b.create(loc, CmpIPredicate::eq, greaterRankOperandExtent, - lesserRankOperandExtent); - Value broadcastableExtents = b.create( - loc, iterArgs[0], - b.create(loc, - b.create(loc, greaterRankOperandExtentIsOne, - lesserRankOperandExtentIsOne), - extentsAreEqual)); - b.create(loc, broadcastableExtents); + // Find a non-1 dim, if it exists. Note that the first part of this + // could reuse the Broadcast lowering entirely, but we redo the work + // here to make optimizations easier between the two loops. + Value broadcastedDim = getBroadcastedDim( + ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv); + + Value broadcastable = iterArgs[0]; + for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) { + Value shape, rankDiff; + std::tie(shape, rankDiff) = tup; + Value outOfBounds = + b.create(loc, CmpIPredicate::ult, iv, rankDiff); + broadcastable = + b.create( + loc, TypeRange{i1Ty}, outOfBounds, + [&](OpBuilder &b, Location loc) { + // Non existent dimensions are always broadcastable + b.create(loc, broadcastable); + }, + [&](OpBuilder &b, Location loc) { + // Every value needs to be either 1, or the same non-1 + // value to be broadcastable in this dim. + Value operandDimension = + b.create(loc, indexTy, iv, rankDiff); + Value dimensionExtent = b.create( + loc, shape, ValueRange{operandDimension}); + + Value equalOne = b.create(loc, CmpIPredicate::eq, + dimensionExtent, one); + Value equalBroadcasted = + b.create(loc, CmpIPredicate::eq, + dimensionExtent, broadcastedDim); + Value result = b.create( + loc, broadcastable, + b.create(loc, equalOne, equalBroadcasted)); + b.create(loc, result); + }) + .getResult(0); + } + + b.create(loc, broadcastable); }); rewriter.replaceOp(op, reduceResult.results().front()); diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td @@ -19,9 +19,9 @@ $_builder.getStringAttr("required broadcastable shapes") }]>; -def : Pat<(Shape_CstrBroadcastableOp $LHS, $RHS), +def CstrBroadcastableToRequire : Pat<(Shape_CstrBroadcastableOp $shapes), (Shape_CstrRequireOp - (Shape_IsBroadcastableOp $LHS, $RHS), + (Shape_IsBroadcastableOp $shapes), (BroadcastableStringAttr))>; #endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -483,6 +483,10 @@ } OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { + // TODO: Add folding for the nary case + if (operands.size() != 2) + return nullptr; + // Both operands are not needed if one is a scalar. if (operands[0] && operands[0].cast().getNumElements() == 0) @@ -504,9 +508,9 @@ // Lastly, see if folding can be completed based on what constraints are known // on the input shapes. SmallVector lhsShape, rhsShape; - if (failed(getShapeVec(lhs(), lhsShape))) + if (failed(getShapeVec(shapes()[0], lhsShape))) return nullptr; - if (failed(getShapeVec(rhs(), rhsShape))) + if (failed(getShapeVec(shapes()[1], rhsShape))) return nullptr; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -18,8 +18,9 @@ (replaceWithValue $args), [(HasSingleElement $args)]>; -def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x), - (Shape_ConstWitnessOp ConstBoolAttrTrue)>; +def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $shapes), + (Shape_ConstWitnessOp ConstBoolAttrTrue), + [(AllInputShapesEq $shapes)]>; def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes), (Shape_ConstWitnessOp ConstBoolAttrTrue), diff --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir --- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir +++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir @@ -4,28 +4,9 @@ // CHECK-LABEL: func @cstr_broadcastable( // CHECK-SAME: %[[LHS:.*]]: tensor, // CHECK-SAME: %[[RHS:.*]]: tensor) -> !shape.witness { -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[RET:.*]] = shape.const_witness true -// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor -// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index -// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor -// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor -// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index -// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] { -// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor -// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index -// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor -// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index -// CHECK: %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index -// CHECK: %[[EXTENTS_AGREE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index -// CHECK: %[[OR_TMP:.*]] = or %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE]] : i1 -// CHECK: %[[BROADCAST_IS_VALID:.*]] = or %[[EXTENTS_AGREE]], %[[OR_TMP]] : i1 -// CHECK: assert %[[BROADCAST_IS_VALID]], "invalid broadcast" -// CHECK: } +// CHECK: %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]] +// CHECK: assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes" // CHECK: return %[[RET]] : !shape.witness // CHECK: } func @cstr_broadcastable(%arg0: tensor, %arg1: tensor) -> !shape.witness { diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -305,77 +305,184 @@ // ----- -func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor) -> i1 { - %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor +func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> i1 { + %0 = shape.is_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> return %0 : i1 } - -// CHECK-LABEL: func @try_is_broadcastable( -// CHECK-SAME: %[[LHS:.*]]: tensor<3xindex>, -// CHECK-SAME: %[[RHS:.*]]: tensor) -> i1 { +// CHECK-LABEL: @try_is_broadcastable +// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<3xindex> -// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[LHS_SMALLER:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index -// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<3xindex> to tensor -// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor to tensor -// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor -// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor -// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index +// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex> +// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex> +// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex> +// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index +// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index +// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index // CHECK: %[[TRUE:.*]] = constant true -// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { -// CHECK: %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor -// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[C1]] : index -// CHECK: %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index -// CHECK: %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor -// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[SMALLER_EXTENT]], %[[C1]] : index -// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index -// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1 -// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1 -// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1 -// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1 -// CHECK: } -// CHECK: return %[[ALL_RESULT]] : i1 -// CHECK: } +// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { +// CHECK: %[[C1_0:.*]] = constant 1 : index +// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) { +// CHECK: scf.yield %[[C1_0]] : index +// CHECK: } else { +// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> +// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index +// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index +// CHECK: } +// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { +// CHECK: scf.yield %[[DIM0]] : index +// CHECK: } else { +// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> +// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index +// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index +// CHECK: } +// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { +// CHECK: scf.yield %[[DIM1]] : index +// CHECK: } else { +// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> +// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index +// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index +// CHECK: } +// CHECK: %[[OUT_BOUND_0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) { +// CHECK: scf.yield %[[ALL_SO_FAR]] : i1 +// CHECK: } else { +// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex> +// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index +// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index +// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1 +// CHECK: %[[AND_REDUCTION:.*]] = and %[[ALL_SO_FAR]], %[[GOOD]] : i1 +// CHECK: scf.yield %[[AND_REDUCTION]] : i1 +// CHECK: } +// CHECK: %[[OUT_BOUND_1:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) { +// CHECK: scf.yield %[[REDUCTION_0]] : i1 +// CHECK: } else { +// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex> +// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index +// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index +// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1 +// CHECK: %[[AND_REDUCTION:.*]] = and %[[REDUCTION_0]], %[[GOOD]] : i1 +// CHECK: scf.yield %[[AND_REDUCTION]] : i1 +// CHECK: } +// CHECK: %[[OUT_BOUND_2:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) { +// CHECK: scf.yield %[[SECOND_REDUCTION]] : i1 +// CHECK: } else { +// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex> +// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED:.*]], %c1 : index +// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index +// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1 +// CHECK: %[[AND_REDUCTION:.*]] = and %[[SECOND_REDUCTION]], %[[GOOD]] : i1 +// CHECK: scf.yield %[[AND_REDUCTION]] : i1 +// CHECK: } +// CHECK: scf.yield %[[FINAL_RESULT]] : i1 // ----- -func @broadcast(%a : tensor, %b : tensor) -> !shape.witness { - %0 = shape.cstr_broadcastable %a, %b : tensor, tensor +func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> !shape.witness { + %0 = shape.cstr_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> return %0 : !shape.witness } - // CHECK-LABEL: func @broadcast( -// CHECK-SAME: %[[LHS:.*]]: tensor, -// CHECK-SAME: %[[RHS:.*]]: tensor) -> !shape.witness { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor -// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[LHS_SMALLER:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index -// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor to tensor -// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor to tensor -// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor -// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor -// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index +// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex> +// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex> +// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex> +// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index +// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index +// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index // CHECK: %[[TRUE:.*]] = constant true -// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { -// CHECK: %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor -// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[C1]] : index -// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index -// CHECK: %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor -// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[SMALLER_EXTENT]], %[[C1]] : index -// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index -// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1 -// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1 -// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1 -// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1 -// CHECK: } +// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { +// CHECK: %[[C1_0:.*]] = constant 1 : index +// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) { +// CHECK: scf.yield %[[C1_0]] : index +// CHECK: } else { +// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> +// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index +// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index +// CHECK: } +// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { +// CHECK: scf.yield %[[DIM0]] : index +// CHECK: } else { +// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> +// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index +// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index +// CHECK: } +// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { +// CHECK: scf.yield %[[DIM1]] : index +// CHECK: } else { +// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> +// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index +// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index +// CHECK: } +// CHECK: %[[OUT_BOUND_0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) { +// CHECK: scf.yield %[[ALL_SO_FAR]] : i1 +// CHECK: } else { +// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex> +// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index +// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index +// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1 +// CHECK: %[[AND_REDUCTION:.*]] = and %[[ALL_SO_FAR]], %[[GOOD]] : i1 +// CHECK: scf.yield %[[AND_REDUCTION]] : i1 +// CHECK: } +// CHECK: %[[OUT_BOUND_1:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) { +// CHECK: scf.yield %[[REDUCTION_0]] : i1 +// CHECK: } else { +// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex> +// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index +// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index +// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1 +// CHECK: %[[AND_REDUCTION:.*]] = and %[[REDUCTION_0]], %[[GOOD]] : i1 +// CHECK: scf.yield %[[AND_REDUCTION]] : i1 +// CHECK: } +// CHECK: %[[OUT_BOUND_2:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) { +// CHECK: scf.yield %[[SECOND_REDUCTION]] : i1 +// CHECK: } else { +// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex> +// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED:.*]], %c1 : index +// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index +// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1 +// CHECK: %[[AND_REDUCTION:.*]] = and %[[SECOND_REDUCTION]], %[[GOOD]] : i1 +// CHECK: scf.yield %[[AND_REDUCTION]] : i1 +// CHECK: } +// CHECK: scf.yield %[[FINAL_RESULT]] : i1 + // CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes" // CHECK: return %[[RESULT]] : !shape.witness // CHECK: }