diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -72,19 +72,26 @@ attribute can be used to describe the error case. }]; - let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, - Shape_ShapeOrExtentTensorType:$rhs, + let arguments = (ins Variadic:$shapes, OptionalAttr:$error); let results = (outs Shape_ShapeOrExtentTensorType:$result); let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + $shapes attr-dict `:` type($shapes) `->` type($result) }]; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; - let hasFolder = 1; + let builders = [OpBuilderDAG<(ins "::mlir::Type":$result, + "::mlir::Value":$lhs, "::mlir::Value":$rhs, + "/*optional*/ ::mlir::StringAttr":$error), [{ + build($_builder, $_state, result, ::llvm::makeArrayRef({lhs, rhs}), error); + }]> + ]; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; + let hasFolder = 1; + let verifier = [{ + return success(succeeded(::verifyShapeOrExtentTensorOp(*this)) && + getNumOperands() >= 2); + }]; } def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -14,7 +14,9 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::shape; @@ -73,6 +75,151 @@ matchAndRewrite(BroadcastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; + +// The generic lowering of Broadcast given any number of inputs. It assumes all +// verification was done before being called. +Operation::result_range NaryBroadcastLowering(BroadcastOp::Adaptor &op, + ImplicitLocOpBuilder lb) { + Value zero = lb.create(0); + Value one = lb.create(1); + Type indexTy = lb.getIndexType(); + + // Save all the ranks for bounds checking + SmallVector ranks, rankDiffs; + llvm::append_range(ranks, llvm::map_range(op.shapes(), [&](Value v) { + return lb.create(v, zero); + })); + + // Find the maximum rank + Value maxRank = ranks.front(); + for (Value v : llvm::drop_begin(ranks, 1)) { + Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + maxRank = lb.create(rankIsGreater, v, maxRank); + } + + // Calculate the difference of ranks and the maximum rank for later offsets. + llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { + return lb.create(indexTy, maxRank, v); + })); + + return lb + .create( + getExtentTensorType(lb.getContext()), ValueRange{maxRank}, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value outputDimension = args[0]; + + Value reduceDim = one; + for (auto tup : llvm::zip(op.shapes(), rankDiffs)) { + Value shape = std::get<0>(tup); + Value rankDiff = std::get<1>(tup); + Value inBound = b.create(loc, CmpIPredicate::ult, + outputDimension, rankDiff); + Value dim = + b.create( + loc, TypeRange{indexTy}, inBound, + [&](OpBuilder &b, Location loc) { + b.create(loc, one); + }, + [&](OpBuilder &b, Location loc) { + // The broadcasting logic is: + // - if one extent (here we arbitrarily choose the + // extent from the greater-rank operand) is equal to 1, + // then take the extent from the other operand + // - otherwise, take the extent as-is. + // Note that this logic remains correct in the presence + // of + Value lesserRankOperandDimension = b.create( + loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandExtent = + b.create( + loc, shape, + ValueRange{lesserRankOperandDimension}); + b.create(loc, lesserRankOperandExtent); + }) + .getResult(0); + + // Always give preference to a possibly non-1 extent + Value dimIsOne = + b.create(loc, CmpIPredicate::eq, dim, one); + reduceDim = b.create(loc, dimIsOne, reduceDim, dim); + } + + b.create(loc, reduceDim); + }) + ->getResults(); +} + +// The specialized lowering for the common case binary broadcast case. This is +// slightly more efficient and and arguably easier to read. It assumes all +// verification was done before being called. +Operation::result_range BinaryBroadcastLowering(BroadcastOp::Adaptor &op, + ImplicitLocOpBuilder lb) { + Value zero = lb.create(0); + Value one = lb.create(1); + Type indexTy = lb.getIndexType(); + + // Find smaller and greater rank and extent tensor. + Value lhsRank = lb.create(op.shapes()[0], zero); + Value rhsRank = lb.create(op.shapes()[1], zero); + Value lhsRankULE = lb.create(CmpIPredicate::ule, lhsRank, rhsRank); + Value lesserRank = lb.create(lhsRankULE, lhsRank, rhsRank); + Value greaterRank = lb.create(lhsRankULE, rhsRank, lhsRank); + auto erasedRankType = + RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); + Value rankErasedLhs = + lb.create(erasedRankType, op.shapes()[0]); + Value rankErasedRhs = + lb.create(erasedRankType, op.shapes()[1]); + Value lesserRankOperand = + lb.create(lhsRankULE, rankErasedLhs, rankErasedRhs); + Value greaterRankOperand = + lb.create(lhsRankULE, rankErasedRhs, rankErasedLhs); + + Value rankDiff = lb.create(indexTy, greaterRank, lesserRank); + return lb + .create( + getExtentTensorType(lb.getContext()), ValueRange{greaterRank}, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value outputDimension = args[0]; + Value isUnchallengedDimension = b.create( + loc, CmpIPredicate::ult, outputDimension, rankDiff); + Value greaterRankOperandExtent = b.create( + loc, greaterRankOperand, outputDimension); + // The initial dimensions of the greater-rank operand are + // unchallenged, so we can take them as-is. Otherwise, we need to do + // a comparison. We need an actual branch here (instead of a select) + // because the lesser-rank operand might be rank 0, so any + // tensor.extract would be invalid. + auto ifOp = b.create( + loc, TypeRange{indexTy}, isUnchallengedDimension, + [&](OpBuilder &b, Location loc) { + b.create(loc, greaterRankOperandExtent); + }, + [&](OpBuilder &b, Location loc) { + // The broadcasting logic is: + // - if one extent (here we arbitrarily choose the extent from + // the greater-rank operand) is equal to 1, then take the + // extent from the other operand + // - otherwise, take the extent as-is. + // Note that this logic remains correct in the presence of + // dimensions of zero extent. + Value lesserRankOperandDimension = + b.create(loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandExtent = b.create( + loc, lesserRankOperand, + ValueRange{lesserRankOperandDimension}); + Value greaterRankOperandExtentIsOne = b.create( + loc, CmpIPredicate::eq, greaterRankOperandExtent, one); + Value broadcastedExtent = b.create( + loc, greaterRankOperandExtentIsOne, + lesserRankOperandExtent, greaterRankOperandExtent); + b.create(loc, broadcastedExtent); + }); + b.create(loc, ifOp.getResult(0)); + }) + ->getResults(); +} + } // namespace LogicalResult BroadcastOpConverter::matchAndRewrite( @@ -83,76 +230,20 @@ if (op.getType().isa()) return failure(); - assert(!op.lhs().getType().isa() && - !op.rhs().getType().isa()); - auto loc = op.getLoc(); BroadcastOp::Adaptor transformed(operands); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - - // Find smaller and greater rank and extent tensor. - Value lhsRank = rewriter.create(loc, op.lhs(), zero); - Value rhsRank = rewriter.create(loc, op.rhs(), zero); - Value lhsRankULE = - rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); - Type indexTy = rewriter.getIndexType(); - Value lesserRank = - rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); - Value greaterRank = - rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); - auto erasedRankType = - RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); - Value rankErasedLhs = - rewriter.create(loc, erasedRankType, transformed.lhs()); - Value rankErasedRhs = - rewriter.create(loc, erasedRankType, transformed.rhs()); - Value lesserRankOperand = - rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); - Value greaterRankOperand = - rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); + assert(llvm::all_of(transformed.shapes(), + [](Value v) { return !v.getType().isa(); })); - Value rankDiff = - rewriter.create(loc, indexTy, greaterRank, lesserRank); - rewriter.replaceOpWithNewOp( - op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value outputDimension = args[0]; - Value isUnchallengedDimension = b.create( - loc, CmpIPredicate::ult, outputDimension, rankDiff); - Value greaterRankOperandExtent = b.create( - loc, greaterRankOperand, outputDimension); - // The initial dimensions of the greater-rank operand are unchallenged, - // so we can take them as-is. Otherwise, we need to do a comparison. - // We need an actual branch here (instead of a select) because the - // lesser-rank operand might be rank 0, so any tensor.extract would be - // invalid. - auto ifOp = b.create( - loc, TypeRange{indexTy}, isUnchallengedDimension, - [&](OpBuilder &b, Location loc) { - b.create(loc, greaterRankOperandExtent); - }, - [&](OpBuilder &b, Location loc) { - // The broadcasting logic is: - // - if one extent (here we arbitrarily choose the extent from - // the greater-rank operand) is equal to 1, then take the extent - // from the other operand - // - otherwise, take the extent as-is. - // Note that this logic remains correct in the presence of - // dimensions of zero extent. - Value lesserRankOperandDimension = - b.create(loc, indexTy, outputDimension, rankDiff); - Value lesserRankOperandExtent = b.create( - loc, lesserRankOperand, - ValueRange{lesserRankOperandDimension}); - Value greaterRankOperandExtentIsOne = b.create( - loc, CmpIPredicate::eq, greaterRankOperandExtent, one); - Value broadcastedExtent = b.create( - loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent, - greaterRankOperandExtent); - b.create(loc, broadcastedExtent); - }); - b.create(loc, ifOp.getResult(0)); - }); + auto loc = op.getLoc(); + ImplicitLocOpBuilder lb(loc, rewriter); + + // Specialize case the 2 input case as it is slightly optimized and might be + // slightly easier to read. + if (transformed.shapes().size() == 2) { + rewriter.replaceOp(op, BinaryBroadcastLowering(transformed, lb)); + } else { + rewriter.replaceOp(op, NaryBroadcastLowering(transformed, lb)); + } return success(); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -352,15 +352,19 @@ //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// - +// OpFoldResult BroadcastOp::fold(ArrayRef operands) { if (!operands[1]) return nullptr; + // TODO: Support folding with more than 2 input shapes + if (operands.size() > 2 && !operands[2].isa()) + return nullptr; + auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); if (rhsShape.empty()) - return lhs(); + return shapes()[0]; if (!operands[0]) return nullptr; @@ -368,7 +372,7 @@ auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); if (lhsShape.empty()) - return rhs(); + return shapes()[1]; SmallVector resultShape; // If the shapes are not compatible, we can't fold it. diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -459,3 +459,65 @@ // CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes" // CHECK: return %[[RESULT]] : !shape.witness // CHECK: } + +// ----- + +func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>, + %b : tensor<3xindex>, + %c : tensor<2xindex>) { +// CHECK-LABEL: func @broadcast_3_shapes_different_extents( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) { +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex> +// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex> +// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex> +// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index +// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index +// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index +// CHECK: %[[RESULT:.*]] = tensor.generate %[[MAX_RANK]] { +// CHECK: ^bb0(%[[IDX:.*]]: index): +// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) { +// CHECK: scf.yield %[[C1]] : index +// CHECK: } else { +// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index +// CHECK: %[[VAL_23:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> +// CHECK: scf.yield %[[VAL_23]] : index +// CHECK: } +// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[DIM0:.*]], %[[C1]] : index +// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1]], %[[DIM0]] : index +// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { +// CHECK: scf.yield %[[C1]] : index +// CHECK: } else { +// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index +// CHECK: %[[VAL_31:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> +// CHECK: scf.yield %[[VAL_31]] : index +// CHECK: } +// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[DIM1:.*]], %[[C1]] : index +// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[MAX_DIM0]], %[[DIM1]] : index +// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { +// CHECK: scf.yield %[[C1]] : index +// CHECK: } else { +// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index +// CHECK: %[[VAL_39:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> +// CHECK: scf.yield %[[VAL_39]] : index +// CHECK: } +// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[DIM2:.*]], %[[C1]] : index +// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[MAX_DIM1]], %[[DIM2]] : index +// CHECK: tensor.yield %[[MAX_DIM2]] : index +// CHECK: } : tensor +// CHECK: return +// CHECK: } + %0 = shape.broadcast %a, %b, %c + : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor + return +}