diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -19,7 +19,7 @@ def SCF_Dialect : Dialect { let name = "scf"; - let cppNamespace = ""; + let cppNamespace = "scf"; } // Base class for SCF dialect ops. @@ -39,7 +39,7 @@ def ForOp : SCF_Op<"for", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp">, + SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects]> { let summary = "for operation"; let description = [{ @@ -183,7 +183,7 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects, + SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects, NoRegionArguments]> { let summary = "if-then-else operation"; let description = [{ @@ -271,7 +271,7 @@ [AttrSizedOperandSegments, DeclareOpInterfaceMethods, RecursiveSideEffects, - SingleBlockImplicitTerminator<"YieldOp">]> { + SingleBlockImplicitTerminator<"scf::YieldOp">]> { let summary = "parallel for operation"; let description = [{ The "scf.parallel" operation represents a loop nest taking 4 groups of SSA diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1474,6 +1474,38 @@ let summary = "floating point division operation"; } +//===----------------------------------------------------------------------===// +// DynamicTensorFromElementsOp +//===----------------------------------------------------------------------===// + +def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements", + [NoSideEffect]> { + string summary = "Creates a dynamically sized tensor from elements"; + string description = [{ + This operation creates a dynamically sized tensor with elements of any type. + It expects one index operand per dynamic dimension of the result tensor, + each defining the corresponding extent at runtime. + + The body region defines the tensor's elements. It takes index operands as + its region arguments that span the index space. The element at the given + position is yielded with the `yield` operation (see `YieldOp`). + + Example: + + ```mlir + %tnsr = dynamic_tensor_from_elements %m, %n { + ^bb0(%i : index, %j : index, %k : index): + ... + yield %elem : f32 + } : tensor + ``` + }]; + + let arguments = (ins Variadic:$dynamicDimensions); + let results = (outs AnyRankedTensor:$result); + let regions = (region SizedRegion<1>:$body); +} + //===----------------------------------------------------------------------===// // ExpOp //===----------------------------------------------------------------------===// @@ -3223,6 +3255,22 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +def YieldOp : Std_Op<"yield", [NoSideEffect, Terminator]> { + let summary = "Yield a value from a region"; + let description = [{ + This operation is used to yield a single value from a within a region. It + is used to create dynamically sized tensors + (see `DynamicTensorFromElementsOp`). + }]; + + let arguments = (ins AnyType:$value); + let assemblyFormat = "$value attr-dict `:` type($value)"; +} + //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -338,7 +338,8 @@ class YieldOpConversion : public ConvertToLLVMPattern { public: explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context, + lowering_) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -356,7 +356,7 @@ // A loop is constructed with an empty "yield" terminator if there are // no results. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.create(loc, forOp.getResults()); + rewriter.create(loc, forOp.getResults()); } rewriter.setInsertionPointToStart(forOp.getBody()); @@ -391,7 +391,7 @@ if (!yieldOperands.empty()) { rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.create(loc, yieldOperands); + rewriter.create(loc, yieldOperands); } rewriter.replaceOp(parallelOp, loopResults); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -898,7 +898,7 @@ // YieldOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, YieldOp op) { +static void print(OpAsmPrinter &p, linalg::YieldOp op) { p << op.getOperationName(); if (op.getNumOperands() > 0) p << ' ' << op.getOperands(); @@ -919,7 +919,8 @@ // Check the operand number and types must match the element types of the // LinalgOp interface's shaped operands. -static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) { +static LogicalResult verifyYield(linalg::YieldOp op, + LinalgOp linalgOpInterface) { auto nOutputs = linalgOpInterface.getNumOutputs(); if (op.getNumOperands() != nOutputs) return op.emitOpError("expected number of yield values (") @@ -939,7 +940,7 @@ return success(); } -static LogicalResult verify(YieldOp op) { +static LogicalResult verify(linalg::YieldOp op) { auto *parentOp = op.getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return op.emitOpError("expected single non-empty parent region"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -660,7 +660,7 @@ // Add operations from producer (except the yield operation) to the fused // op. for (auto &op : producerBlock.getOperations()) { - if (auto yieldOp = dyn_cast(op)) { + if (auto yieldOp = dyn_cast(op)) { // Lookup the value the yield operation is mapped to. Value yieldVal = yieldOp.getOperand(0); if (Value clonedVal = mapper.lookupOrNull(yieldVal)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -147,7 +147,7 @@ } Operation &terminator = block.back(); - assert(isa(terminator) && + assert(isa(terminator) && "expected a yield op in the end of the region"); for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { IndexedValueType O(outputBuffers[i]); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -48,14 +48,14 @@ auto c = m_Val(r.getArgument(2)); // TODO: Update this detection once we have matcher support for specifying // that any permutation of operands matches. - auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); - auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); - auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); - auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); - auto pattern5 = m_Op(m_Op(m_Op(a, b), c)); - auto pattern6 = m_Op(m_Op(c, m_Op(a, b))); - auto pattern7 = m_Op(m_Op(m_Op(b, a), c)); - auto pattern8 = m_Op(m_Op(c, m_Op(b, a))); + auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); + auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); + auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); + auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); + auto pattern5 = m_Op(m_Op(m_Op(a, b), c)); + auto pattern6 = m_Op(m_Op(c, m_Op(a, b))); + auto pattern7 = m_Op(m_Op(m_Op(b, a), c)); + auto pattern8 = m_Op(m_Op(c, m_Op(b, a))); return pattern1.match(&r.front().back()) || pattern2.match(&r.front().back()) || pattern3.match(&r.front().back()) || diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -38,7 +38,7 @@ // as necessary. Required when the region has only one block. void handleTerminator(Operation *op, ArrayRef valuesToRepl) const final { - auto retValOp = dyn_cast(op); + auto retValOp = dyn_cast(op); if (!retValOp) return; @@ -889,7 +889,7 @@ return success(); } -static void print(OpAsmPrinter &p, YieldOp op) { +static void print(OpAsmPrinter &p, scf::YieldOp op) { p << op.getOperationName(); if (op.getNumOperands() != 0) p << ' ' << op.getOperands() << " : " << op.getOperandTypes(); @@ -899,5 +899,9 @@ // TableGen'd op method definitions //===----------------------------------------------------------------------===// +namespace mlir { +namespace scf { #define GET_OP_CLASSES #include "mlir/Dialect/SCF/SCFOps.cpp.inc" +} // namespace scf +} // namespace mlir diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -779,7 +779,7 @@ // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(YieldOp op) { +static LogicalResult verify(shape::YieldOp op) { auto *parentOp = op.getParentOp(); auto results = parentOp->getResults(); auto operands = op.getOperands(); diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -45,7 +45,7 @@ OpBuilder b = OpBuilder::atBlockEnd(body); Value product = b.create(loc, valueType, body->getArgument(1), body->getArgument(2)); - b.create(loc, product); + b.create(loc, product); rewriter.replaceOp(op, reduce.result()); return success(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1312,7 +1312,6 @@ } static LogicalResult verify(DimOp op) { - // Assume unknown index to be in range. Optional index = op.getConstantIndex(); if (!index.hasValue()) @@ -1634,6 +1633,67 @@ return success(); } +//===----------------------------------------------------------------------===// +// DynamicTensorFromElementsOp +//===----------------------------------------------------------------------===// + +static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser, + OperationState &result) { + // Parse operands. + SmallVector dynamicDimensions; + Type indexTy = parser.getBuilder().getIndexType(); + if (parser.parseOperandList(dynamicDimensions) || + parser.resolveOperands(dynamicDimensions, indexTy, result.operands)) + return failure(); + + // Parse body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, {}, {})) + return failure(); + + // Parse result type. + Type resultType; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultType)) + return failure(); + result.addTypes(resultType); + + return success(); +} + +static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) { + p << "dynamic_tensor_from_elements " << op.dynamicDimensions(); + p.printRegion(op.body()); + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getType(); +} + +static LogicalResult verify(DynamicTensorFromElementsOp op) { + // Ensure that the tensor type has as many dynamic dimensions as are specified + // by the operands. + RankedTensorType resultTy = op.getType().cast(); + if (op.getNumOperands() != resultTy.getNumDynamicDims()) + return op.emitError( + "must have as many index operands as dynamic dimensions in the result " + "type"); + + // Ensure that region arguments span the index space. + if (op.body().getNumArguments() != resultTy.getRank() || + !llvm::all_of(op.body().getArgumentTypes(), + [](Type ty) { return ty.isIndex(); })) + return op.emitError("body arguments must span index space"); + + // Ensure that the region yields an element of the right type. + auto yieldOp = + llvm::dyn_cast(op.body().getBlocks().front().getTerminator()); + if (!yieldOp || yieldOp.value().getType() != resultTy.getElementType()) + return op.emitOpError( + "body must be terminated with a `yield` operation of the tensor " + "element type"); + + return success(); +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// @@ -3248,6 +3308,12 @@ [](APInt a, APInt b) { return a ^ b; }); } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(YieldOp op) { return success(); } + //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -15,3 +15,68 @@ %0 = index_cast %arg0 : tensor to i64 return %0 : i64 } + +// ----- + +func @dynamic_tensor_from_elements(%m : index) + -> tensor { + // expected-error @+1 {{must have as many index operands as dynamic dimensions in the result type}} + %tnsr = dynamic_tensor_from_elements %m { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + yield %elem : f32 + } : tensor + return %tnsr : tensor +} + +// ----- + +func @dynamic_tensor_from_elements(%m : index, %n : index) + -> tensor { + // expected-error @+1 {{body arguments must span index space}} + %tnsr = dynamic_tensor_from_elements %m, %n { + ^bb0(%i : index, %j : index): + %elem = constant 8.0 : f32 + yield %elem : f32 + } : tensor + return %tnsr : tensor +} + +// ----- + +func @dynamic_tensor_from_elements(%m : index, %n : index) + -> tensor { + // expected-error @+1 {{body arguments must span index space}} + %tnsr = dynamic_tensor_from_elements %m, %n { + ^bb0(%i : index, %j : index, %k : i64): + %elem = constant 8.0 : f32 + yield %elem : f32 + } : tensor + return %tnsr : tensor +} + +// ----- + +func @dynamic_tensor_from_elements(%m : index, %n : index) + -> tensor { + // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}} + %tnsr = dynamic_tensor_from_elements %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + return %elem : f32 + } : tensor + return %tnsr : tensor +} + +// ----- + +func @dynamic_tensor_from_elements(%m : index, %n : index) + -> tensor { + // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}} + %tnsr = dynamic_tensor_from_elements %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8 : i32 + yield %elem : i32 + } : tensor + return %tnsr : tensor +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt -split-input-file %s | FileCheck %s +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s // CHECK-LABEL: test_index_cast func @test_index_cast(%arg0 : index) -> i64 { @@ -22,3 +23,14 @@ assert %arg, "Some message in case this assertion fails." return } + +func @dynamic_tensor_from_elements(%m : index, %n : index) + -> tensor { + %tnsr = dynamic_tensor_from_elements %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + yield %elem : f32 + } : tensor + return %tnsr : tensor +} +