diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -183,14 +183,14 @@ } def FillOp : LinalgStructured_Op<"fill", []> { - let arguments = (ins AnyShaped:$output, - AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger, - AnyVector]>:$value); + let arguments = (ins + AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger, AnyVector]>:$value, + AnyShaped:$output); let results = (outs Optional:$result); let regions = (region AnyRegion:$region); let extraClassDeclaration = structuredOpsDecls # [{ - ValueRange inputs() { return {}; } - ValueRange outputs() { return getOperands().take_front(); } + ValueRange inputs() { return getOperands().take_front(); } + ValueRange outputs() { return getOperands().take_back(); } // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. @@ -204,6 +204,7 @@ MLIRContext *context = getContext(); // filling_value -> O(ivs) return Builder(getContext()).getAffineMapArrayAttr({ + AffineMap::get(getNumParallelLoops(), 0, {}, getContext()), extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } @@ -214,13 +215,13 @@ getRegionBuilder() { return ®ionBuilder; } - static unsigned getNumRegionArgs() { return 1; } + static unsigned getNumRegionArgs() { return 2; } }]; let assemblyFormat = [{ `(` $output `,` $value `)` attr-dict `:` type($output) `,` type($value) (`->` type($result)^)? - custom($region, ref(type($output)), ref($value)) + custom($region, ref(type($output)), ref(type($value))) }]; let builders = [ diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -104,9 +104,21 @@ if (isa(op)) return failure(); + // Swap the operand order of the FillOp to maintain the pretty printed + // signature that takes an output buffer followed by the fill value. + SmallVector originalOperandOrder = op->getOperands(); + if (auto fillOp = dyn_cast(op.getOperation())) { + Value value = fillOp.value(); + Value output = fillOp.output(); + op->setOperands(ValueRange{output, value}); + } + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) + if (!libraryCallName) { + // Restore the operand order in case it has been modified. + op->setOperands(originalOperandOrder); return failure(); + } // TODO: Add support for more complex library call signatures that include // indices or captured values. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -421,32 +421,29 @@ //===----------------------------------------------------------------------===// void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ValueRange captures) { - assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture"); - b.create(captures); + assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); + b.create(block.getArgument(0)); } void FillOp::build(OpBuilder &builder, OperationState &result, Value output, Value value) { - build(builder, result, output.getType().dyn_cast(), output, - value); - fillStructuredOpRegion(builder, *result.regions.front(), TypeRange{}, - TypeRange{output.getType()}, value); + build(builder, result, output.getType().dyn_cast(), value, + output); + fillStructuredOpRegion(builder, *result.regions.front(), + TypeRange{value.getType()}, + TypeRange{output.getType()}, {}); } ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType, - OpAsmParser::OperandType valueRef) { + Type valueType) { OpBuilder opBuilder(parser.getBuilder().getContext()); - // Resolve `valueRef` into `value` at parse time so we can build the region - // with captures. - SmallVector value; - parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value); - fillStructuredOpRegion(opBuilder, r, TypeRange{}, - TypeRange{outputType}, value); + fillStructuredOpRegion(opBuilder, r, TypeRange{valueType}, + TypeRange{outputType}); return success(); } /// FillOp region is elided when printing. -void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {} +void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} static LogicalResult verify(FillOp op) { OpOperand *output = op.getOutputOperand(0); diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -31,13 +31,13 @@ if isa(RankedTensorType, output.type): results = [output.type] op = self.build_generic(results=results, - operands=[output, value], + operands=[value, output], attributes=None, loc=loc, ip=ip) OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, self.operation, [value]) + fill_builtin_region(linalgDialect, self.operation, []) # TODO: self.result is None. When len(results) == 1 we expect it to be # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug # in the generator of _linalg_ops_gen.py where we have: diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -296,14 +296,15 @@ // TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]]) // TLOOP-SAME: step (%[[C32]], %[[C64]]) // TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]], -// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]]) +// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]], +// TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]] // TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) { // TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]] // TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0] // TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]] // TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]] -// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]]) +// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32_]]) // TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) // TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]]) @@ -398,3 +399,4 @@ // TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]] // TLOOP: } // TLOOP: return %[[AB]] : [[TY]] + diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -476,15 +476,17 @@ return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: func @generalize_fill // CHECK-SAME: (%[[ARG0:.+]]: memref, %[[VAL:.+]]: f32) // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[VAL]] : f32) // CHECK-SAME: outs(%{{.+}} : memref) -// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32) -// CHECK-NEXT: linalg.yield %[[VAL]] : f32 +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) +// CHECK-NEXT: linalg.yield %[[BBARG0]] : f32 diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -734,7 +734,7 @@ func @illegal_fill_tensor_with_memref_return (%arg0 : tensor, %arg1 : f32) -> memref { - // expected-error @+1 {{expected type of operand #0 ('tensor') to match type of corresponding result ('memref')}} + // expected-error @+1 {{expected type of operand #1 ('tensor') to match type of corresponding result ('memref')}} %0 = linalg.fill(%arg0, %arg1) : tensor, f32 -> memref return %0 : memref } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -224,8 +224,8 @@ patterns.add>( ctx, LinalgPromotionOptions() - .setOperandsToPromote({0}) - .setUseFullTileBuffers({true}) + .setOperandsToPromote({1}) + .setUseFullTileBuffers({false, true}) .setAlignment(32), LinalgTransformationFilter( Identifier::get("_promote_views_aligned_", ctx),