mirror of
https://github.com/Gericom/teak-llvm.git
synced 2025-06-19 11:35:51 -04:00
[mlir] Enable specifying verify on OpInterface
Summary: Add method in ODS to specify verification for operations implementing a OpInterface. Use this with infer type op interface to verify that the inferred type matches the return type and remove special case in TestPatterns. This could also have been achieved by using OpInterfaceMethod but verify seems pretty common and it is not an arbitrary method that just happened to be named verifyTrait, so having it be defined in special way seems appropriate/better documenting. Differential Revision: https://reviews.llvm.org/D73122
This commit is contained in:
parent
963f268186
commit
178562fb35
@ -400,6 +400,10 @@ def OpWithInferTypeInterfaceOp : Op<...
|
|||||||
[DeclareOpInterfaceMethods<MyInterface>]> { ... }
|
[DeclareOpInterfaceMethods<MyInterface>]> { ... }
|
||||||
```
|
```
|
||||||
|
|
||||||
|
A verification method can also be specified on the `OpInterface` by setting
|
||||||
|
`verify`. Setting `verify` results in the generated trait having a `verifyTrait`
|
||||||
|
method that is applied to all operations implementing the trait.
|
||||||
|
|
||||||
### Builder methods
|
### Builder methods
|
||||||
|
|
||||||
For each operation, there are a few builders automatically generated based on
|
For each operation, there are a few builders automatically generated based on
|
||||||
|
@ -72,8 +72,6 @@ private:
|
|||||||
Attribute attr;
|
Attribute attr;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
|
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
// Helper function to infer return tensor returns types given element and shape
|
// Helper function to infer return tensor returns types given element and shape
|
||||||
// inference function.
|
// inference function.
|
||||||
@ -89,8 +87,14 @@ LogicalResult inferReturnTensorTypes(
|
|||||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferedReturnTypes);
|
SmallVectorImpl<Type> &inferedReturnTypes);
|
||||||
|
|
||||||
|
/// Verifies that the inferred result types match the actual result types for
|
||||||
|
/// the op. Precondition: op implements InferTypeOpInterface.
|
||||||
|
LogicalResult verifyInferredResultTypes(Operation *op);
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
|
||||||
|
|
||||||
namespace OpTrait {
|
namespace OpTrait {
|
||||||
|
|
||||||
/// Tensor type inference trait that constructs a tensor from the infered
|
/// Tensor type inference trait that constructs a tensor from the infered
|
||||||
|
@ -60,6 +60,10 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
|
|||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
let verify = [{
|
||||||
|
return detail::verifyInferredResultTypes($_op);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
|
def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
|
||||||
|
@ -1411,8 +1411,12 @@ def ins;
|
|||||||
// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
|
// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
|
||||||
// C++. The purpose to wrap around C++ symbol string with this class is to make
|
// C++. The purpose to wrap around C++ symbol string with this class is to make
|
||||||
// interfaces specified for ops in TableGen less alien and more integrated.
|
// interfaces specified for ops in TableGen less alien and more integrated.
|
||||||
class OpInterfaceTrait<string name> : NativeOpTrait<""> {
|
class OpInterfaceTrait<string name, code verifyBody = [{}]> : NativeOpTrait<""> {
|
||||||
let trait = name # "::Trait";
|
let trait = name # "::Trait";
|
||||||
|
|
||||||
|
// Specify the body of the verification function. `$_op` will be replaced with
|
||||||
|
// the operation being verified.
|
||||||
|
code verify = verifyBody;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This class represents a single, optionally static, interface method.
|
// This class represents a single, optionally static, interface method.
|
||||||
|
@ -86,6 +86,9 @@ public:
|
|||||||
// Return the description of this method if it has one.
|
// Return the description of this method if it has one.
|
||||||
llvm::Optional<StringRef> getDescription() const;
|
llvm::Optional<StringRef> getDescription() const;
|
||||||
|
|
||||||
|
// Return the verify method body if it has one.
|
||||||
|
llvm::Optional<StringRef> getVerify() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The TableGen definition of this interface.
|
// The TableGen definition of this interface.
|
||||||
const llvm::Record *def;
|
const llvm::Record *def;
|
||||||
|
@ -45,3 +45,17 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
|
|||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
|
||||||
|
SmallVector<Type, 4> inferedReturnTypes;
|
||||||
|
auto retTypeFn = cast<InferTypeOpInterface>(op);
|
||||||
|
if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(),
|
||||||
|
op->getOperands(), op->getAttrs(),
|
||||||
|
op->getRegions(), inferedReturnTypes)))
|
||||||
|
return failure();
|
||||||
|
SmallVector<Type, 4> resultTypes(op->getResultTypes());
|
||||||
|
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
|
||||||
|
return op->emitOpError(
|
||||||
|
"inferred type incompatible with return type of operation");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
@ -85,3 +85,9 @@ llvm::Optional<StringRef> OpInterface::getDescription() const {
|
|||||||
auto value = def->getValueAsString("description");
|
auto value = def->getValueAsString("description");
|
||||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the body for this method if it has one.
|
||||||
|
llvm::Optional<StringRef> OpInterface::getVerify() const {
|
||||||
|
auto value = def->getValueAsString("verify");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
@ -103,26 +103,6 @@ struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
|
|||||||
};
|
};
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verification check.
|
|
||||||
// TODO: Move to ops that implement type infer interface.
|
|
||||||
getFunction().walk([this](Operation *op) -> void {
|
|
||||||
auto retTypeFn = dyn_cast<InferTypeOpInterface>(op);
|
|
||||||
if (!retTypeFn)
|
|
||||||
return;
|
|
||||||
auto *context = &getContext();
|
|
||||||
SmallVector<Type, 2> inferedReturnTypes;
|
|
||||||
if (failed(retTypeFn.inferReturnTypes(
|
|
||||||
context, op->getLoc(), op->getOperands(), op->getAttrs(),
|
|
||||||
op->getRegions(), inferedReturnTypes)))
|
|
||||||
return;
|
|
||||||
SmallVector<Type, 1> resultTypes(op->getResultTypes());
|
|
||||||
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) {
|
|
||||||
op->emitOpError(
|
|
||||||
"inferred type incompatible with return type of operation");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
@ -23,7 +23,6 @@ func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: testReturnTypeOpInterface
|
|
||||||
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
|
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
|
||||||
// expected-error@+1 {{incompatible with return type}}
|
// expected-error@+1 {{incompatible with return type}}
|
||||||
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
||||||
@ -32,7 +31,6 @@ func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: testReturnTypeOpInterface
|
|
||||||
func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
|
func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
|
||||||
// expected-error@+1 {{operand type mismatch}}
|
// expected-error@+1 {{operand type mismatch}}
|
||||||
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
|
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
|
|
||||||
#include "DocGenUtilities.h"
|
#include "DocGenUtilities.h"
|
||||||
#include "mlir/Support/STLExtras.h"
|
#include "mlir/Support/STLExtras.h"
|
||||||
|
#include "mlir/TableGen/Format.h"
|
||||||
#include "mlir/TableGen/GenInfo.h"
|
#include "mlir/TableGen/GenInfo.h"
|
||||||
#include "mlir/TableGen/OpInterfaces.h"
|
#include "mlir/TableGen/OpInterfaces.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
@ -152,6 +153,12 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
|
|||||||
|
|
||||||
// Insert the default implementation for any methods.
|
// Insert the default implementation for any methods.
|
||||||
for (auto &method : interface.getMethods()) {
|
for (auto &method : interface.getMethods()) {
|
||||||
|
// Flag interface methods named verifyTrait.
|
||||||
|
if (method.getName() == "verifyTrait")
|
||||||
|
PrintFatalError(
|
||||||
|
formatv("'verifyTrait' method cannot be specified as interface "
|
||||||
|
"method for '{0}'; set 'verify' on OpInterfaceTrait instead",
|
||||||
|
interfaceName));
|
||||||
auto defaultImpl = method.getDefaultImplementation();
|
auto defaultImpl = method.getDefaultImplementation();
|
||||||
if (!defaultImpl)
|
if (!defaultImpl)
|
||||||
continue;
|
continue;
|
||||||
@ -162,6 +169,13 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
|
|||||||
os << " {\n" << defaultImpl.getValue() << " }\n";
|
os << " {\n" << defaultImpl.getValue() << " }\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tblgen::FmtContext traitCtx;
|
||||||
|
traitCtx.withOp("op");
|
||||||
|
if (auto verify = interface.getVerify()) {
|
||||||
|
os << " static LogicalResult verifyTrait(Operation* op) {\n"
|
||||||
|
<< tblgen::tgfmt(*verify, &traitCtx) << "\n }\n";
|
||||||
|
}
|
||||||
|
|
||||||
os << " };\n";
|
os << " };\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user