diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -15,8 +15,6 @@ include "mlir/IR/OpBase.td" -def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>; - def Linalg_Dialect : Dialect { let name = "linalg"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -584,6 +584,19 @@ return {}; }] >, + InterfaceMethod< + /*desc=*/[{ + Return true if the `opOperand` is a scalar value. + }], + /*retTy=*/"bool", + /*methodName=*/"isScalar", + /*args=*/(ins "OpOperand*":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + return !opOperand->get().getType().template isa(); + }] + >, InterfaceMethod< /*desc=*/[{ Return the input or output indexing map for `opOperand`. @@ -694,10 +707,13 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ return this->getOperation()->getNumResults() == 0 && - llvm::all_of(getInputAndOutputOperands(), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); + llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) { + return isScalar(opOperand) || + opOperand->get().getType().template isa(); + }) && + llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); }] >, InterfaceMethod< @@ -709,8 +725,12 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::all_of(getInputAndOutputOperands(), - [](OpOperand *opOperand) { + return + llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) { + return isScalar(opOperand) || + opOperand->get().getType().template isa(); + }) && + llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) { return opOperand->get().getType().template isa(); }); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -640,8 +640,8 @@ let arguments = (ins Variadic:$lowerBound, Variadic:$upperBound, Variadic:$step, - Variadic:$inputs, - Variadic:$outputs, + Variadic:$inputs, + Variadic:$outputs, ArrayAttr:$iterator_types, OptionalAttr:$distribution_types); let results = (outs Variadic:$results); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -517,17 +517,12 @@ //===----------------------------------------------------------------------===// // Generic Linalg ops. //===----------------------------------------------------------------------===// -class LinalgOperandOfRank: Type< - And<[ - LinalgOperand.predicate, - CPred<"$_self.cast().getRank() == " # rank>] - >>; class GenericOpBase : LinalgStructuredBase_Op, SingleBlockImplicitTerminator<"YieldOp">]> { - let arguments = (ins Variadic:$inputs, + let arguments = (ins Variadic:$inputs, Variadic:$outputs, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -338,7 +338,7 @@ if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) return failure(); - // All shaped operands must be indexed. + // All input/output operands must be indexed. if (static_cast(linalgOp.indexing_maps().size()) != linalgOp.getNumInputsAndOutputs()) return op->emitOpError("expected the number of indexing_map (") @@ -363,7 +363,7 @@ int64_t rank = linalgOp.getRank(opOperand); if (indexingMap.getNumResults() != rank) - return op->emitOpError("expected shaped value rank (") + return op->emitOpError("expected operand rank (") << rank << ") to match the result rank of indexing_map #" << opOperand->getOperandNumber() << " (" << indexingMap.getNumResults() << ")"; @@ -444,7 +444,7 @@ if (linalgOp.getNumInputsAndOutputs() + numBBIvs != block.getNumArguments()) return op->emitOpError("expected as many non-induction variable region " - "arguments as the number of shaped operands"); + "arguments as the number of input/output operands"); // Note: the number and type of yield values are checked in the YieldOp. for (unsigned i = 0; i < numBBIvs; ++i) @@ -452,14 +452,14 @@ return op->emitOpError("expected index block argument #") << i; for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - Type elementType = getElementTypeOrSelf(opOperand->get().getType()); + Type elementType = getElementTypeOrSelf(opOperand->get()); Type argType = block.getArgument(numBBIvs + opOperand->getOperandNumber()).getType(); if (elementType != argType) return op->emitOpError("expected type of bb argument #") << numBBIvs + opOperand->getOperandNumber() << " (" << argType << ")" - << " to match element type of corresponding shaped operand (" + << " to match element or self type of the corresponding operand (" << elementType << ")"; } @@ -489,10 +489,11 @@ // The first index or last index should be the maximum or the minimum in // the inferred index ranges since the range is increasing or - // decreasing. The size of dimensions of shaped operands and the maximum - // value + 1 in the inferred range should be the same. But, for now we - // check if the inferred ranges are in boundary of shaped operands' size - // or not in case that Affine Expressions are complicated such as d0 * 3 + // decreasing. The size of dimensions of input/output operands and the + // maximum value + 1 in the inferred range should be the same. But, for + // now we check if the inferred ranges are in boundary of input/output + // operands' size or not in case that Affine Expressions are complicated + // such as d0 * 3 // + d1 since it is not easy to handle the issues. // Found the case that this solution can't check, for example, (d0, d1) // -> (d1 - d0) @@ -510,14 +511,14 @@ } if (indexingMap.getResult(dim).dyn_cast()) { if (inferredDimSize != shape[dim]) { - return op->emitOpError("inferred shaped operand #") + return op->emitOpError("inferred input/output operand #") << opOperand->getOperandNumber() << " has shape's dimension #" << dim << " to be " << inferredDimSize << ", but found " << shape[dim]; } } else { if (inferredDimSize > shape[dim]) { - return op->emitOpError("inferred shaped operand #") + return op->emitOpError("inferred input/output operand #") << opOperand->getOperandNumber() << " has shape's dimension #" << dim << " to be greater than or equal to " << inferredDimSize diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -377,8 +377,7 @@ static LogicalResult verify(CopyOp op) { OpOperand *output = op.getOutputOperand(0); OpOperand *input = op.getInputOperand(0); - if (getElementTypeOrSelf(input->get().getType()) != - getElementTypeOrSelf(output->get().getType())) + if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get())) return op.emitOpError("expects views of the same type"); if (op.getRank(input) != op.getRank(output)) return op.emitOpError("expects views of the same rank"); @@ -452,7 +451,7 @@ static LogicalResult verify(FillOp op) { OpOperand *output = op.getOutputOperand(0); Type fillType = op.value().getType(); - if (getElementTypeOrSelf(output->get().getType()) != fillType) + if (getElementTypeOrSelf(output->get()) != fillType) return op.emitOpError("expects fill type to match view elemental type"); if (!op.getNumResults() && !output->get().getType().isa()) { return op.emitOpError( @@ -489,7 +488,7 @@ SmallVector blockArgTypes; for (ValueRange container : {inputs, outputs}) for (Value v : container) - blockArgTypes.push_back(v.getType().cast().getElementType()); + blockArgTypes.push_back(getElementTypeOrSelf(v)); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); @@ -545,7 +544,7 @@ SmallVector blockArgTypes(nLoops, builder.getIndexType()); for (ValueRange container : {inputs, outputs}) for (Value v : container) - blockArgTypes.push_back(v.getType().cast().getElementType()); + blockArgTypes.push_back(getElementTypeOrSelf(v)); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); @@ -2949,7 +2948,6 @@ TypeRange inputTypes, TypeRange outputTypes, ValueRange captures, std::function errorHandler) { - assert(llvm::all_of(inputTypes, [](Type t) { return t.isa(); })); assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); // TODO: atm all operands go through getElementTypeOrSelf, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -484,18 +484,21 @@ b.setInsertionPoint(op); Location loc = op.getLoc(); - SmallVector newInputBuffers; - newInputBuffers.reserve(op.getNumInputs()); + SmallVector newInputs; + newInputs.reserve(op.getNumInputs()); for (OpOperand *opOperand : op.getInputOperands()) { - Value v = lookup(bvm, opOperand->get()); - if (!v) + if (op.isScalar(opOperand)) { + newInputs.push_back(opOperand->get()); + continue; + } + newInputs.push_back(lookup(bvm, opOperand->get())); + if (!newInputs.back()) return failure(); - newInputBuffers.push_back(v); } - SmallVector newOutputBuffers; + SmallVector newOutputBuffers; if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm))) return failure(); - finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm); + finalizeBufferAllocation(b, op, newInputs, newOutputBuffers, bvm); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -301,7 +301,7 @@ ++dim; } // Compute the tensor or scalar replacement type. - Type elementType = getElementTypeOrSelf(opOperand->get().getType()); + Type elementType = getElementTypeOrSelf(opOperand->get()); Type replacementType = elementType == opOperand->get().getType() ? elementType : RankedTensorType::get(newShape, elementType); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -129,14 +129,14 @@ assert(producer.hasTensorSemantics() && "only fusion on tensors is currently supported for TiledLinalgOp"); - for (OpOperand *producerInput : producer.getInputTensorOperands()) { + for (OpOperand *producerInput : producer.getInputOperands()) { OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get()); if (addedInput == nullptr) addedInput = &tiledLoop.appendInputOperand(b, producerInput->get()); BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput); tiledOperands.push_back(addedBlockArg); } - for (OpOperand *producerOutput : producer.getOutputTensorOperands()) { + for (OpOperand *producerOutput : producer.getOutputOperands()) { OpResult result = producer.getTiedOpResult(producerOutput); OpOperand *resultInputOperand = tiledLoop.findInputOperand(result); OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -126,8 +126,12 @@ // TODO: Avoid the loads if the corresponding argument of the // region has no uses. - // 1.a. Emit load from input views. + // 1.a. Emit load from input operand or for scalars access the operand itself. for (OpOperand *inputOperand : linalgOp.getInputOperands()) { + if (linalgOp.isScalar(inputOperand)) { + indexedValues.push_back(inputOperand->get()); + continue; + } auto indexing = makeCanonicalAffineApplies( b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims); indexedValues.push_back( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -149,7 +149,7 @@ } Value pad = options.paddingValueComputationFunction(rewriter, *opOperand); auto staticTensorType = RankedTensorType::get( - staticSizes, getElementTypeOrSelf(opOperand->get().getType())); + staticSizes, getElementTypeOrSelf(opOperand->get())); result = linalg::PadTensorOp::createPadHighOp( staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -479,6 +479,10 @@ SmallVector indexings; for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber()); + if (linalgOp.isScalar(opOperand)) { + bvm.map(bbarg, opOperand->get()); + continue; + } // TODO: 0-d vectors. if (linalgOp.getShape(opOperand).empty()) { Value loaded = @@ -494,14 +498,13 @@ if (broadcastToMaximalCommonShape) { map = inverseAndBroadcastProjectedPermuation( linalgOp.getTiedIndexingMap(opOperand)); - vectorType = VectorType::get( - commonVectorShape, getElementTypeOrSelf(opOperand->get().getType())); + vectorType = VectorType::get(commonVectorShape, + getElementTypeOrSelf(opOperand->get())); } else { map = inversePermutation( reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand))); - vectorType = - VectorType::get(map.compose(linalgOp.getShape(opOperand)), - getElementTypeOrSelf(opOperand->get().getType())); + vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), + getElementTypeOrSelf(opOperand->get())); } Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map); LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" @@ -1157,7 +1160,7 @@ int64_t rank = op.getRank(input); int64_t numDims = mapping.size(); - Type elemType = getElementTypeOrSelf(input->get().getType()); + Type elemType = getElementTypeOrSelf(input->get()); auto map = AffineMap::get(rank, 0, mapping, context); SmallVector zeros(rank, rewriter.create(loc, 0)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1372,6 +1372,8 @@ // Detects sparse annotations and translate the per-dimension sparsity // information for all tensors to loop indices in the kernel. assert(op.getNumOutputs() == 1); + assert(llvm::none_of(op.getInputAndOutputOperands(), + [&](OpOperand *t) { return op.isScalar(t); })); unsigned numTensors = op.getNumInputsAndOutputs(); unsigned numLoops = op.iterator_types().getValue().size(); Merger merger(numTensors, numLoops); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -2,6 +2,7 @@ #accesses = [ affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> ()>, affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> ] @@ -11,21 +12,22 @@ library_call = "some_external_func" } -func @drop_one_trip_loops(%arg0 : tensor, %shape: tensor) -> tensor { +func @drop_one_trip_loops(%arg0 : tensor, %arg1 : f32, %shape: tensor) -> tensor { %0 = linalg.generic #trait - ins(%arg0 : tensor) + ins(%arg0, %arg1 : tensor, f32) outs(%shape : tensor) { - ^bb0(%arg2 : f32, %arg3 : f32) : - linalg.yield %arg2 : f32 + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) : + linalg.yield %arg3 : f32 } -> tensor return %0 : tensor } -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @drop_one_trip_loops // CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]] // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]] +// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -292,7 +292,7 @@ // TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]] // TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]] -// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = +// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = // TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]]) // TLOOP-SAME: step (%[[C32]], %[[C64]]) // TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]], @@ -305,7 +305,80 @@ // TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]] // TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]]) -// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) +// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) +// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]]) +// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]], +// TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]]) +// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]]) +// TLOOP-SAME: iterators["reduction"] { + +// TLOOP: %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]] +// TLOOP: %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0] + +// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul +// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]]) +// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]] +// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]] +// TLOOP: } +// TLOOP: %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]] +// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]] +// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]] +// TLOOP: } +// TLOOP: return %[[AB]] : [[TY]] + +// ----- + +module { + func @generic_plus_matmul(%arg0: tensor, %arg1: tensor, + %arg2: tensor) -> tensor { + %c0 = constant 0.0 : f32 + %0 = linalg.generic { + indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], + iterator_types = ["parallel", "parallel"]} + ins(%c0 : f32) + outs(%arg0: tensor) { + ^bb(%0: f32, %1: f32) : + linalg.yield %0 : f32 + } -> tensor + %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) -> tensor + return %1 : tensor + } +} + +// TLOOP-LABEL: func @generic_plus_matmul( +// TLOOP-SAME: %[[OUT:[a-zA-Z0-9_]+]]: tensor +// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor +// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor + +// TLOOP-DAG: %[[C0_F32:.*]] = constant 0.0 +// TLOOP-DAG: %[[C32:.*]] = constant 32 : index +// TLOOP-DAG: %[[C64:.*]] = constant 64 : index +// TLOOP-DAG: %[[C16:.*]] = constant 16 : index +// TLOOP-DAG: %[[C0:.*]] = constant 0 : index +// TLOOP-DAG: %[[C1:.*]] = constant 1 : index + +// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]] +// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]] + +// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = +// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]]) +// TLOOP-SAME: step (%[[C32]], %[[C64]]) +// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]], +// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]], +// TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]] +// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) { + +// TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]] +// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0] +// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]] +// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]] +// TLOOP: %[[INIT_SUB:.*]] = linalg.generic +// TLOOP-SAME: ins(%[[C0_F32_]] +// TLOOP-SAME: outs(%[[OUT_SUB]] + +// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) // TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]]) // TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]], // TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]]) diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -40,6 +40,48 @@ // ----- +// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> ()> +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> ()> + +// CHECK-LABEL: @scalar_add_mul_fusion +func @scalar_add_mul_fusion(%arg0: tensor, %arg1 : f32, %arg2 : f32) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = memref.dim %arg0, %c0 : tensor + %1 = memref.dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, f32) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %4 = addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor + // CHECK: linalg.generic { + // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP1]], [[$MAP1]], [[$MAP0]]{{\]}} + %4 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} + ins(%3, %arg2 : tensor, f32) + outs(%2 : tensor) { + // CHECK: ^{{[a-zA-Z0-9_]*}} + // CHECK-SAME: [[ARG3:%[a-zA-Z0-9_]*]] + // CHECK-SAME: [[ARG4:%[a-zA-Z0-9_]*]] + // CHECK-SAME: [[ARG5:%[a-zA-Z0-9_]*]] + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): // no predecessors + // CHECK: [[T1:%[a-zA-Z0-9_]*]] = addf [[ARG3]], [[ARG4]] + // CHECK-NOT: linalg.yield + // CHECK: mulf [[T1]], [[ARG5]] + // CHECK: linalg.yield + %5 = mulf %arg5, %arg6 : f32 + linalg.yield %5 : f32 + } -> tensor + return %4 : tensor +} + +// ----- + // CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d1, d0)> #map0 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -96,7 +96,7 @@ // ----- func @generic_one_d_view(%arg0: memref(off + i)>>) { - // expected-error @+1 {{expected shaped value rank (1) to match the result rank of indexing_map #0 (2)}} + // expected-error @+1 {{expected operand rank (1) to match the result rank of indexing_map #0 (2)}} linalg.generic { indexing_maps = [ affine_map<() -> (0, 0)> ], iterator_types = []} @@ -108,6 +108,21 @@ // ----- +func @generic_scalar_view(%arg0: memref(off + i)>>) { + %cst = constant 0.0 : f32 + // expected-error @+1 {{expected operand rank (0) to match the result rank of indexing_map #0 (1)}} + linalg.generic { + indexing_maps = [ affine_map<() -> (0)>, affine_map<() -> (0, 0)> ], + iterator_types = []} + ins(%cst : f32) + outs(%arg0 : memref(off + i)>>) { + ^bb(%0 : f32, %1 : f32): + linalg.yield %0: f32 + } +} + +// ----- + func @generic_result_0_element_type(%arg0: memref(off + i)>>) { // expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}} linalg.generic { @@ -174,7 +189,7 @@ // ----- func @generic_mismatched_num_arguments(%arg0: memref) { - // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}} + // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}} linalg.generic { indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ], iterator_types = []} @@ -186,8 +201,8 @@ // ----- -func @generic_block_arg_type(%arg0: memref) { - // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element type of corresponding shaped operand ('f32')}} +func @generic_shaped_operand_block_arg_type(%arg0: memref) { + // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}} linalg.generic { indexing_maps = [ affine_map<() -> ()> ], iterator_types = []} @@ -199,8 +214,21 @@ // ----- +func @generic_scalar_operand_block_arg_type(%arg0: f32) { + // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}} + linalg.generic { + indexing_maps = [ affine_map<() -> ()> ], + iterator_types = []} + outs(%arg0 : f32) { + ^bb(%i: i1): + linalg.yield %i : i1 + } +} + +// ----- + func @indexed_generic_block_arg_count(%arg0: memref) { - // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}} + // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}} linalg.indexed_generic { indexing_maps = [ affine_map<(i) -> (i)> ], iterator_types = ["parallel"]} @@ -226,7 +254,7 @@ // ----- func @indexed_generic_block_arg_type(%arg0: memref) { - // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element type of corresponding shaped operand ('f32')}} + // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element or self type of the corresponding operand ('f32')}} linalg.indexed_generic { indexing_maps = [ affine_map<(d0) -> (d0)> ], iterator_types = ["parallel"]} @@ -239,7 +267,7 @@ // ----- func @indexed_generic_arg_count(%arg0: memref) { - // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}} + // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}} linalg.indexed_generic { indexing_maps = [ affine_map<()[] -> ()> ], iterator_types = []} @@ -401,7 +429,7 @@ func @pooling_rank_mismatch(%arg0: memref, %arg1: memref<2x3xf32>, %arg2: memref) { - // expected-error @+1 {{expected shaped value rank (2) to match the result rank of indexing_map #1 (3)}} + // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}} linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: memref, memref<2x3xf32>, memref return @@ -410,7 +438,7 @@ // ----- func @named_ops(%a3: memref, %b3: memref, %c3: memref) { - // expected-error @+1 {{expected shaped value rank (2) to match the result rank of indexing_map #1 (3)}} + // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}} linalg.batch_matmul ins(%a3, %b3: memref, memref) outs(%c3 : memref) return @@ -714,7 +742,7 @@ // ----- func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) { - // expected-error @+1 {{inferred shaped operand #1 has shape's dimension #0 to be 4, but found 3}} + // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 3}} linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>) outs(%arg2 :memref<2x4xf32>) return @@ -723,7 +751,7 @@ // ----- func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) { - // expected-error @+1 {{inferred shaped operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}} + // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}} linalg.conv_2d_input_nhwc_filter_hwcf { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%input, %filter : memref<1x3x4x2xf32>, memref<3x2x2x1xf32>) diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -975,6 +975,30 @@ // CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][] // CHECKPARALLEL: store %[[a]], %[[ARG1]][%[[i]], %[[j]]] +func @generic_op_scalar(%arg0: f32, %arg1: memref<3x4xf32>) +{ + linalg.generic #trait_broadcast + ins(%arg0 : f32) + outs(%arg1 : memref<3x4xf32>) { + ^bb(%a: f32, %b: f32) : + linalg.yield %a : f32 + } + return +} + +// CHECK-LABEL: @generic_op_scalar +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32> +// CHECK: scf.for %[[i:.*]] = {{.*}} +// CHECK: scf.for %[[j:.*]] = {{.*}} +// CHECK: store %[[ARG0]], %[[ARG1]][%[[i]], %[[j]]] + +// CHECKPARALLEL-LABEL: @generic_op_scalar +// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 +// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32> +// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]]) +// CHECKPARALLEL: store %[[ARG0]], %[[ARG1]][%[[i]], %[[j]]] + func @generic_index_op_zero_rank(%arg0: memref, %arg1: memref<3x4xi32>) { linalg.generic #trait_broadcast diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -2,37 +2,42 @@ // RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map2 = affine_map<(d0, d1, d2) -> ()> func @generic_op_reshape_producer_fusion(%arg0 : tensor, - %arg1 : tensor) -> + %arg1 : tensor, + %arg2 : f32) -> tensor { %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor %1 = linalg.generic { - indexing_maps = [#map0, #map1, #map1], + indexing_maps = [#map0, #map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} - ins(%0, %arg1 : tensor, tensor) + ins(%0, %arg1, %arg2 : tensor, tensor, f32) outs(%0 : tensor) { - ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 + %2 = addf %1, %arg5 : f32 + linalg.yield %2 : f32 } -> tensor return %1 : tensor } // CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> // CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> ()> // CHECK: func @generic_op_reshape_producer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 // CHECK: %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2], [3] // CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]] // CHECK-SAME: [0], [1], [2, 3] // CHECK: %[[T3:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]] +// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor, tensor) +// CHECK-SAME: ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor, tensor, f32) // CHECK-SAME: outs(%{{.+}} : tensor) // CHECK: %[[T4:.+]] = linalg.tensor_collapse_shape %[[T3]] // CHECK-SAME: [0], [1], [2, 3] @@ -42,18 +47,21 @@ // ----- #map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> ()> func @generic_op_reshape_consumer_fusion(%arg0 : tensor, - %arg1 : tensor) -> + %arg1 : tensor, + %arg2 : f32) -> tensor { %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], + indexing_maps = [#map0, #map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) + ins(%arg0, %arg1, %arg2 : tensor, tensor, f32) outs(%arg0 : tensor) { - ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 + %2 = addf %1, %arg5 : f32 + linalg.yield %2 : f32 } -> tensor %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] : tensor into tensor @@ -61,9 +69,12 @@ } // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> ()> + // CHECK: func @generic_op_reshape_consumer_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor) +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 // CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]] // CHECK-SAME: [0], [1, 2, 3] // CHECK-SAME: tensor into tensor @@ -71,9 +82,9 @@ // CHECK-SAME: [0], [1, 2, 3] // CHECK-SAME: tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) +// CHECK-SAME: ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor, tensor, f32) // CHECK-SAME: outs(%{{.+}} : tensor) // CHECK: return %[[T3]] : tensor diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -316,6 +316,7 @@ #accesses_0 = [ affine_map<(i, j, k) -> (j, i)>, + affine_map<(i, j, k) -> ()>, affine_map<(i, j, k) -> (i, k, i + j)> ] @@ -327,34 +328,34 @@ func @generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { + %cst = constant 0.0 : f32 linalg.generic #trait_0 - ins(%arg0 : memref, offset: ?, strides: [?, 1]>) + ins(%arg0, %cst : memref, offset: ?, strides: [?, 1]>, f32) outs(%arg1 : memref) attrs = {foo = 1} { - ^bb(%0: vector<3x4xi4>, %1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 + ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) : + linalg.yield %1 : f32 } return } // CHECK-LABEL: func @generic // CHECK: linalg.generic { -// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}], +// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}, #{{[0-9a-z]*}}], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} -// CHECK-SAME: ins({{.*}} : memref, #[[$strided2D]]>) +// CHECK-SAME: ins({{.*}}, {{.*}} : memref, #[[$strided2D]]>, f32) // CHECK-SAME: outs({{.*}} : memref) // CHECK-SAME: {foo = 1 : i64} func @generic_with_tensor_input(%arg0: tensor>, %arg1: memref) { + %cst = constant 0.0 : f32 linalg.generic #trait_0 - ins(%arg0 : tensor>) + ins(%arg0, %cst : tensor>, f32) outs(%arg1 : memref) attrs = {foo = 1} { - ^bb(%0: vector<3x4xi4>, %1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 + ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) : + linalg.yield %1 : f32 } return } @@ -362,7 +363,7 @@ // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} -// CHECK-SAME: ins({{.*}} : tensor>) +// CHECK-SAME: ins({{.*}}, {{.*}} : tensor>, f32) // CHECK-SAME: outs({{.*}} : memref) // CHECK-SAME: {foo = 1 : i64} diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -8,6 +8,7 @@ func @matmul_tensors( %arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { // CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { // CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor) { @@ -19,11 +20,11 @@ // CHECK-NOT: linalg.matmul {{.*}} tensor // Padding injects static information. -// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x4xi8> -// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<4x3xi8> -// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x3xi32> // CHECK: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>) // CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> @@ -41,6 +42,41 @@ return %0 : tensor } +// CHECK-LABEL: func @generic_scalar_and_tensor( +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[VAL:[0-9a-z]+]]: f32) -> tensor { +func @generic_scalar_and_tensor( + %arg0: tensor, %arg1: f32) + -> tensor { +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { +// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor) { +// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor to tensor + +// Padding injects static information. +// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: : tensor to tensor<2x3x4xf32> +// CHECK: %[[pD:.*]] = linalg.generic +// CHECK-SAME: ins(%[[VAL]] : f32) outs(%[[pC]] : tensor<2x3x4xf32>) +// CHECK: %[[sTD:.*]] = subtensor %[[pD]][0, 0, 0] [%{{.*}}, %{{.*}}, %{{.*}}] [1, 1, 1] : tensor<2x3x4xf32> to tensor +// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor +// CHECK: scf.yield %[[TD]] : tensor +// CHECK: scf.yield %[[TD2]] : tensor +// CHECK: scf.yield %[[TD1]] : tensor + %0 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1, d2) -> ()>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)> ], + iterator_types = ["parallel", "parallel", "parallel"]} + {__internal_linalg_transform__ = "tile-and-pad"} + ins(%arg1 : f32) + outs(%arg0: tensor) { + ^bb(%0: f32, %1: f32) : + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + // CHECK-1DIM-TILE: func @matmul_tensors( // CHECK-1DIM-TILE: %[[TA:[0-9a-z]+]]: tensor // CHECK-1DIM-TILE: %[[TB:[0-9a-z]+]]: tensor @@ -65,6 +101,7 @@ // CHECK-1DIM-TILE-SAME: %[[TA:[0-9a-z]+]]: tensor // CHECK-1DIM-TILE-SAME: %[[TB:[0-9a-z]+]]: tensor<8x?xi8> // CHECK-1DIM-TILE-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +// CHECK-1DIM-TILE: %[[C0:.*]] = constant 0 : index // CHECK-1DIM-TILE: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { // CHECK-1DIM-TILE: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { // CHECK-1DIM-TILE: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor @@ -72,11 +109,11 @@ // CHECK-1DIM-TILE: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8> // CHECK-1DIM-TILE: %[[sTBc:.*]] = tensor.cast %[[sTB]] : tensor<8x?xi8> to tensor // CHECK-1DIM-TILE: %[[sTC:.*]] = subtensor %[[TC1]][{{.*}}] : tensor to tensor -// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor to tensor<2x8xi8> -// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor to tensor<8x3xi8> -// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor to tensor<2x3xi32> // CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) // CHECK-1DIM-TILE: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -136,6 +136,23 @@ // ----- +// CHECK-LABEL: func @test_vectorize_scalar_input +func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) { + // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> + // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + linalg.generic { + indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : f32) + outs(%A: memref<8x16xf32>) { + ^bb(%0: f32, %1: f32) : + linalg.yield %0 : f32 + } + return +} + +// ----- + // CHECK-LABEL: func @test_vectorize_fill func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -162,8 +162,7 @@ *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); changed = true; } - } else { - assert(opOperand->get().getType().isa()); + } else if (opOperand->get().getType().isa()) { // Tile and Fuse tensor input. if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) continue; diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -533,7 +533,7 @@ // For now, just assume it is the zero of type. // In the future, it should be the zero of type + op. static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { - auto t = getElementTypeOrSelf(op.get().getType()); + auto t = getElementTypeOrSelf(op.get()); return b.create(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); } @@ -544,7 +544,8 @@ linalg::LinalgTilingOptions() .setTileSizes(tileSizes) .setPaddingValueComputationFunction(getNeutralOfLinalgOp); - tilingPattern.add>( + tilingPattern.add, + linalg::LinalgTilingPattern>( context, linalgTilingOptions, linalg::LinalgTransformationFilter( Identifier::get("tile-and-pad", context)));