diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -24,7 +24,6 @@ using linalg_vecmat = OperationBuilder; using linalg_range = ValueBuilder; using linalg_reshape = ValueBuilder; -using linalg_slice = ValueBuilder; using linalg_yield = OperationBuilder; } // namespace intrinsics diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -433,78 +433,6 @@ let hasCanonicalizer = 1; } -def Linalg_SliceOp : Linalg_Op<"slice", [ - DeclareOpInterfaceMethods, NoSideEffect]>, - Arguments<(ins AnyStridedMemRef:$view, - Variadic>:$indexings)>, - Results<(outs AnyStridedMemRef)> { - let summary = "Produce a rank-reduced `subview` of a base `view`."; - let description = [{ - The `linalg.slice` op allows defining a subregion of a smaller rank than the - operand `view` within the underlying buffer. - - A `linalg.slice` op takes a view and a variadic number of indexings and - produces a `view` of the same elemental type. An indexing is either: - 1. a `linalg.range`, in which case it does not reduce the rank of the - parent `view` along the corresponding dimension. - 2. an `index`, in which case it reduces the rank of the parent view by - one. - - If an indexing extends past the size of the `view`, this is undefined - behavior. Ideally the `linalg.slice` operation would automatically truncate - it to be within bounds but there are tradeoffs involved now that `std.view` - is a standard op. - - Examples: - - 1. rank-preserving `slice`: - - ```mlir - %4 = linalg.slice %0[%1, %2] : memref, - !linalg.range, !linalg.range, memref - ``` - - 2. rank-reducing `slice` (from 2-D to 1-D): - - ```mlir - %4 = linalg.slice %0[%1, %2] : memref, - index, !linalg.range, memref - ``` - - 3. rank-reducing `slice` (from 2-D to 0-D): - - ```mlir - %4 = linalg.slice %0[%1, %2] : memref, - index, index, memref - ``` - }]; - - let builders = [OpBuilderDAG<(ins "Value":$base, "ValueRange":$indexings)>]; - - let extraClassDeclaration = [{ - enum { FirstIndexingOperand = 1 }; - unsigned getRank() { return getShapedType().getRank(); } - Type getElementType() { return getShapedType().getElementType(); } - ShapedType getShapedType() { return getType().cast(); } - unsigned getBaseViewRank() { return getBaseViewType().getRank(); } - ShapedType getBaseViewType() { return view().getType().cast();} - - // Get the underlying indexing at a given rank. - Value indexing(unsigned rank) { return *(indexings().begin() + rank); } - - // Get the subset of indexings that are of RangeType. - SmallVector getRanges() { - SmallVector res; - for (auto operand : indexings()) - if (!operand.getType().isa()) - res.push_back(operand); - return res; - } - }]; - - let hasFolder = 1; -} - def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, Arguments<(ins Variadic:$values)> { let summary = "Linalg yield operation"; diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -45,9 +45,9 @@ The invariants of this interface are: 1. `static_offsets`, `static_sizes` and `static_strides` have length - at most `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). + exactly `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). 2. `offsets`, `sizes` and `strides` have each length at most - length `static_offsets` (resp. `static_sizes`, `static_strides`). + `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). 3. if an entry of `static_offsets` (resp. `static_sizes`, `static_strides`) is equal to a special sentinel value, namely `ShapedType::kDynamicStrideOrOffset` (resp. `ShapedType::kDynamicSize`, diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -185,93 +185,6 @@ } }; -/// Conversion pattern that transforms a linalg.slice op into: -/// 1. An "undef" value for the ViewDescriptor. -/// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size -/// and stride corresponding to the region of memory within the bounds of -/// the parent view. -/// The linalg.slice op is replaced by the alloca'ed pointer. -class SliceOpConversion : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(SliceOp sliceOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - edsc::ScopedContext context(rewriter, sliceOp->getLoc()); - SliceOpAdaptor adaptor(operands); - BaseViewConversionHelper baseDesc(adaptor.view()); - - auto memRefType = sliceOp.getBaseViewType(); - auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64)); - - BaseViewConversionHelper desc( - typeConverter->convertType(sliceOp.getShapedType())); - - // TODO: extract sizes and emit asserts. - SmallVector strides(memRefType.getRank()); - for (int i = 0, e = memRefType.getRank(); i < e; ++i) - strides[i] = baseDesc.stride(i); - - auto pos = [&rewriter](ArrayRef values) { - return rewriter.getI64ArrayAttr(values); - }; - - // Compute base offset. - Value baseOffset = baseDesc.offset(); - for (int i = 0, e = memRefType.getRank(); i < e; ++i) { - Value indexing = adaptor.indexings()[i]; - Value min = indexing; - if (sliceOp.indexing(i).getType().isa()) - min = llvm_extractvalue(int64Ty, indexing, pos(0)); - baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i])); - } - - // Insert the base and aligned pointers. - desc.setAllocatedPtr(baseDesc.allocatedPtr()); - desc.setAlignedPtr(baseDesc.alignedPtr()); - - // Insert base offset. - desc.setOffset(baseOffset); - - // Corner case, no sizes or strides: early return the descriptor. - if (sliceOp.getShapedType().getRank() == 0) - return rewriter.replaceOp(sliceOp, {desc}), success(); - - Value zero = llvm_constant( - int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - // Compute and insert view sizes (max - min along the range) and strides. - // Skip the non-range operands as they will be projected away from the view. - int numNewDims = 0; - for (auto en : llvm::enumerate(sliceOp.indexings())) { - Value indexing = en.value(); - if (indexing.getType().isa()) { - int rank = en.index(); - Value rangeDescriptor = adaptor.indexings()[rank]; - Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0)); - Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1)); - Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2)); - Value baseSize = baseDesc.size(rank); - - // Bound upper by base view upper bound. - max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, - baseSize); - Value size = llvm_sub(max, min); - // Bound lower by zero. - size = - llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); - Value stride = llvm_mul(strides[rank], step); - desc.setSize(numNewDims, size); - desc.setStride(numNewDims, stride); - ++numNewDims; - } - } - - rewriter.replaceOp(sliceOp, {desc}); - return success(); - } -}; - // YieldOp produces and LLVM::ReturnOp. class YieldOpConversion : public ConvertOpToLLVMPattern { public: @@ -289,8 +202,8 @@ /// Populate the given list with patterns that convert from Linalg to LLVM. void mlir::populateLinalgToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert(converter); + patterns.insert( + converter); // Populate the type conversions for the linalg types. converter.addConversion( diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1318,83 +1318,6 @@ context); } -//===----------------------------------------------------------------------===// -// SliceOp -//===----------------------------------------------------------------------===// -void mlir::linalg::SliceOp::build(OpBuilder &b, OperationState &result, - Value base, ValueRange indexings) { - result.addOperands(base); - result.addOperands(indexings); - - auto memRefType = base.getType().cast(); - int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(memRefType, strides, offset); - assert(succeeded(res) && strides.size() == indexings.size()); - (void)res; - - unsigned rank = memRefType.getRank(); - // TODO: propagate static size and stride information when available. - SmallVector sizes(rank, -1); // -1 encodes dynamic size. - result.addTypes({MemRefType::Builder(memRefType) - .setShape(sizes) - .setAffineMaps(makeStridedLinearLayoutMap( - strides, offset, b.getContext()))}); -} - -static void print(OpAsmPrinter &p, SliceOp op) { - auto indexings = op.indexings(); - p << SliceOp::getOperationName() << " " << op.view() << "[" << indexings - << "] "; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getBaseViewType(); - if (!indexings.empty()) - p << ", " << op.indexings().getTypes(); - p << ", " << op.getType(); -} - -static ParseResult parseSliceOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType baseInfo; - SmallVector operands; - SmallVector types; - if (parser.parseOperand(baseInfo) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonTypeList(types)) - return failure(); - - if (types.size() < 2) - return parser.emitError(parser.getCurrentLocation(), - "expected at least input and result view types"); - - ArrayRef indexingTypes = ArrayRef(types).drop_front().drop_back(); - return failure( - parser.resolveOperand(baseInfo, types.front(), result.operands) || - (!operands.empty() && - parser.resolveOperands(operands, indexingTypes, - operands.front().location, result.operands)) || - parser.addTypeToList(types.back(), result.types)); -} - -static LogicalResult verify(SliceOp op) { - unsigned rank = op.getBaseViewRank(); - if (rank != llvm::size(op.indexings())) - return op.emitOpError("expected ") - << rank << " indexings, got " << llvm::size(op.indexings()); - unsigned index = 0; - for (auto indexing : op.indexings()) { - if (indexing.getType().isa()) - --rank; - ++index; - } - if (op.getRank() != rank) - return op.emitOpError() << "expected rank of the view(" << op.getRank() - << ") to be the number of ranges(" << rank << ")"; - return success(); -} - -Value SliceOp::getViewSource() { return view(); } - //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// @@ -1746,11 +1669,6 @@ return getResult(); return foldReshapeOp(*this, operands); } -OpFoldResult SliceOp::fold(ArrayRef) { - if (succeeded(foldMemRefCast(*this))) - return getResult(); - return {}; -} OpFoldResult TensorReshapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -426,9 +426,8 @@ // Must be a subview or a slice to guarantee there are loops we can fuse // into. auto subView = consumerOpOperand.get().getDefiningOp(); - auto slice = consumerOpOperand.get().getDefiningOp(); - if (!subView && !slice) { - LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)"); + if (!subView) { + LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); return llvm::None; } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -754,6 +754,12 @@ return t; } + // 0-D corner case for empty shape that still have an affine map. Example: + // `memref (s0)>>`. This is a 1 element memref whose + // offset needs to remain, just return t. + if (t.getShape().empty()) + return t; + // If the canonical strided layout for the sizes of `t` is equal to the // simplified layout of `t` we can just return an empty layout. Otherwise, // just simplify the existing layout. @@ -770,6 +776,9 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, ArrayRef exprs, MLIRContext *context) { + assert(!sizes.empty() && !exprs.empty() && + "expected non-empty sizes and exprs"); + // Size 0 corner case is useful for canonicalizations. if (llvm::is_contained(sizes, 0)) return getAffineConstantExpr(0, context); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -9,15 +9,11 @@ %1 = alloc (%b) : memref %2 = view %1[%c0][] : memref to memref<16x16xf32> %3 = memref_cast %2 : memref<16x16xf32> to memref - %r0 = linalg.range %c0:%c8:%c1 : !linalg.range - - // CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref - %4 = linalg.slice %3[%r0, %r0] : memref, !linalg.range, !linalg.range, memref // CHECK: linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>) linalg.matmul ins(%3, %3: memref, memref) outs(%3: memref) - return %4: memref + return %3: memref } // ----- diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -8,22 +8,6 @@ // ----- -func @slice_number_of_indexings(%arg0: memref(off + M * i + j)>>) { - // expected-error @+2 {{expected 2 indexings, got 1}} - %c0 = constant 0: index - %0 = linalg.slice %arg0[%c0] : memref(off + M * i + j)>>, index, memref(off + M * i + j)>> -} - -// ----- - -func @slice_rank_vs_range_indices(%arg0: memref(off + M * i + j)>>) { - // expected-error @+2 {{op expected rank of the view(1) to be the number of ranges(0)}} - %c0 = constant 0: index - %0 = linalg.slice %arg0[%c0, %c0] : memref(off + M * i + j)>>, index, index, memref(off + i)>> -} - -// ----- - func @store_number_of_indices(%v : memref) { // expected-error @+3 {{store index operand count not equal to memref rank}} %c0 = constant 0 : index diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -14,61 +14,6 @@ // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)> -func @slice(%arg0: memref, %arg1: !linalg.range) { - %1 = linalg.slice %arg0[%arg1] : memref, !linalg.range, memref - return -} -// CHECK-LABEL: func @slice -// insert data ptr for slice op -// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i64, i64, i64)> -// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : i64 -// CHECK-NEXT: llvm.add %{{.*}}, %{{.*}} : i64 -// insert offset -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: llvm.mlir.constant(0 : index) -// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i64, i64, i64)> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, i64, i64)> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(i64, i64, i64)> -// get size[0] from parent view -// CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : i64 -// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, i64 -// compute size[0] bounded by parent view's size[0] -// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : i64 -// bound below by 0 -// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : i64 -// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, i64 -// compute stride[0] using bounded size -// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : i64 -// insert size and stride -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - -func @slice_with_range_and_index(%arg0: memref) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %R = linalg.range %c0:%c1:%c1 : !linalg.range - scf.for %i0 = %c0 to %c1 step %c1 { - %1 = linalg.slice %arg0[%i0, %R] : memref, index, !linalg.range, memref - } - return -} -// CHECK-LABEL: func @slice_with_range_and_index -// loop-body. -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i64, i64, i64)> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, i64, i64)> -// CHECK: llvm.insertvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK: llvm.insertvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { // Reshapes that expand a contiguous tensor with some 1's. %0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>, diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -85,32 +85,13 @@ // ----- -// CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> - func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { %c0 = constant 0 : index %0 = muli %arg0, %arg0 : index %1 = alloc (%0) : memref %2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range %3 = view %1[%c0][%arg0, %arg0] : memref to memref - %4 = linalg.slice %3[%2, %2] : - memref, - !linalg.range, - !linalg.range, - memref - %5 = linalg.slice %3[%2, %arg2] : memref, - !linalg.range, - index, - memref - %6 = linalg.slice %3[%arg2, %2] : memref, - index, - !linalg.range, - memref - %7 = linalg.slice %3[%arg2, %arg3] : memref, - index, - index, - memref - %8 = view %1[%c0][%arg0, %arg0] : memref to memref> + %4 = view %1[%c0][%arg0, %arg0] : memref to memref> dealloc %1 : memref return } @@ -120,26 +101,6 @@ // CHECK-NEXT: range // CHECK-NEXT: std.view %{{.*}}[%{{.*}}][%{{.*}}] : // CHECK-SAME: memref to memref -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : -// CHECK-SAME: memref, -// CHECK-SAME: !linalg.range, -// CHECK-SAME: !linalg.range, -// CHECK-SAME: memref -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : -// CHECK-SAME: memref, -// CHECK-SAME: !linalg.range, -// CHECK-SAME: index, -// CHECK-SAME: memref -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : -// CHECK-SAME: memref, -// CHECK-SAME: index, -// CHECK-SAME: !linalg.range, -// CHECK-SAME: memref -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : -// CHECK-SAME: memref, -// CHECK-SAME: index, -// CHECK-SAME: index, -// CHECK-SAME: memref // CHECK-NEXT: view %{{.*}}[%{{.*}}][%{{.*}}] : // CHECK-SAME: memref to memref> // CHECK-NEXT: dealloc %{{.*}} : memref diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -24,6 +24,8 @@ // CHECK-DAG: #[[$SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> // CHECK-DAG: #[[$SUBVIEW_MAP9:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 3 + d1 + 6)> // CHECK-DAG: #[[$SUBVIEW_MAP10:map[0-9]+]] = affine_map<(d0) -> (d0 + 3)> +// CHECK-DAG: #[[$SUBVIEW_MAP11:map[0-9]+]] = affine_map<() -> (4)> +// CHECK-DAG: #[[$SUBVIEW_MAP12:map[0-9]+]] = affine_map<()[s0] -> (s0)> // CHECK-LABEL: func @func_with_ops // CHECK-SAME: %[[ARG:.*]]: f32 @@ -803,6 +805,13 @@ // CHECK: subview %{{.*}}[1] [1] [1] : memref<5x3xf32> to memref<3xf32, #[[$SUBVIEW_MAP10]]> %26 = subview %24[1][1][1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]> + // Corner-case of 0-D rank-reducing subview with an offset. + // CHECK: subview %{{.*}}[1, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref + %27 = subview %24[1, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref (4)>> + + // CHECK: subview %{{.*}}[%{{.*}}, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref + %28 = subview %24[%arg0, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref (s0)>> + return } @@ -903,9 +912,9 @@ // CHECK-LABEL: func @subtensor_insert({{.*}}) { func @subtensor_insert( - %t: tensor<8x16x4xf32>, - %t2: tensor<16x32x8xf32>, - %t3: tensor<4x4xf32>, + %t: tensor<8x16x4xf32>, + %t2: tensor<16x32x8xf32>, + %t3: tensor<4x4xf32>, %idx : index) { %c0 = constant 0 : index %c1 = constant 1 : index