diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -440,7 +440,9 @@ static std::function)> - getRegionBuilder(); + getRegionBuilder() { + return nullptr; + } static void createRegion(::mlir::OpBuilder &opBuilder, ::mlir::OperationState & odsState); @@ -450,6 +452,79 @@ let hasVerifier = 1; } + +//===----------------------------------------------------------------------===// +// Broadcast op. +//===----------------------------------------------------------------------===// + +def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ + DeclareOpInterfaceMethods, + SameVariadicOperandSize, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Static broadcast operator"; + let description = [{ + Broadcast the input into the given shape by adding dimensions. + + Each index in `dimensions` attribute maps input dimension into the + corresponding target dimension. The length of the `dimensions` list should + match the `input` rank and dimensions should be in sorted order. There is no + ambiguity at compile-time about shape information. + + Example: + ``` + %bcast = linalg.broadcast + ins(%input:tensor<16xf32>) + inits(%init:tensor<16x64xf32>) + dimensions = [0] + ``` + }]; + + let arguments = (ins + // Input arg + TensorOrMemref:$input, + // Output arg + TensorOrMemref:$init, + + DenseI64ArrayAttr:$dimensions + ); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "Value":$input, "Value":$init, + "DenseI64ArrayAttr":$dimensions, CArg<"ArrayRef", + "{}">:$attributes)>, + OpBuilder<(ins "Value":$input, "Value":$init, + "ArrayRef":$dimensions, CArg<"ArrayRef", + "{}">:$attributes)>, + ]; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + // Declare functions necessary for LinalgStructuredInterface. + SmallVector getIteratorTypesArray(); + ArrayAttr getIndexingMaps(); + std::string getLibraryCallName() { + return "op_has_no_registered_library_name"; + } + + // Implement functions necessary for DestinationStyleOpInterface. + std::pair getDpsInitsPositionRange() { + int64_t getNumOperands = this->getNumOperands(); + return {getNumOperands - 1, getNumOperands}; + } + + static std::function)> + getRegionBuilder() { + return nullptr; + } + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -662,7 +662,7 @@ //===----------------------------------------------------------------------===// static void buildGenericRegion( - OpBuilder &builder, OperationState &result, ValueRange inputs, + OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref bodyBuild) { SmallVector blockArgTypes; @@ -675,10 +675,9 @@ } OpBuilder::InsertionGuard guard(builder); - auto ®ion = *result.regions.front(); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); - bodyBuild(builder, result.location, bodyBlock->getArguments()); + bodyBuild(builder, loc, bodyBlock->getArguments()); } void GenericOp::getAsmBlockArgumentNames(Region ®ion, @@ -699,7 +698,8 @@ iteratorTypes, doc, libraryCall); result.addAttributes(attributes); if (bodyBuild) - buildGenericRegion(builder, result, inputs, outputs, bodyBuild); + buildGenericRegion(builder, result.location, *result.regions.front(), + inputs, outputs, bodyBuild); } void GenericOp::build( @@ -1346,7 +1346,8 @@ result.addTypes(initType); if (bodyBuild) - buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild); + buildGenericRegion(builder, result.location, *result.regions.front(), + inputs, /*outputs=*/{}, bodyBuild); } ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1471,7 +1472,8 @@ } if (bodyBuild) - buildGenericRegion(builder, result, inputs, inits, bodyBuild); + buildGenericRegion(builder, result.location, *result.regions.front(), + inputs, inits, bodyBuild); } SmallVector ReduceOp::getIteratorTypesArray() { @@ -1648,13 +1650,13 @@ // TransposeOp //===----------------------------------------------------------------------===// -std::function)> -TransposeOp::getRegionBuilder() { - return [](mlir::ImplicitLocOpBuilder &b, mlir::Block &block, - mlir::ArrayRef) { - b.create(block.getArguments().front()); - }; +static void buildIdentityRegion(OpBuilder &builder, Location loc, + Region ®ion, ValueRange inputs, + ValueRange outputs) { + buildGenericRegion(builder, loc, region, inputs, outputs, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }); } void TransposeOp::build(::mlir::OpBuilder &builder, @@ -1671,11 +1673,8 @@ if (initType.isa()) result.addTypes(initType); - (void)result.addRegion(); - buildGenericRegion(builder, result, input, init, - [&](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }); + buildIdentityRegion(builder, result.location, *result.addRegion(), input, + init); } void TransposeOp::build(::mlir::OpBuilder &builder, @@ -1693,13 +1692,10 @@ }))) return failure(); - (void)result.addRegion(); OpBuilder builder(parser.getContext()); - buildGenericRegion(builder, result, /*inputs=*/result.operands, - /*outputs=*/{}, - [&](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }); + buildIdentityRegion(builder, result.location, *result.addRegion(), + /*inputs=*/result.operands, + /*outputs=*/{}); return success(); } @@ -1778,6 +1774,143 @@ getDpsInputOperands(), getDpsInitOperands()); } +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +void BroadcastOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &result, Value input, Value init, + DenseI64ArrayAttr dimensions, + ArrayRef attributes) { + result.addOperands(input); + result.addOperands(init); + result.addAttribute(getDimensionsAttrName(result.name), dimensions); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + Type initType = init.getType(); + if (initType.isa()) + result.addTypes(initType); + + buildIdentityRegion(builder, result.location, *result.addRegion(), input, + init); +} + +void BroadcastOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &result, Value input, Value init, + ArrayRef dimensions, + ArrayRef attributes) { + build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions), + attributes); +} + +ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) { + if (failed(parseDstStyleOp( + parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); + }))) + return failure(); + + OpBuilder builder(parser.getContext()); + buildIdentityRegion(builder, result.location, *result.addRegion(), + /*inputs=*/result.operands, + /*outputs=*/{}); + return success(); +} + +void BroadcastOp::getAsmResultNames( + function_ref setNameFn) { + if (!getResults().empty()) + setNameFn(getResults().front(), "broadcasted"); +} + +void BroadcastOp::print(OpAsmPrinter &p) { + p.increaseIndent(); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); + p.printNewline(); + + printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); + p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); + p.decreaseIndent(); +} + +LogicalResult BroadcastOp::verify() { + ArrayRef dimensionsRef = getDimensions(); + + if (!llvm::is_sorted(dimensionsRef)) + return emitOpError() << "dimensions should be in sorted order, implicit " + "transpose is not supported"; + + auto inputType = getInput().getType(); + auto initType = getInit().getType(); + + int64_t inputRank = inputType.getRank(); + int64_t initRank = initType.getRank(); + + auto inputShape = inputType.getShape(); + auto initShape = initType.getShape(); + + if (inputRank != dimensionsRef.size()) + return emitOpError() + << "input rank does match the number of dimensions. expected: " + << inputRank << ", got: " << dimensionsRef.size(); + + // Mapping from init dims to input dims. + SmallVector reverseDimMap(initRank, -1); + + for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) { + if (dim < 0 || dim >= initRank) + return emitOpError() << "dimension " << idx + << " is out of range. expected range: [0, " + << initRank - 1 << "], got: " << dim; + + reverseDimMap[dim] = idx; + } + + for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) { + if (inputDimIdx == -1) { + // This dimensions is being added. Should be statically known. + if (ShapedType::isDynamic(initShape[idx])) + return emitOpError() + << "init dim " << idx + << " can't be dynamic, because it's not matched to input"; + } else { + // This dimensions is mapped from the input. Init and input dims should + // match. + if (inputShape[inputDimIdx] != initShape[idx]) + return emitOpError() + << "input dim " << inputDimIdx << " should match init dim " + << idx << ". input: " << inputShape[inputDimIdx] + << ", init: " << initShape[idx]; + } + } + + return success(); +} + +SmallVector BroadcastOp::getIteratorTypesArray() { + int64_t rank = getInit().getType().getRank(); + return SmallVector(rank, getParallelIteratorTypeName()); +} + +ArrayAttr BroadcastOp::getIndexingMaps() { + Builder builder(getContext()); + int64_t rank = getInit().getType().getRank(); + return builder.getAffineMapArrayAttr( + {builder.getMultiDimIdentityMap(rank).getSubMap( + llvm::to_vector_of(getDimensions())), + builder.getMultiDimIdentityMap(rank)}); +} + +void BroadcastOp::getEffects( + SmallVectorImpl> + &effects) { + getGenericEffectsImpl(effects, getOperation()->getResults(), + getDpsInputOperands(), getDpsInitOperands()); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -673,3 +673,81 @@ permutation = [1, 0, 2] func.return %transpose : tensor<32x64x16xf32> } + +// ----- + +func.func @broadcast_unsorted_dims( + %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) + -> tensor<4x8x16xf32> { + // expected-error @+1 {{'linalg.broadcast' op dimensions should be in sorted order}} + %bcast = linalg.broadcast + ins(%input:tensor<4x16xf32>) + outs(%init:tensor<4x8x16xf32>) + dimensions = [1, 0] + func.return %bcast : tensor<4x8x16xf32> +} + +// ----- + +func.func @broadcast_input_dims_rank_mismatch( + %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) + -> tensor<4x8x16xf32> { + // expected-error @+1 {{'linalg.broadcast' op input rank does match the number of dimensions. expected: 2, got: 1}} + %bcast = linalg.broadcast + ins(%input:tensor<4x16xf32>) + outs(%init:tensor<4x8x16xf32>) + dimensions = [0] + func.return %bcast : tensor<4x8x16xf32> +} + +// ----- + +func.func @broadcast_unsorted_dims( + %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) + -> tensor<4x8x16xf32> { + // expected-error @+1 {{'linalg.broadcast' op dimension 1 is out of range. expected range: [0, 2], got: 5}} + %bcast = linalg.broadcast + ins(%input:tensor<4x16xf32>) + outs(%init:tensor<4x8x16xf32>) + dimensions = [0, 5] + func.return %bcast : tensor<4x8x16xf32> +} + +// ----- + +func.func @broadcast_mapped_dim_mismatch( + %input: tensor<4x16xf32>, %init: tensor<5x8x16xf32>) + -> tensor<5x8x16xf32> { + // expected-error @+1 {{'linalg.broadcast' op input dim 0 should match init dim 0. input: 4, init: 5}} + %bcast = linalg.broadcast + ins(%input:tensor<4x16xf32>) + outs(%init:tensor<5x8x16xf32>) + dimensions = [0, 2] + func.return %bcast : tensor<5x8x16xf32> +} + +// ----- + +func.func @broadcast_added_dynamic_mismatch( + %input: tensor<4x16xf32>, %init: tensor<4x?x16xf32>) + -> tensor<4x?x16xf32> { + // expected-error @+1 {{'linalg.broadcast' op init dim 1 can't be dynamic, because it's not matched to input}} + %bcast = linalg.broadcast + ins(%input:tensor<4x16xf32>) + outs(%init:tensor<4x?x16xf32>) + dimensions = [0, 2] + func.return %bcast : tensor<4x?x16xf32> +} + +// ----- + +func.func @broadcast_size_1_extension_not_supported( + %input: tensor<1x16xf32>, %init: tensor<4x?x16xf32>) + -> tensor<4x?x16xf32> { + // expected-error @+1 {{'linalg.broadcast' op input dim 0 should match init dim 0. input: 1, init: 4}} + %bcast = linalg.broadcast + ins(%input:tensor<1x16xf32>) + outs(%init:tensor<4x?x16xf32>) + dimensions = [0, 2] + func.return %bcast : tensor<4x?x16xf32> +} diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -388,6 +388,19 @@ // ----- +// CHECK-LABEL: func @broadcast +// CHECK-SAME: %[[ARG0:.*]]: memref<8x32xf32 +func.func @broadcast(%input: tensor<8x32xf32>, + %init: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %bcast = linalg.broadcast + ins(%input:tensor<8x32xf32>) + outs(%init:tensor<8x16x32xf32>) + dimensions = [0, 2] + func.return %bcast : tensor<8x16x32xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // AllocTensorOp elimination would produce SSA violations for the example below. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -517,3 +517,53 @@ func.return } // CHECK-LABEL: func @transpose_memref + +// ----- + +func.func @broadcast_static_sizes(%input: tensor<8x32xf32>, + %init: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %bcast = linalg.broadcast + ins(%input:tensor<8x32xf32>) + outs(%init:tensor<8x16x32xf32>) + dimensions = [0, 2] + func.return %bcast : tensor<8x16x32xf32> +} +// CHECK-LABEL: func @broadcast_static_sizes +// CHECK: linalg.broadcast +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: dimensions + +// ----- + +func.func @broadcast_with_dynamic_sizes( + %input: tensor<8x?xf32>, %init: tensor<8x16x?xf32>) + -> tensor<8x16x?xf32> { + %bcast = linalg.broadcast + ins(%input:tensor<8x?xf32>) + outs(%init:tensor<8x16x?xf32>) + dimensions = [0, 2] + func.return %bcast : tensor<8x16x?xf32> +} +// CHECK-LABEL: func @broadcast_with_dynamic_sizes +// CHECK: linalg.broadcast +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: dimensions + +// ----- + +func.func @broadcast_memref(%input: memref<8x32xf32>, + %init: memref<8x16x32xf32>) { + linalg.broadcast + ins(%input:memref<8x32xf32>) + outs(%init:memref<8x16x32xf32>) + dimensions = [0, 2] + func.return +} + +// CHECK-LABEL: func @broadcast_memref +// CHECK: linalg.broadcast +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: dimensions diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir @@ -240,3 +240,29 @@ // CHECK: %[[OUT_ELEM:.*]] = memref.load %[[OUT]][%[[I]], %[[K]]] // CHECK: %[[ADD:.*]] = arith.addf %[[IN_ELEM]], %[[OUT_ELEM]] // CHECK: memref.store %[[ADD]], %[[OUT]][%[[I]], %[[K]]] + +// ----- + +func.func @broadcast(%input: memref<8x32xf32>, + %init: memref<8x16x32xf32>) { + linalg.broadcast + ins(%input:memref<8x32xf32>) + outs(%init:memref<8x16x32xf32>) + dimensions = [0, 2] + func.return +} +// CHECK-LABEL: func.func @broadcast( +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>, +// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32> + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C8]] step %[[C1]] { +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C16]] step %[[C1]] { +// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] { +// CHECK: %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]] +// CHECK: memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]