diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -72,19 +72,26 @@ attribute can be used to describe the error case. }]; - let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, - Shape_ShapeOrExtentTensorType:$rhs, + let arguments = (ins Variadic:$shapes, OptionalAttr:$error); let results = (outs Shape_ShapeOrExtentTensorType:$result); let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + $shapes attr-dict `:` type($shapes) `->` type($result) }]; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; - let hasFolder = 1; + let builders = [OpBuilderDAG<(ins "::mlir::Type":$result, + "::mlir::Value":$lhs, "::mlir::Value":$rhs, + "/*optional*/ ::mlir::StringAttr":$error), [{ + build($_builder, $_state, result, ::llvm::makeArrayRef({lhs, rhs}), error); + }]> + ]; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; + let hasFolder = 1; + let verifier = [{ + return success(succeeded(::verifyShapeOrExtentTensorOp(*this)) && + getNumOperands() >= 2); + }]; } def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -14,7 +14,9 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::shape; @@ -73,6 +75,86 @@ matchAndRewrite(BroadcastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; + +// The generic lowering of Broadcast given any number of inputs. +Operation::result_range naryBroadcastLowering(BroadcastOp::Adaptor &op, + ImplicitLocOpBuilder lb) { + assert(llvm::all_of(op.shapes(), + [](Value v) { return !v.getType().isa(); })); + + Value zero = lb.create(0); + Value one = lb.create(1); + Type indexTy = lb.getIndexType(); + + // Save all the ranks for bounds checking. Because this is a tensor + // representing the shape extents, the rank is the extent of the only + // dimension in the tensor. + SmallVector ranks, rankDiffs; + llvm::append_range(ranks, llvm::map_range(op.shapes(), [&](Value v) { + return lb.create(v, zero); + })); + + // Find the maximum rank + Value maxRank = ranks.front(); + for (Value v : llvm::drop_begin(ranks, 1)) { + Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + maxRank = lb.create(rankIsGreater, v, maxRank); + } + + // Calculate the difference of ranks and the maximum rank for later offsets. + llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { + return lb.create(indexTy, maxRank, v); + })); + + return lb + .create( + getExtentTensorType(lb.getContext()), ValueRange{maxRank}, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value outputDimension = args[0]; + + Value reduceDim = one; + for (auto tup : llvm::zip(op.shapes(), rankDiffs)) { + Value shape = std::get<0>(tup); + Value rankDiff = std::get<1>(tup); + Value outOfBounds = b.create(loc, CmpIPredicate::ult, + outputDimension, rankDiff); + reduceDim = + b.create( + loc, TypeRange{indexTy}, outOfBounds, + [&](OpBuilder &b, Location loc) { + b.create(loc, reduceDim); + }, + [&](OpBuilder &b, Location loc) { + // The broadcasting logic is: + // - if one extent (here we arbitrarily choose the + // extent from the greater-rank operand) is equal to 1, + // then take the extent from the other operand + // - otherwise, take the extent as-is. + // Note that this logic remains correct in the presence + // of dimensions of zero extent. + Value lesserRankOperandDimension = b.create( + loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandExtent = + b.create( + loc, shape, + ValueRange{lesserRankOperandDimension}); + + Value dimIsOne = + b.create(loc, CmpIPredicate::eq, + lesserRankOperandExtent, one); + Value dim = b.create( + loc, dimIsOne, reduceDim, lesserRankOperandExtent); + b.create(loc, dim); + }) + .getResult(0); + + // Always give preference to a possibly non-1 extent + } + + b.create(loc, reduceDim); + }) + ->getResults(); +} } // namespace LogicalResult BroadcastOpConverter::matchAndRewrite( @@ -83,76 +165,10 @@ if (op.getType().isa()) return failure(); - assert(!op.lhs().getType().isa() && - !op.rhs().getType().isa()); auto loc = op.getLoc(); + ImplicitLocOpBuilder lb(loc, rewriter); BroadcastOp::Adaptor transformed(operands); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - - // Find smaller and greater rank and extent tensor. - Value lhsRank = rewriter.create(loc, op.lhs(), zero); - Value rhsRank = rewriter.create(loc, op.rhs(), zero); - Value lhsRankULE = - rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); - Type indexTy = rewriter.getIndexType(); - Value lesserRank = - rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); - Value greaterRank = - rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); - auto erasedRankType = - RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); - Value rankErasedLhs = - rewriter.create(loc, erasedRankType, transformed.lhs()); - Value rankErasedRhs = - rewriter.create(loc, erasedRankType, transformed.rhs()); - Value lesserRankOperand = - rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); - Value greaterRankOperand = - rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); - - Value rankDiff = - rewriter.create(loc, indexTy, greaterRank, lesserRank); - rewriter.replaceOpWithNewOp( - op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value outputDimension = args[0]; - Value isUnchallengedDimension = b.create( - loc, CmpIPredicate::ult, outputDimension, rankDiff); - Value greaterRankOperandExtent = b.create( - loc, greaterRankOperand, outputDimension); - // The initial dimensions of the greater-rank operand are unchallenged, - // so we can take them as-is. Otherwise, we need to do a comparison. - // We need an actual branch here (instead of a select) because the - // lesser-rank operand might be rank 0, so any tensor.extract would be - // invalid. - auto ifOp = b.create( - loc, TypeRange{indexTy}, isUnchallengedDimension, - [&](OpBuilder &b, Location loc) { - b.create(loc, greaterRankOperandExtent); - }, - [&](OpBuilder &b, Location loc) { - // The broadcasting logic is: - // - if one extent (here we arbitrarily choose the extent from - // the greater-rank operand) is equal to 1, then take the extent - // from the other operand - // - otherwise, take the extent as-is. - // Note that this logic remains correct in the presence of - // dimensions of zero extent. - Value lesserRankOperandDimension = - b.create(loc, indexTy, outputDimension, rankDiff); - Value lesserRankOperandExtent = b.create( - loc, lesserRankOperand, - ValueRange{lesserRankOperandDimension}); - Value greaterRankOperandExtentIsOne = b.create( - loc, CmpIPredicate::eq, greaterRankOperandExtent, one); - Value broadcastedExtent = b.create( - loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent, - greaterRankOperandExtent); - b.create(loc, broadcastedExtent); - }); - b.create(loc, ifOp.getResult(0)); - }); + rewriter.replaceOp(op, naryBroadcastLowering(transformed, lb)); return success(); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -357,10 +357,14 @@ if (!operands[1]) return nullptr; + // TODO: Support folding with more than 2 input shapes + if (operands.size() > 2 && !operands[2].isa()) + return nullptr; + auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); if (rhsShape.empty()) - return lhs(); + return shapes()[0]; if (!operands[0]) return nullptr; @@ -368,7 +372,7 @@ auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); if (lhsShape.empty()) - return rhs(); + return shapes()[1]; SmallVector resultShape; // If the shapes are not compatible, we can't fold it. diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -305,86 +305,6 @@ // ----- -// CHECK-LABEL: func @broadcast_unknown_extents( -// CHECK-SAME: %[[LHS:.*]]: tensor, -// CHECK-SAME: %[[RHS:.*]]: tensor) { -func @broadcast_unknown_extents(%a : tensor, %b : tensor) { - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor - // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor - // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index - // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index - // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index - // CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor to tensor - // CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor to tensor - // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor - // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor - // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index - // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] { - // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): - // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index - // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor - // CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) { - // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index - // CHECK: } else { - // CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index - // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor - // CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index - // CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index - // CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index - // CHECK: } - // CHECK: yield %[[OUTPUT_EXTENT:.*]] : index - // CHECK: } : tensor - // CHECK: return - // CHECK: } - %0 = shape.broadcast %a, %b - : tensor, tensor -> tensor - return -} - -// ----- - -// CHECK-LABEL: func @broadcast_known_different_extents( -// CHECK-SAME: %[[LHS:.*]]: tensor<2xindex>, -// CHECK-SAME: %[[RHS:.*]]: tensor<3xindex>) { -func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) { - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex> - // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex> - // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index - // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index - // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index - // CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor - // CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor - // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor - // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor - // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index - // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] { - // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): - // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index - // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor - // CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) { - // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index - // CHECK: } else { - // CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index - // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor - // CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index - // CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index - // CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index - // CHECK: } - // CHECK: yield %[[OUTPUT_EXTENT:.*]] : index - // CHECK: } : tensor - // CHECK: return - // CHECK: } - %0 = shape.broadcast %a, %b - : tensor<2xindex>, tensor<3xindex> -> tensor - return -} - -// ----- - func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor) -> i1 { %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor return %0 : i1 @@ -459,3 +379,62 @@ // CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes" // CHECK: return %[[RESULT]] : !shape.witness // CHECK: } + +// ----- + +func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>, + %b : tensor<3xindex>, + %c : tensor<2xindex>) { +// CHECK-LABEL: func @broadcast_3_shapes_different_extents( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) { +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex> +// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex> +// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex> +// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index +// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index +// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index +// CHECK: %[[RESULT:.*]] = tensor.generate %[[MAX_RANK]] { +// CHECK: ^bb0(%[[IDX:.*]]: index): +// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) { +// CHECK: scf.yield %[[C1]] : index +// CHECK: } else { +// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> +// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1]] : index +// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index +// CHECK: } +// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { +// CHECK: scf.yield %[[DIM0]] : index +// CHECK: } else { +// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> +// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1]] : index +// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index +// CHECK: } +// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { +// CHECK: scf.yield %[[DIM1]] : index +// CHECK: } else { +// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> +// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2:.*]], %[[C1]] : index +// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index +// CHECK: } +// CHECK: tensor.yield %[[DIM2]] : index +// CHECK: } : tensor +// CHECK: return +// CHECK: } + %0 = shape.broadcast %a, %b, %c + : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor + return +}