diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -19,6 +19,106 @@ using namespace mlir::shape; using namespace mlir::scf; +namespace { +struct BroadcastOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(BroadcastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult BroadcastOpConverter::matchAndRewrite( + BroadcastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // For now, this lowering is only defined on `tensor` operands, not + // on shapes. + if (op.getType().isa()) + return failure(); + + assert(!op.lhs().getType().isa() && + !op.rhs().getType().isa()); + auto loc = op.getLoc(); + BroadcastOp::Adaptor transformed(operands); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + // Find smaller and greater rank and extent tensor. + Value lhsRank = rewriter.create(loc, transformed.lhs(), zero); + Value rhsRank = rewriter.create(loc, transformed.rhs(), zero); + Value lhsSmaller = + rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); + Type indexTy = rewriter.getIndexType(); + Type extentTensorTy = op.getType(); + auto ifOp = rewriter.create( + loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, + lhsSmaller, + [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{lhsRank, transformed.lhs(), + rhsRank, transformed.rhs()}); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{rhsRank, transformed.rhs(), + lhsRank, transformed.lhs()}); + }); + Value smallerRank = ifOp.getResult(0); + Value smallerOperand = ifOp.getResult(1); + Value greaterRank = ifOp.getResult(2); + Value greaterOperand = ifOp.getResult(3); + + // Allocate stack memory for the broadcasted extent tensor. + Type i64Ty = rewriter.getI64Type(); + Type memTy = MemRefType::get({ShapedType::kDynamicSize}, + i64Ty); // TODO: Replace i64 with index type when + // memref is supported. + Value mem = rewriter.create(loc, memTy, ValueRange{greaterRank}); + + // Copy extents from greater operand that are not challenged. + Value rankDiff = + rewriter.create(loc, indexTy, greaterRank, smallerRank); + rewriter.create( + loc, zero, rankDiff, one, llvm::None, + [&](OpBuilder &b, Location loc, Value iv, ValueRange) { + Value extent = + b.create(loc, greaterOperand, ValueRange{iv}); + Value extentI64 = b.create( + loc, i64Ty, extent); // TODO: Remove this cast when redundant. + b.create(loc, extentI64, mem, ValueRange{iv}); + b.create(loc); + }); + + // Determine remaining broadcasted extents. + rewriter.create( + loc, rankDiff, greaterRank, one, llvm::None, + [&](OpBuilder &b, Location loc, Value iv, ValueRange) { + Value greaterOperandExtent = + b.create(loc, greaterOperand, ValueRange{iv}); + Value greaterOperandExtentIsOne = + b.create(loc, CmpIPredicate::eq, greaterOperandExtent, one); + auto ifOp = b.create( + loc, TypeRange{indexTy}, greaterOperandExtentIsOne, + [&](OpBuilder &b, Location loc) { + Value ivShifted = b.create(loc, indexTy, iv, rankDiff); + Value smallerOperandExtent = b.create( + loc, smallerOperand, ValueRange{ivShifted}); + b.create(loc, smallerOperandExtent); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, greaterOperandExtent); + }); + Value extent = ifOp.getResult(0); + Value extentI64 = b.create( + loc, i64Ty, extent); // TODO: Remove this cast when redundant. + b.create(loc, extentI64, mem, ValueRange{iv}); + b.create(loc); + }); + + // Load broadcasted shape as an extent tensor. + rewriter.replaceOpWithNewOp(op, mem); + return success(); +} + namespace { /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is /// only defined on `tensor` operands. The test for equality first @@ -226,7 +326,6 @@ // Setup target legality. ConversionTarget target(getContext()); target.addLegalDialect(); - target.addLegalOp(); // Apply conversion. if (failed(applyPartialConversion(getFunction(), target, patterns))) @@ -237,6 +336,7 @@ OwningRewritePatternList &patterns, MLIRContext *ctx) { // clang-format off patterns.insert< + BroadcastOpConverter, ShapeEqOpConverter, ReduceOpConverter, ShapeOfOpConverter>(ctx); diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -82,3 +82,55 @@ %result = shape.shape_eq %a, %b : tensor, tensor return %result : i1 } + +// ----- + +// Don't lower `shape.broadcast` if a `shape.shape` type is involved. +// CHECK-LABEL: @broadcast +func @broadcast(%a : tensor, %b : !shape.shape) -> !shape.shape { + // CHECK: shape.broadcast + %c = shape.broadcast %a, %b : tensor, !shape.shape -> !shape.shape + return %c : !shape.shape +} + +// ----- + +// CHECK-LABEL: @broadcast +// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) +func @broadcast(%a : tensor, %b : tensor) { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor + // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor + // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] + // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor, index, tensor) { + // CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor, index, tensor + // CHECK: } else { + // CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor, index, tensor + // CHECK: } + // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref + // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index + // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { + // CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor + // CHECK: %[[EXTENT_I64:.*]] = index_cast %[[EXTENT]] : index to i64 + // CHECK: store %[[EXTENT_I64]], %[[MEM]][%[[IV]]] : memref + // CHECK: } + // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] { + // CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor + // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index + // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) { + // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index + // CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor + // CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index + // CHECK: } else { + // CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index + // CHECK: } + // CHECK: %[[EXTENT_I64:.*]] = index_cast %[[EXTENT]] : index to i64 + // CHECK: store %[[EXTENT_I64]], %[[MEM]][%[[IV]]] : memref + // CHECK: } + // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref + %0 = shape.broadcast %a, %b + : tensor, tensor -> tensor + return +} +