diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -554,9 +554,9 @@ * `std.view`, * `std.subview`, + * `std.transpose`. * `linalg.range`, * `linalg.slice`, - * `linalg.transpose`. * `linalg.reshape`, Future ops are added on a per-need basis but should include: diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -287,36 +287,6 @@ let hasFolder = 1; } -def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, - Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>, - Results<(outs AnyStridedMemRef)> { - let summary = "`transpose` produces a new strided memref (metadata-only)"; - let description = [{ - The `linalg.transpose` op produces a strided memref whose sizes and strides - are a permutation of the original `view`. This is a pure metadata - transformation. - - Example: - - ```mlir - %1 = linalg.transpose %0 (i, j) -> (j, i) : memref to memref - ``` - }]; - - let builders = [OpBuilder< - "OpBuilder &b, OperationState &result, Value view, " - "AffineMapAttr permutation, ArrayRef attrs = {}">]; - - let verifier = [{ return ::verify(*this); }]; - - let extraClassDeclaration = [{ - static StringRef getPermutationAttrName() { return "permutation"; } - ShapedType getShapedType() { return view().getType().cast(); } - }]; - - let hasFolder = 1; -} - def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, Terminator]>, Arguments<(ins Variadic:$values)> { let summary = "Linalg yield operation"; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3428,6 +3428,38 @@ let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)"; } +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def TransposeOp : Std_Op<"transpose", [NoSideEffect]>, + Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>, + Results<(outs AnyStridedMemRef)> { + let summary = "`transpose` produces a new strided memref (metadata-only)"; + let description = [{ + The `transpose` op produces a strided memref whose sizes and strides + are a permutation of the original `in` memref. This is purely a metadata + transformation. + + Example: + + ```mlir + %1 = transpose %0 (i, j) -> (j, i) : memref to memref (d1 * s0 + d0)>> + ``` + }]; + + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result, Value in, " + "AffineMapAttr permutation, ArrayRef attrs = {}">]; + + let extraClassDeclaration = [{ + static StringRef getPermutationAttrName() { return "permutation"; } + ShapedType getShapedType() { return in().getType().cast(); } + }]; + + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -284,57 +284,6 @@ } }; -/// Conversion pattern that transforms a linalg.transpose op into: -/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. -/// 2. A load of the ViewDescriptor from the pointer allocated in 1. -/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size -/// and stride. Size and stride are permutations of the original values. -/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. -/// The linalg.transpose op is replaced by the alloca'ed pointer. -class TransposeOpConversion : public ConvertToLLVMPattern { -public: - explicit TransposeOpConversion(MLIRContext *context, - LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(TransposeOp::getOperationName(), context, - lowering_) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // Initialize the common boilerplate and alloca at the top of the FuncOp. - edsc::ScopedContext context(rewriter, op->getLoc()); - TransposeOpAdaptor adaptor(operands); - BaseViewConversionHelper baseDesc(adaptor.view()); - - auto transposeOp = cast(op); - // No permutation, early exit. - if (transposeOp.permutation().isIdentity()) - return rewriter.replaceOp(op, {baseDesc}), success(); - - BaseViewConversionHelper desc( - typeConverter.convertType(transposeOp.getShapedType())); - - // Copy the base and aligned pointers from the old descriptor to the new - // one. - desc.setAllocatedPtr(baseDesc.allocatedPtr()); - desc.setAlignedPtr(baseDesc.alignedPtr()); - - // Copy the offset pointer from the old descriptor to the new one. - desc.setOffset(baseDesc.offset()); - - // Iterate over the dimensions and apply size/stride permutation. - for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { - int sourcePos = en.index(); - int targetPos = en.value().cast().getPosition(); - desc.setSize(targetPos, baseDesc.size(sourcePos)); - desc.setStride(targetPos, baseDesc.stride(sourcePos)); - } - - rewriter.replaceOp(op, {desc}); - return success(); - } -}; - // YieldOp produces and LLVM::ReturnOp. class YieldOpConversion : public ConvertToLLVMPattern { public: @@ -356,7 +305,7 @@ LLVMTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx, converter); + YieldOpConversion>(ctx, converter); // Populate the type conversions for the linalg types. converter.addConversion( diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -206,12 +206,12 @@ // If either inputPerm or outputPerm are non-identities, insert transposes. auto inputPerm = op.inputPermutation(); if (inputPerm.hasValue() && !inputPerm->isIdentity()) - in = rewriter.create(op.getLoc(), in, - AffineMapAttr::get(*inputPerm)); + in = rewriter.create(op.getLoc(), in, + AffineMapAttr::get(*inputPerm)); auto outputPerm = op.outputPermutation(); if (outputPerm.hasValue() && !outputPerm->isIdentity()) - out = rewriter.create( - op.getLoc(), out, AffineMapAttr::get(*outputPerm)); + out = rewriter.create(op.getLoc(), out, + AffineMapAttr::get(*outputPerm)); // If nothing was transposed, fail and let the conversion kick in. if (in == op.input() && out == op.output()) @@ -270,7 +270,7 @@ ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); OwningRewritePatternList patterns; populateLinalgToStandardConversionPatterns(patterns, &getContext()); if (failed(applyFullConversion(module, target, patterns))) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3011,6 +3011,57 @@ } }; +/// Conversion pattern that transforms a transpose op into: +/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. +/// 2. A load of the ViewDescriptor from the pointer allocated in 1. +/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size +/// and stride. Size and stride are permutations of the original values. +/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. +/// The transpose op is replaced by the alloca'ed pointer. +class TransposeOpLowering : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + TransposeOpAdaptor adaptor(operands); + MemRefDescriptor viewMemRef(adaptor.in()); + + auto transposeOp = cast(op); + // No permutation, early exit. + if (transposeOp.permutation().isIdentity()) + return rewriter.replaceOp(op, {viewMemRef}), success(); + + auto targetMemRef = MemRefDescriptor::undef( + rewriter, loc, typeConverter.convertType(transposeOp.getShapedType())); + + // Copy the base and aligned pointers from the old descriptor to the new + // one. + targetMemRef.setAllocatedPtr(rewriter, loc, + viewMemRef.allocatedPtr(rewriter, loc)); + targetMemRef.setAlignedPtr(rewriter, loc, + viewMemRef.alignedPtr(rewriter, loc)); + + // Copy the offset pointer from the old descriptor to the new one. + targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); + + // Iterate over the dimensions and apply size/stride permutation. + for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { + int sourcePos = en.index(); + int targetPos = en.value().cast().getPosition(); + targetMemRef.setSize(rewriter, loc, targetPos, + viewMemRef.size(rewriter, loc, sourcePos)); + targetMemRef.setStride(rewriter, loc, targetPos, + viewMemRef.stride(rewriter, loc, sourcePos)); + } + + rewriter.replaceOp(op, {targetMemRef}); + return success(); + } +}; + /// Conversion pattern that transforms an op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size @@ -3425,6 +3476,7 @@ RankOpLowering, StoreOpLowering, SubViewOpLowering, + TransposeOpLowering, ViewOpLowering, AllocOpLowering>(converter); // clang-format on diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -973,86 +973,6 @@ Value SliceOp::getViewSource() { return view(); } -//===----------------------------------------------------------------------===// -// TransposeOp -//===----------------------------------------------------------------------===// - -static MemRefType inferTransposeResultType(MemRefType memRefType, - AffineMap permutationMap) { - auto rank = memRefType.getRank(); - auto originalSizes = memRefType.getShape(); - // Compute permuted sizes. - SmallVector sizes(rank, 0); - for (auto en : llvm::enumerate(permutationMap.getResults())) - sizes[en.index()] = - originalSizes[en.value().cast().getPosition()]; - - // Compute permuted strides. - int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(memRefType, strides, offset); - assert(succeeded(res) && strides.size() == static_cast(rank)); - (void)res; - auto map = - makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); - map = permutationMap ? map.compose(permutationMap) : map; - return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); -} - -void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result, - Value view, AffineMapAttr permutation, - ArrayRef attrs) { - auto permutationMap = permutation.getValue(); - assert(permutationMap); - - auto memRefType = view.getType().cast(); - // Compute result type. - MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); - - build(b, result, resultType, view, attrs); - result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); -} - -static void print(OpAsmPrinter &p, TransposeOp op) { - p << op.getOperationName() << " " << op.view() << " " << op.permutation(); - p.printOptionalAttrDict(op.getAttrs(), - {TransposeOp::getPermutationAttrName()}); - p << " : " << op.view().getType() << " to " << op.getType(); -} - -static ParseResult parseTransposeOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType view; - AffineMap permutation; - MemRefType srcType, dstType; - if (parser.parseOperand(view) || parser.parseAffineMap(permutation) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.resolveOperand(view, srcType, result.operands) || - parser.parseKeywordType("to", dstType) || - parser.addTypeToList(dstType, result.types)) - return failure(); - - result.addAttribute(TransposeOp::getPermutationAttrName(), - AffineMapAttr::get(permutation)); - return success(); -} - -static LogicalResult verify(TransposeOp op) { - if (!op.permutation().isPermutation()) - return op.emitOpError("expected a permutation map"); - if (op.permutation().getNumDims() != op.getShapedType().getRank()) - return op.emitOpError( - "expected a permutation map of same rank as the view"); - - auto srcType = op.view().getType().cast(); - auto dstType = op.getType().cast(); - if (dstType != inferTransposeResultType(srcType, op.permutation())) - return op.emitOpError("output type ") - << dstType << " does not match transposed input type " << srcType; - return success(); -} - //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// @@ -1359,11 +1279,6 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } -OpFoldResult TransposeOp::fold(ArrayRef) { - if (succeeded(foldMemRefCast(*this))) - return getResult(); - return {}; -} //===----------------------------------------------------------------------===// // Auto-generated Linalg named ops. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3491,6 +3491,96 @@ return NoneType::get(type.getContext()); } +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +/// Build a strided memref type by applying `permutationMap` tp `memRefType`. +static MemRefType inferTransposeResultType(MemRefType memRefType, + AffineMap permutationMap) { + auto rank = memRefType.getRank(); + auto originalSizes = memRefType.getShape(); + // Compute permuted sizes. + SmallVector sizes(rank, 0); + for (auto en : llvm::enumerate(permutationMap.getResults())) + sizes[en.index()] = + originalSizes[en.value().cast().getPosition()]; + + // Compute permuted strides. + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(memRefType, strides, offset); + assert(succeeded(res) && strides.size() == static_cast(rank)); + (void)res; + auto map = + makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); + map = permutationMap ? map.compose(permutationMap) : map; + return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); +} + +void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, + AffineMapAttr permutation, + ArrayRef attrs) { + auto permutationMap = permutation.getValue(); + assert(permutationMap); + + auto memRefType = in.getType().cast(); + // Compute result type. + MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); + + build(b, result, resultType, in, attrs); + result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); +} + +// transpose $in $permutation attr-dict : type($in) `to` type(results) +static void print(OpAsmPrinter &p, TransposeOp op) { + p << "transpose " << op.in() << " " << op.permutation(); + p.printOptionalAttrDict(op.getAttrs(), + {TransposeOp::getPermutationAttrName()}); + p << " : " << op.in().getType() << " to " << op.getType(); +} + +static ParseResult parseTransposeOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType in; + AffineMap permutation; + MemRefType srcType, dstType; + if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(in, srcType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)) + return failure(); + + result.addAttribute(TransposeOp::getPermutationAttrName(), + AffineMapAttr::get(permutation)); + return success(); +} + +static LogicalResult verify(TransposeOp op) { + if (!op.permutation().isPermutation()) + return op.emitOpError("expected a permutation map"); + if (op.permutation().getNumDims() != op.getShapedType().getRank()) + return op.emitOpError( + "expected a permutation map of same rank as the input"); + + auto srcType = op.in().getType().cast(); + auto dstType = op.getType().cast(); + auto transposedType = inferTransposeResultType(srcType, op.permutation()); + if (dstType != transposedType) + return op.emitOpError("output type ") + << dstType << " does not match transposed input type " << srcType + << ", " << transposedType; + return success(); +} + +OpFoldResult TransposeOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return {}; +} + //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -673,7 +673,7 @@ /// Fold the result of an ExtractOp in place when it comes from a TransposeOp. static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) { - auto transposeOp = extractOp.vector().getDefiningOp(); + auto transposeOp = extractOp.vector().getDefiningOp(); if (!transposeOp) return failure(); @@ -2521,7 +2521,7 @@ // Eliminates transpose operations, which produce values identical to their // input values. This happens when the dimensions of the input vector remain in // their original order after the transpose operation. -OpFoldResult TransposeOp::fold(ArrayRef operands) { +OpFoldResult vector::TransposeOp::fold(ArrayRef operands) { SmallVector transp; getTransp(transp); @@ -2535,7 +2535,7 @@ return vector(); } -static LogicalResult verify(TransposeOp op) { +static LogicalResult verify(vector::TransposeOp op) { VectorType vectorType = op.getVectorType(); VectorType resultType = op.getResultType(); int64_t rank = resultType.getRank(); @@ -2563,14 +2563,14 @@ namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. -class TransposeFolder final : public OpRewritePattern { +class TransposeFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TransposeOp transposeOp, + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { - // Wrapper around TransposeOp::getTransp() for cleaner code. - auto getPermutation = [](TransposeOp transpose) { + // Wrapper around vector::TransposeOp::getTransp() for cleaner code. + auto getPermutation = [](vector::TransposeOp transpose) { SmallVector permutation; transpose.getTransp(permutation); return permutation; @@ -2586,15 +2586,15 @@ }; // Return if the input of 'transposeOp' is not defined by another transpose. - TransposeOp parentTransposeOp = - transposeOp.vector().getDefiningOp(); + vector::TransposeOp parentTransposeOp = + transposeOp.vector().getDefiningOp(); if (!parentTransposeOp) return failure(); SmallVector permutation = composePermutations( getPermutation(parentTransposeOp), getPermutation(transposeOp)); // Replace 'transposeOp' with a new transpose operation. - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( transposeOp, transposeOp.getResult().getType(), parentTransposeOp.vector(), vector::getVectorSubscriptAttr(rewriter, permutation)); @@ -2604,12 +2604,12 @@ } // end anonymous namespace -void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +void vector::TransposeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } -void TransposeOp::getTransp(SmallVectorImpl &results) { +void vector::TransposeOp::getTransp(SmallVectorImpl &results) { populateFromInt64AttrArray(transp(), results); } diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -114,3 +114,20 @@ return } +// ----- + +// CHECK-LABEL: func @transpose +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +func @transpose(%arg0: memref) { + %0 = transpose %arg0 (i, j, k) -> (k, i, j) : memref to memref (d2 * s1 + s0 + d0 * s2 + d1)>> + return +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -33,27 +33,6 @@ // ----- -func @transpose_not_permutation(%v : memref(off + M * i + j)>>) { - // expected-error @+1 {{expected a permutation map}} - linalg.transpose %v (i, j) -> (i, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> -} - -// ----- - -func @transpose_bad_rank(%v : memref(off + M * i + j)>>) { - // expected-error @+1 {{expected a permutation map of same rank as the view}} - linalg.transpose %v (i) -> (i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> -} - -// ----- - -func @transpose_wrong_type(%v : memref(off + M * i + j)>>) { - // expected-error @+1 {{output type 'memref (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref (d0 * s1 + s0 + d1)>>'}} - linalg.transpose %v (i, j) -> (j, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> -} - -// ----- - func @yield_parent(%arg0: memref(off + i)>>) { // expected-error @+1 {{op expected parent op with LinalgOp interface}} linalg.yield %arg0: memref(off + i)>> diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -69,22 +69,6 @@ // CHECK: llvm.insertvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: llvm.insertvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -func @transpose(%arg0: memref) { - %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref to memref (d2 * s1 + s0 + d0 * s2 + d1)>> - return -} -// CHECK-LABEL: func @transpose -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { // Reshapes that expand a contiguous tensor with some 1's. %0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>, diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -126,11 +126,11 @@ // CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)> func @transpose(%arg0: memref) { - %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref to memref (d2 * s1 + s0 + d1 * s2 + d0)>> + %0 = transpose %arg0 (i, j, k) -> (k, j, i) : memref to memref (d2 * s1 + s0 + d1 * s2 + d0)>> return } // CHECK-LABEL: func @transpose -// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : +// CHECK: transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : // CHECK-SAME: memref to memref // ----- diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir --- a/mlir/test/Dialect/Linalg/standard.mlir +++ b/mlir/test/Dialect/Linalg/standard.mlir @@ -55,9 +55,9 @@ // CHECK-LABEL: func @copy_transpose( // CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref, // CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref) { -// CHECK: %[[t0:.*]] = linalg.transpose %[[arg0]] +// CHECK: %[[t0:.*]] = transpose %[[arg0]] // CHECK-SAME: (d0, d1, d2) -> (d0, d2, d1) : memref -// CHECK: %[[t1:.*]] = linalg.transpose %[[arg1]] +// CHECK: %[[t1:.*]] = transpose %[[arg1]] // CHECK-SAME: (d0, d1, d2) -> (d2, d1, d0) : memref // CHECK: %[[o0:.*]] = memref_cast %[[t0]] : // CHECK-SAME: memref to memref diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -81,3 +81,24 @@ } : tensor return %tnsr : tensor } + +// ----- + +func @transpose_not_permutation(%v : memref(off + M * i + j)>>) { + // expected-error @+1 {{expected a permutation map}} + transpose %v (i, j) -> (i, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> +} + +// ----- + +func @transpose_bad_rank(%v : memref(off + M * i + j)>>) { + // expected-error @+1 {{expected a permutation map of same rank as the input}} + transpose %v (i) -> (i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> +} + +// ----- + +func @transpose_wrong_type(%v : memref(off + M * i + j)>>) { + // expected-error @+1 {{output type 'memref (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref (d0 * s1 + s0 + d1)>>'}} + transpose %v (i, j) -> (j, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> +}